MUYANG GUO / INDEX

LeetCode

LintCode 654 Sparse Matrix Multiplication - Medium

654. Sparse Matrix Multiplication

·1 min read·#LintCode#Medium#Python

654. Sparse Matrix Multiplication — Medium

Open on LintCode

Problem

  1. Sparse Matrix Multiplication

Given two Sparse Matrix A and B, return the result of AB.

You may assume that A's column number is equal to B's row number.

Example

Example1

Input: [[1,0,0],[-1,0,3]] [[7,0,0],[0,0,0],[0,0,1]] Output: [[7,0,0],[-7,0,3]] Explanation: A = [ [ 1, 0, 0], [-1, 0, 3] ]

B = [ [ 7, 0, 0 ], [ 0, 0, 0 ], [ 0, 0, 1 ] ]

 |  1 0 0 |   | 7 0 0 |   |  7 0 0 |

AB = | -1 0 3 | x | 0 0 0 | = | -7 0 3 | | 0 0 1 | Example2

Input: [[1,0],[0,1]] [[0,1],[1,0]] Output: [[0,1],[1,0]]

Solution

class Solution:
    """
    @param A: a sparse matrix
    @param B: a sparse matrix
    @return: the result of A * B
    """
    def multiply(self, A, B):
        # write your code here
        row_vectors = self.convert_to_row_vectors(A)
        col_vectors = self.convert_to_col_vectors(B)
        
        matrix = []
        for row_vector in row_vectors:
            row = []
            for col_vector in col_vectors:
                row.append(self.multi_vector(row_vector, col_vector))
            matrix.append(row)
        return matrix
        
    def convert_to_row_vectors(self, matrix):
        vectors = []
        for row in matrix:
            vector = []
            for index, col in enumerate(row):
                if col != 0:
                    vector.append((index, col))
            vectors.append(vector)
        return vectors
        
    def convert_to_col_vectors(self, matrix):
        n, m = len(matrix), len(matrix[0])
        vectors = []
        for j in range(m):
            vector = []
            for i in range(n):
                if matrix[i][j] != 0:
                    vector.append((i, matrix[i][j]))
            vectors.append(vector)
        return vectors
 
    def multi_vector(self, v1, v2):
        i, j = 0, 0
        result = 0
        
        while i < len(v1) and j < len(v2):
            if v1[i][0] < v2[j][0]:
                i += 1
            elif v1[i][0] > v2[j][0]:
                j += 1
            else:
                result += v1[i][1] * v2[j][1]
                i += 1
                j += 1
                
        return result

Comments