使用java寫的矩陣乘法實例(Strassen算法)
Strassen算法於1969年由德國數學傢Strassen提出,該方法引入七個中間變量,每個中間變量都隻需要進行一次乘法運算。而樸素算法卻需要進行8次乘法運算。
原理
Strassen算法的原理如下所示,使用sympy驗證Strassen算法的正確性
import sympy as s A = s.Symbol("A") B = s.Symbol("B") C = s.Symbol("C") D = s.Symbol("D") E = s.Symbol("E") F = s.Symbol("F") G = s.Symbol("G") H = s.Symbol("H") p1 = A * (F - H) p2 = (A + B) * H p3 = (C + D) * E p4 = D * (G - E) p5 = (A + D) * (E + H) p6 = (B - D) * (G + H) p7 = (A - C) * (E + F) print(A * E + B * G, (p5 + p4 - p2 + p6).simplify()) print(A * F + B * H, (p1 + p2).simplify()) print(C * E + D * G, (p3 + p4).simplify()) print(C * F + D * H, (p1 + p5 - p3 - p7).simplify())
復雜度分析
$$f(N)=7\times f(\frac{N}{2})=7^2\times f(\frac{N}{4})=…=7^k\times f(\frac{N}{2^k})$$
最終復雜度為$7^{log_2 N}=N^{log_2 7}$
java矩陣乘法(Strassen算法)
代碼如下,可以看看數據結構的定義,時間換空間。
public class Matrix { private final Matrix[] _matrixArray; private final int n; private int element; public Matrix(int n) { this.n = n; if (n != 1) { this._matrixArray = new Matrix[4]; for (int i = 0; i < 4; i++) { this._matrixArray[i] = new Matrix(n / 2); } } else { this._matrixArray = null; } } private Matrix(int n, boolean needInit) { this.n = n; if (n != 1) { this._matrixArray = new Matrix[4]; } else { this._matrixArray = null; } } public void set(int i, int j, int a) { if (n == 1) { element = a; } else { int size = n / 2; this._matrixArray[(i / size) * 2 + (j / size)].set(i % size, j % size, a); } } public Matrix multi(Matrix m) { Matrix result = null; if (n == 1) { result = new Matrix(1); result.set(0, 0, (element * m.element)); } else { result = new Matrix(n, false); result._matrixArray[0] = P5(m).add(P4(m)).minus(P2(m)).add(P6(m)); result._matrixArray[1] = P1(m).add(P2(m)); result._matrixArray[2] = P3(m).add(P4(m)); result._matrixArray[3] = P5(m).add(P1(m)).minus(P3(m)).minus(P7(m)); } return result; } public Matrix add(Matrix m) { Matrix result = null; if (n == 1) { result = new Matrix(1); result.set(0, 0, (element + m.element)); } else { result = new Matrix(n, false); result._matrixArray[0] = this._matrixArray[0].add(m._matrixArray[0]); result._matrixArray[1] = this._matrixArray[1].add(m._matrixArray[1]); result._matrixArray[2] = this._matrixArray[2].add(m._matrixArray[2]); result._matrixArray[3] = this._matrixArray[3].add(m._matrixArray[3]);; } return result; } public Matrix minus(Matrix m) { Matrix result = null; if (n == 1) { result = new Matrix(1); result.set(0, 0, (element - m.element)); } else { result = new Matrix(n, false); result._matrixArray[0] = this._matrixArray[0].minus(m._matrixArray[0]); result._matrixArray[1] = this._matrixArray[1].minus(m._matrixArray[1]); result._matrixArray[2] = this._matrixArray[2].minus(m._matrixArray[2]); result._matrixArray[3] = this._matrixArray[3].minus(m._matrixArray[3]);; } return result; } protected Matrix P1(Matrix m) { return _matrixArray[0].multi(m._matrixArray[1]).minus(_matrixArray[0].multi(m._matrixArray[3])); } protected Matrix P2(Matrix m) { return _matrixArray[0].multi(m._matrixArray[3]).add(_matrixArray[1].multi(m._matrixArray[3])); } protected Matrix P3(Matrix m) { return _matrixArray[2].multi(m._matrixArray[0]).add(_matrixArray[3].multi(m._matrixArray[0])); } protected Matrix P4(Matrix m) { return _matrixArray[3].multi(m._matrixArray[2]).minus(_matrixArray[3].multi(m._matrixArray[0])); } protected Matrix P5(Matrix m) { return (_matrixArray[0].add(_matrixArray[3])).multi(m._matrixArray[0].add(m._matrixArray[3])); } protected Matrix P6(Matrix m) { return (_matrixArray[1].minus(_matrixArray[3])).multi(m._matrixArray[2].add(m._matrixArray[3])); } protected Matrix P7(Matrix m) { return (_matrixArray[0].minus(_matrixArray[2])).multi(m._matrixArray[0].add(m._matrixArray[1])); } public int get(int i, int j) { if (n == 1) { return element; } else { int size = n / 2; return this._matrixArray[(i / size) * 2 + (j / size)].get(i % size, j % size); } } public void display() { for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { System.out.print(get(i, j)); System.out.print(" "); } System.out.println(); } } public static void main(String[] args) { Matrix m = new Matrix(2); Matrix n = new Matrix(2); m.set(0, 0, 1); m.set(0, 1, 3); m.set(1, 0, 5); m.set(1, 1, 7); n.set(0, 0, 8); n.set(0, 1, 4); n.set(1, 0, 6); n.set(1, 1, 2); Matrix res = m.multi(n); res.display(); } }
總結
到此這篇關於使用java寫的矩陣乘法的文章就介紹到這瞭,更多相關java矩陣乘法(Strassen算法)內容請搜索WalkonNet以前的文章或繼續瀏覽下面的相關文章希望大傢以後多多支持WalkonNet!
推薦閱讀:
- Java實現矩陣乘法以及優化的方法實例
- Python 中的Sympy詳細使用
- Python實現曲線擬合的最小二乘法
- C++實現LeetCode(73.矩陣賦零)
- pytorch tensor內所有元素相乘實例