|
| 1 | +import math |
| 2 | +from typing import List, Tuple |
| 3 | + |
| 4 | + |
| 5 | +def default_matrix_multiplication(a: List, b: List) -> List: |
| 6 | + """ |
| 7 | + Multiplication only for 2x2 matrices |
| 8 | + """ |
| 9 | + if len(a) != 2 or len(a[0]) != 2 or len(b) != 2 or len(b[0]) != 2: |
| 10 | + raise Exception("Matrices are not 2x2") |
| 11 | + new_matrix = [ |
| 12 | + [a[0][0] * b[0][0] + a[0][1] * b[1][0], a[0][0] * b[0][1] + a[0][1] * b[1][1]], |
| 13 | + [a[1][0] * b[0][0] + a[1][1] * b[1][0], a[1][0] * b[0][1] + a[1][1] * b[1][1]], |
| 14 | + ] |
| 15 | + return new_matrix |
| 16 | + |
| 17 | + |
| 18 | +def matrix_addition(matrix_a: List, matrix_b: List): |
| 19 | + return [ |
| 20 | + [matrix_a[row][col] + matrix_b[row][col] for col in range(len(matrix_a[row]))] |
| 21 | + for row in range(len(matrix_a)) |
| 22 | + ] |
| 23 | + |
| 24 | + |
| 25 | +def matrix_subtraction(matrix_a: List, matrix_b: List): |
| 26 | + return [ |
| 27 | + [matrix_a[row][col] - matrix_b[row][col] for col in range(len(matrix_a[row]))] |
| 28 | + for row in range(len(matrix_a)) |
| 29 | + ] |
| 30 | + |
| 31 | + |
| 32 | +def split_matrix(a: List,) -> Tuple[List, List, List, List]: |
| 33 | + """ |
| 34 | + Given an even length matrix, returns the top_left, top_right, bot_left, bot_right quadrant. |
| 35 | +
|
| 36 | + >>> split_matrix([[4,3,2,4],[2,3,1,1],[6,5,4,3],[8,4,1,6]]) |
| 37 | + ([[4, 3], [2, 3]], [[2, 4], [1, 1]], [[6, 5], [8, 4]], [[4, 3], [1, 6]]) |
| 38 | + >>> split_matrix([[4,3,2,4,4,3,2,4],[2,3,1,1,2,3,1,1],[6,5,4,3,6,5,4,3],[8,4,1,6,8,4,1,6],[4,3,2,4,4,3,2,4],[2,3,1,1,2,3,1,1],[6,5,4,3,6,5,4,3],[8,4,1,6,8,4,1,6]]) |
| 39 | + ([[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]], [[4, 3, 2, 4], [2, 3, 1, 1], [6, 5, 4, 3], [8, 4, 1, 6]]) |
| 40 | + """ |
| 41 | + if len(a) % 2 != 0 or len(a[0]) % 2 != 0: |
| 42 | + raise Exception("Odd matrices are not supported!") |
| 43 | + |
| 44 | + matrix_length = len(a) |
| 45 | + mid = matrix_length // 2 |
| 46 | + |
| 47 | + top_right = [[a[i][j] for j in range(mid, matrix_length)] for i in range(mid)] |
| 48 | + bot_right = [ |
| 49 | + [a[i][j] for j in range(mid, matrix_length)] for i in range(mid, matrix_length) |
| 50 | + ] |
| 51 | + |
| 52 | + top_left = [[a[i][j] for j in range(mid)] for i in range(mid)] |
| 53 | + bot_left = [[a[i][j] for j in range(mid)] for i in range(mid, matrix_length)] |
| 54 | + |
| 55 | + return top_left, top_right, bot_left, bot_right |
| 56 | + |
| 57 | + |
| 58 | +def matrix_dimensions(matrix: List) -> Tuple[int, int]: |
| 59 | + return len(matrix), len(matrix[0]) |
| 60 | + |
| 61 | + |
| 62 | +def print_matrix(matrix: List) -> None: |
| 63 | + for i in range(len(matrix)): |
| 64 | + print(matrix[i]) |
| 65 | + |
| 66 | + |
| 67 | +def actual_strassen(matrix_a: List, matrix_b: List) -> List: |
| 68 | + """ |
| 69 | + Recursive function to calculate the product of two matrices, using the Strassen Algorithm. |
| 70 | + It only supports even length matrices. |
| 71 | + """ |
| 72 | + if matrix_dimensions(matrix_a) == (2, 2): |
| 73 | + return default_matrix_multiplication(matrix_a, matrix_b) |
| 74 | + |
| 75 | + a, b, c, d = split_matrix(matrix_a) |
| 76 | + e, f, g, h = split_matrix(matrix_b) |
| 77 | + |
| 78 | + t1 = actual_strassen(a, matrix_subtraction(f, h)) |
| 79 | + t2 = actual_strassen(matrix_addition(a, b), h) |
| 80 | + t3 = actual_strassen(matrix_addition(c, d), e) |
| 81 | + t4 = actual_strassen(d, matrix_subtraction(g, e)) |
| 82 | + t5 = actual_strassen(matrix_addition(a, d), matrix_addition(e, h)) |
| 83 | + t6 = actual_strassen(matrix_subtraction(b, d), matrix_addition(g, h)) |
| 84 | + t7 = actual_strassen(matrix_subtraction(a, c), matrix_addition(e, f)) |
| 85 | + |
| 86 | + top_left = matrix_addition(matrix_subtraction(matrix_addition(t5, t4), t2), t6) |
| 87 | + top_right = matrix_addition(t1, t2) |
| 88 | + bot_left = matrix_addition(t3, t4) |
| 89 | + bot_right = matrix_subtraction(matrix_subtraction(matrix_addition(t1, t5), t3), t7) |
| 90 | + |
| 91 | + # construct the new matrix from our 4 quadrants |
| 92 | + new_matrix = [] |
| 93 | + for i in range(len(top_right)): |
| 94 | + new_matrix.append(top_left[i] + top_right[i]) |
| 95 | + for i in range(len(bot_right)): |
| 96 | + new_matrix.append(bot_left[i] + bot_right[i]) |
| 97 | + return new_matrix |
| 98 | + |
| 99 | + |
| 100 | +def strassen(matrix1: List, matrix2: List) -> List: |
| 101 | + """ |
| 102 | + >>> strassen([[2,1,3],[3,4,6],[1,4,2],[7,6,7]], [[4,2,3,4],[2,1,1,1],[8,6,4,2]]) |
| 103 | + [[34, 23, 19, 15], [68, 46, 37, 28], [28, 18, 15, 12], [96, 62, 55, 48]] |
| 104 | + >>> strassen([[3,7,5,6,9],[1,5,3,7,8],[1,4,4,5,7]], [[2,4],[5,2],[1,7],[5,5],[7,8]]) |
| 105 | + [[139, 163], [121, 134], [100, 121]] |
| 106 | + """ |
| 107 | + if matrix_dimensions(matrix1)[1] != matrix_dimensions(matrix2)[0]: |
| 108 | + raise Exception( |
| 109 | + f"Unable to multiply these matrices, please check the dimensions. \nMatrix A:{matrix1} \nMatrix B:{matrix2}" |
| 110 | + ) |
| 111 | + dimension1 = matrix_dimensions(matrix1) |
| 112 | + dimension2 = matrix_dimensions(matrix2) |
| 113 | + |
| 114 | + if dimension1[0] == dimension1[1] and dimension2[0] == dimension2[1]: |
| 115 | + return matrix1, matrix2 |
| 116 | + |
| 117 | + maximum = max(max(dimension1), max(dimension2)) |
| 118 | + maxim = int(math.pow(2, math.ceil(math.log2(maximum)))) |
| 119 | + new_matrix1 = matrix1 |
| 120 | + new_matrix2 = matrix2 |
| 121 | + |
| 122 | + # Adding zeros to the matrices so that the arrays dimensions are the same and also power of 2 |
| 123 | + for i in range(0, maxim): |
| 124 | + if i < dimension1[0]: |
| 125 | + for j in range(dimension1[1], maxim): |
| 126 | + new_matrix1[i].append(0) |
| 127 | + else: |
| 128 | + new_matrix1.append([0] * maxim) |
| 129 | + if i < dimension2[0]: |
| 130 | + for j in range(dimension2[1], maxim): |
| 131 | + new_matrix2[i].append(0) |
| 132 | + else: |
| 133 | + new_matrix2.append([0] * maxim) |
| 134 | + |
| 135 | + final_matrix = actual_strassen(new_matrix1, new_matrix2) |
| 136 | + |
| 137 | + # Removing the additional zeros |
| 138 | + for i in range(0, maxim): |
| 139 | + if i < dimension1[0]: |
| 140 | + for j in range(dimension2[1], maxim): |
| 141 | + final_matrix[i].pop() |
| 142 | + else: |
| 143 | + final_matrix.pop() |
| 144 | + return final_matrix |
| 145 | + |
| 146 | + |
| 147 | +if __name__ == "__main__": |
| 148 | + matrix1= [ |
| 149 | + [2, 3, 4, 5], |
| 150 | + [6, 4, 3, 1], |
| 151 | + [2, 3, 6, 7], |
| 152 | + [3, 1, 2, 4], |
| 153 | + [2, 3, 4, 5], |
| 154 | + [6, 4, 3, 1], |
| 155 | + [2, 3, 6, 7], |
| 156 | + [3, 1, 2, 4], |
| 157 | + [2, 3, 4, 5], |
| 158 | + [6, 2, 3, 1], |
| 159 | + ] |
| 160 | + matrix2 = [[0, 2, 1, 1], [16, 2, 3, 3], [2, 2, 7, 7], [13, 11, 22, 4]] |
| 161 | + print(strassen(matrix1, matrix2)) |
0 commit comments