I recently wrote a python code for matrix exponentiation.
Here's the code:
from typing import List
Matrix = List[List[int]]
MOD = 10 ** 9 + 7
def identity(n: int) -> Matrix:
matrix = [[0] * n for _ in range(n)]
for i in range(n):
matrix[i][i] = 1
return matrix
def multiply(mat1: Matrix, mat2: Matrix, copy: Matrix) -> None:
r1, r2 = len(mat1), len(mat2)
c1, c2 = len(mat1[0]), len(mat2[0])
result = [[0] * c2 for _ in range(r1)]
for i in range(r1):
for j in range(c2):
for k in range(r2):
result[i][j] = (result[i][j] + mat1[i][k] * mat2[k][j]) % MOD
for i in range(r1):
for j in range(c2):
copy[i][j] = result[i][j]
def power(mat: Matrix, n: int) -> Matrix:
res = identity(len(mat))
while n:
if n & 1:
multiply(res, mat, res)
multiply(mat, mat, mat)
n >>= 1
return res
def fib(n: int) -> int:
if n == 0:
return 0
magic = [[1, 1],
[1, 0]]
mat = power(magic, n - 1)
return mat[0][0]
if __name__ == '__main__':
print(fib(10 ** 18))
I would mainly like to improve this code performance wise, but any comments on style is also much appreciated.
Thanks!
-
\$\begingroup\$ Nothing wrong with your code, but this isn't actually the matrix exponential, it's a matrix power. en.wikipedia.org/wiki/Matrix_exponential \$\endgroup\$Bijan– Bijan2020年03月01日 15:29:05 +00:00Commented Mar 1, 2020 at 15:29
-
\$\begingroup\$ @Bijan Matrix exponentiation is completely different from Matrix exponential. Matrix exponentiation is a popular topic in Competitive Programming. \$\endgroup\$Sriv– Sriv2020年03月01日 15:44:36 +00:00Commented Mar 1, 2020 at 15:44
1 Answer 1
Using the numpy module for numerical computations is often better as the code is generally simpler and much faster. Here is a version adapted to use numpy:
from typing import List
import numpy as np
Matrix = np.matrix
MOD = 10 ** 9 + 7
def power(mat: Matrix, n: int) -> Matrix:
res = np.identity(len(mat), dtype=np.int64)
while n:
if n & 1:
np.matmul(res, mat, out=res)
res %= MOD
np.matmul(mat, mat, out=mat)
mat %= MOD # Required for numpy if you want correct results
n >>= 1
return res
def fib(n: int) -> int:
if n == 0:
return 0
magic = np.matrix([[1, 1], [1, 0]], dtype=np.int64)
mat = power(magic, n - 1)
return mat[0,0]
if __name__ == '__main__':
print(fib(10 ** 18))
As you can see, using numpy reduce significantly the size of the code since it provide primitive for creating and computing matrices (eg. matrix multiplication).
Numpy is also faster since it uses fast native code to perform the computations. However, here, your matrices are too small so that numpy can provide any speed-up.
Also note that numpy as a downside: it does not support large integers as python does. So, the code above works well as long as you do not increase too much the value of MOD
. You can use dtype=object
to force numpy to support large integers but it will be slower (especially on bigger matrices).
Besides using numpy, you can also specialize your code to compute 2x2 matrices much faster in this specific case. Here is the result:
from typing import List
Matrix = List[List[int]]
MOD = 10 ** 9 + 7
def identity_2x2() -> Matrix:
return [1, 0, 0, 1]
def multiply_2x2(mat1: Matrix, mat2: Matrix, copy: Matrix) -> None:
a00, a01, a10, a11 = mat1
b00, b01, b10, b11 = mat2
copy[0] = (a00 * b00 + a01 * b10) % MOD
copy[1] = (a00 * b01 + a01 * b11) % MOD
copy[2] = (a10 * b00 + a11 * b10) % MOD
copy[3] = (a10 * b01 + a11 * b11) % MOD
def power_2x2(mat: Matrix, n: int) -> Matrix:
res = identity_2x2()
while n:
if n & 1:
multiply_2x2(res, mat, res)
multiply_2x2(mat, mat, mat)
n >>= 1
return res
def fib(n: int) -> int:
if n == 0:
return 0
magic = [1, 1, 1, 0]
mat = power_2x2(magic, n - 1)
return mat[0]
if __name__ == '__main__':
print(fib(10 ** 18))
This is faster because the default interpreter (CPython) executes loops very slowly, so it is better to avoid them as much as possible. It is also faster because no additional list is created.
Please note that if you want your code to run faster, you could use the PyPy interpreter.
-
\$\begingroup\$ Great answer! But why does the first code output
0
? Also, I'm using this as a template for matrix exponentiation problems in general, so the code has to be flexible! \$\endgroup\$Sriv– Sriv2020年03月01日 14:50:49 +00:00Commented Mar 1, 2020 at 14:50 -
\$\begingroup\$ It does not return
0
on my machine, neither on an online python IDE. Maybe you are running on a 32 bits system where np.int64 are not supported? Try to test with smaller inputs to see if this is due to an integer overflow. \$\endgroup\$Jérôme Richard– Jérôme Richard2020年03月01日 15:23:00 +00:00Commented Mar 1, 2020 at 15:23 -
\$\begingroup\$ Yes, you're right. I have a 32-bit system and the code does work for smaller values. Is there any way I can make it work on a 32-bit system? Thanks! \$\endgroup\$Sriv– Sriv2020年03月01日 15:39:58 +00:00Commented Mar 1, 2020 at 15:39
-
\$\begingroup\$ I just discover that can use
dtype=object
as a fallback solution to solve this but note that it could be much slower (although probably not more that if it was done in pure python). \$\endgroup\$Jérôme Richard– Jérôme Richard2020年03月01日 18:14:17 +00:00Commented Mar 1, 2020 at 18:14 -
\$\begingroup\$ It shows
TypeError: Object arrays are not currently supported
when I run it. Full error: hastebin.com/inuxureyoz.txt \$\endgroup\$Sriv– Sriv2020年03月01日 18:21:29 +00:00Commented Mar 1, 2020 at 18:21
Explore related questions
See similar questions with these tags.