Brief introduction for CSR :
The compressed sparse row (CSR) or compressed row storage (CRS) format represents a matrix M by three (one-dimensional) arrays, that respectively contain nonzero values, the extents of rows, and column indices. It is similar to COO, but compresses the row indices, hence the name. This format allows fast row access and matrix-vector multiplications (Mx). The CSR format has been in use since at least the mid-1960s, with the first complete description appearing in 1967.
The CSR format stores a sparse \$m ×ばつ n\$ matrix \$M\$ in row form using three (one-dimensional) arrays (\$A\,ドル \$IA\,ドル \$JA\$). Let \$NNZ\$ denote the number of nonzero entries in \$M\$. (Note that zero-based indices shall be used here.)
- The array \$A\$ is of length \$NNZ\$ and holds all the nonzero entries of \$M\$ in left-to-right top-to-bottom ("row-major") order.
- The array \$IA\$ is of length \$m + 1\$. It is defined by this recursive definition:
- \$IA[0] = 0\$
- \$IA[i] = IA[i − 1]\$ + (number of nonzero elements on the (\$i − 1\$)th row in the original matrix)
- Thus, the first \$m\$ elements of \$IA\$ store the index into \$A\$ of the first nonzero element in each row of \$M\,ドル and the last element \$IA[m]\$ stores \$NNZ\,ドル the number of elements in \$A\,ドル which can be also thought of as the index in \$A\$ of first element of a phantom row just beyond the end of the matrix \$M\$. The values of the i-th row of the original matrix is read from the elements \$A[IA[i]]\$ to \$A[IA[i + 1] − 1]\$ (inclusive on both ends), i.e. from the start of one row to the last index just before the start of the next.
- The third array, \$JA\,ドル contains the column index in \$M\$ of each element of \$A\$ and hence is of length \$NNZ\$ as well.
For example, the matrix:
\$ \left (\begin{matrix} 0 & 0 & 0 & 0 \\ 5 & 8 & 0 & 0 \\ 0 & 0 & 3 & 0 \\ 0 & 6 & 0 & 0 \\ \end{matrix} \right)\$
is a 4 ×ばつ 4 matrix with 4 nonzero elements, hence:
- \$A = [ 5 8 3 6 ]\$
- \$IA = [ 0 0 2 3 4 ]\$
- \$JA = [ 0 1 2 1 ]\$
So, in array \$JA\,ドル the element "5" from \$A\$ has column index 0, "8" and "6" have index 1, and element "3" has index 2.
My implementation:
class CSRImpl:
def __init__(self, numRows, numCols):
self.value = []
self.IA = [0] * (numRows + 1)
self.JA = []
self.numRows = numRows
self.numCols = numCols
def get(self, x, y):
previous_row_values_count = self.IA[x]
current_row_valid_count = self.IA[x+1]
for i in range(previous_row_values_count, current_row_valid_count):
if self.JA[i] == y:
return self.value[i]
else:
return 0.0
def set(self, x, y, v):
for i in range(x+1, self.numRows+1):
self.IA[i] += 1
previous_row_values_count = self.IA[x]
inserted = False
for j in range(previous_row_values_count, self.IA[x+1]-1):
if self.JA[j] > y:
self.JA.insert(j, y)
self.value.insert(j, v)
inserted = True
break
elif self.JA[j] == y:
inserted = True
self.value[j] = v
break
if not inserted:
self.JA.insert(self.IA[x+1]-1,y)
self.value.insert(self.IA[x+1]-1, v)
def iterate(self):
result = [] # a list of triple (row, col, value)
for i,v in enumerate(self.IA):
if i == 0:
continue
current_row_index = 0
while current_row_index < v-self.IA[i-1]:
row_value = i - 1
col_value = self.JA[self.IA[i-1] + current_row_index]
real_value = self.value[self.IA[i-1] + current_row_index]
result.append((row_value, col_value, real_value))
current_row_index += 1
return result
def debug_info(self):
print 'value ', self.value
print 'IA ', self.IA
print 'JA ', self.JA
if __name__ == "__main__":
matrix = CSRImpl(4,4)
matrix.set(1,0,5)
matrix.set(1,1,8)
matrix.set(2,2,3)
matrix.set(3,1,6)
matrix.debug_info()
print matrix.iterate()
Output:
value [5, 8, 3, 6]
IA [0, 0, 2, 3, 4]
JA [0, 1, 2, 1]
[(1, 0, 5), (1, 1, 8), (2, 2, 3), (3, 1, 6)]
1 Answer 1
The first thing to change is the name. CSRMatrix
is more descriptive to people who don't know exactly what it is, and as useful for people who do. I'm also going to assume that the lack of docstrings and newlines is only for a code review. If you were going to publish this code, both would be good. You should change debug_info
to __repr__
, and make it return the results, set
should be __setitem(self, coord, v)__
, and get
should be __getitem__(self, coord)
. This will make everything feel much more pythony to use.
WRT performance, your current code seems pretty optimal. It would probably be a good idea to try making self.IA
an np.array
, as it's size is fixed and it is storing only ints
. This will be slower for small numbers of items, but should be faster eventually. Here is a non-finished set of edits for these.
class CSRMatrix:
def __init__(self, numRows, numCols):
self.value = []
self.IA = np.zeros(numRows + 1, np.int)#[0] * (numRows + 1)
self.JA = []
self.numRows = numRows
self.numCols = numCols
def __getitem__(self, coord):
x, y = coord
previous_row_values_count = self.IA[x]
current_row_valid_count = self.IA[x+1]
for i in range(previous_row_values_count, current_row_valid_count):
if self.JA[i] == y:
return self.value[i]
else:
return 0.0
def __setitem__(self, coord, v):
x, y = coord
self.IA[x+1: self.numRows+1] += 1
previous_row_values_count = self.IA[x]
inserted = False
for j in range(previous_row_values_count, self.IA[x+1]-1):
if self.JA[j] > y:
self.JA.insert(j, y)
self.value.insert(j, v)
inserted = True
break
elif self.JA[j] == y:
inserted = True
self.value[j] = v
break
if not inserted:
self.JA.insert(self.IA[x+1]-1,y)
self.value.insert(self.IA[x+1]-1, v)
def iterate(self):
result = [] # a list of triple (row, col, value)
for i,v in enumerate(self.IA):
if i == 0:
continue
current_row_index = 0
while current_row_index < v-self.IA[i-1]:
row_value = i - 1
col_value = self.JA[self.IA[i-1] + current_row_index]
real_value = self.value[self.IA[i-1] + current_row_index]
result.append((row_value, col_value, real_value))
current_row_index += 1
return result
def __repr__(self):
return ('value '+ str(self.value) +
'\nIA ' + str(self.IA) +
'\nJA '+ str(self.JA))
-
\$\begingroup\$ also iterate really should be a generator. i might add that tomorrow. \$\endgroup\$Oscar Smith– Oscar Smith2017年10月28日 07:21:22 +00:00Commented Oct 28, 2017 at 7:21
Explore related questions
See similar questions with these tags.
CSR
s should be implemented. \$\endgroup\$