This is an algorithm regarding the RKF method:
import numpy as np
class rkf():
def __init__(self,f, a, b, x0, atol, rtol, hmax, hmin):
self.f=f
self.a=a
self.b=b
self.x0=x0
self.atol=atol
self.rtol=rtol
self.hmax=hmax
self.hmin=hmin
def solve(self):
a2 = 2.500000000000000e-01 # 1/4
a3 = 3.750000000000000e-01 # 3/8
a4 = 9.230769230769231e-01 # 12/13
a5 = 1.000000000000000e+00 # 1
a6 = 5.000000000000000e-01 # 1/2
b21 = 2.500000000000000e-01 # 1/4
b31 = 9.375000000000000e-02 # 3/32
b32 = 2.812500000000000e-01 # 9/32
b41 = 8.793809740555303e-01 # 1932/2197
b42 = -3.277196176604461e+00 # -7200/2197
b43 = 3.320892125625853e+00 # 7296/2197
b51 = 2.032407407407407e+00 # 439/216
b52 = -8.000000000000000e+00 # -8
b53 = 7.173489278752436e+00 # 3680/513
b54 = -2.058966861598441e-01 # -845/4104
b61 = -2.962962962962963e-01 # -8/27
b62 = 2.000000000000000e+00 # 2
b63 = -1.381676413255361e+00 # -3544/2565
b64 = 4.529727095516569e-01 # 1859/4104
b65 = -2.750000000000000e-01 # -11/40
r1 = 2.777777777777778e-03 # 1/360
r3 = -2.994152046783626e-02 # -128/4275
r4 = -2.919989367357789e-02 # -2197/75240
r5 = 2.000000000000000e-02 # 1/50
r6 = 3.636363636363636e-02 # 2/55
c1 = 1.157407407407407e-01 # 25/216
c3 = 5.489278752436647e-01 # 1408/2565
c4 = 5.353313840155945e-01 # 2197/4104
c5 = -2.000000000000000e-01 # -1/5
t = self.a
x = np.array(self.x0)
h = self.hmax
T = np.array( [t] )
X = np.array( [x] )
while t < self.b:
if t + h > self.b:
h = self.b - t
k1 = h * self.f(t, x)
k2 = h * self.f(t + a2 * h, x + b21 * k1 )
k3 = h * self.f(t + a3 * h, x + b31 * k1 + b32 * k2)
k4 = h * self.f(t + a4 * h, x + b41 * k1 + b42 * k2 + b43 * k3)
k5 = h * self.f(t + a5 * h, x + b51 * k1 + b52 * k2 + b53 * k3 + b54 * k4)
k6 = h * self.f(t + a6 * h, x + b61 * k1 + b62 * k2 + b63 * k3 + b64 * k4 + b65 * k5)
r = abs( r1 * k1 + r3 * k3 + r4 * k4 + r5 * k5 + r6 * k6 ) / h
r = r / (self.atol+self.rtol*(abs(x)+abs(k1)))
if len( np.shape( r ) ) > 0:
r = max( r )
if r <= 1:
t = t + h
x = x + c1 * k1 + c3 * k3 + c4 * k4 + c5 * k5
T = np.append( T, t )
X = np.append( X, [x], 0 )
h = h * min( max( 0.94 * ( 1 / r )**0.25, 0.1 ), 4.0 )
if h > self.hmax:
h = self.hmax
elif h < self.hmin or t==t-h:
raise RuntimeError("Error: Could not converge to the required tolerance.")
break
return (T,X)
Which works just fine, but I was wondering if is it possible to make this even faster and more efficient?
3 Answers 3
If you're already using Numpy and you find that you are motivated to do loop unrolling in an attempt to make things fast, it's time to switch to C and use lower-level vectorized libraries
Your class does not deserve to be a class, and should just be a function
You should add type hints
There is really no reason to pre-compute your fractions as you have. This makes so marginal a speed difference, at a cost of so worse a legibility and maintainability, that it isn't worth it compared to other efforts like switching language
k
,A
,R
andC
are obviously vectors, andB
is obviously a triangular matrix. Best to actually represent them as such.Since
T
andX
are being frequently reallocated, there's no advantage to using numpy - just use Python listsYour calculation for
k
is actually a series of dot-products, and so it's best to just call intonp.dot
You're not using in-place operators where you should, i.e.
t = t + h
should just bet += h
This condition:
if t + h > self.b: h = self.b - t
is more legible as
if h > b - t:
h = b - t
When doing all of the above, I experience a marginal slowdown of 4.2 us in exchange for greater legibility and maintainability, and centralized constants.
Alternate implementation
from functools import partial
from timeit import timeit
from typing import Callable, Tuple, Sequence
import numpy as np
class rkf_old():
def __init__(self, f, a, b, x0, atol, rtol, hmax, hmin):
self.f = f
self.a = a
self.b = b
self.x0 = x0
self.atol = atol
self.rtol = rtol
self.hmax = hmax
self.hmin = hmin
def solve(self):
a2 = 2.500000000000000e-01 # 1/4
a3 = 3.750000000000000e-01 # 3/8
a4 = 9.230769230769231e-01 # 12/13
a5 = 1.000000000000000e+00 # 1
a6 = 5.000000000000000e-01 # 1/2
b21 = 2.500000000000000e-01 # 1/4
b31 = 9.375000000000000e-02 # 3/32
b32 = 2.812500000000000e-01 # 9/32
b41 = 8.793809740555303e-01 # 1932/2197
b42 = -3.277196176604461e+00 # -7200/2197
b43 = 3.320892125625853e+00 # 7296/2197
b51 = 2.032407407407407e+00 # 439/216
b52 = -8.000000000000000e+00 # -8
b53 = 7.173489278752436e+00 # 3680/513
b54 = -2.058966861598441e-01 # -845/4104
b61 = -2.962962962962963e-01 # -8/27
b62 = 2.000000000000000e+00 # 2
b63 = -1.381676413255361e+00 # -3544/2565
b64 = 4.529727095516569e-01 # 1859/4104
b65 = -2.750000000000000e-01 # -11/40
r1 = 2.777777777777778e-03 # 1/360
r3 = -2.994152046783626e-02 # -128/4275
r4 = -2.919989367357789e-02 # -2197/75240
r5 = 2.000000000000000e-02 # 1/50
r6 = 3.636363636363636e-02 # 2/55
c1 = 1.157407407407407e-01 # 25/216
c3 = 5.489278752436647e-01 # 1408/2565
c4 = 5.353313840155945e-01 # 2197/4104
c5 = -2.000000000000000e-01 # -1/5
t = self.a
x = np.array(self.x0)
h = self.hmax
T = np.array([t])
X = np.array([x])
while t < self.b:
if t + h > self.b:
h = self.b - t
k1 = h * self.f(t, x)
k2 = h * self.f(t + a2 * h, x + b21 * k1)
k3 = h * self.f(t + a3 * h, x + b31 * k1 + b32 * k2)
k4 = h * self.f(t + a4 * h, x + b41 * k1 + b42 * k2 + b43 * k3)
k5 = h * self.f(t + a5 * h, x + b51 * k1 + b52 * k2 + b53 * k3 + b54 * k4)
k6 = h * self.f(t + a6 * h, x + b61 * k1 + b62 * k2 + b63 * k3 + b64 * k4 + b65 * k5)
r = abs(r1 * k1 + r3 * k3 + r4 * k4 + r5 * k5 + r6 * k6) / h
r = r / (self.atol + self.rtol * (abs(x) + abs(k1)))
if len(np.shape(r)) > 0:
r = max(r)
if r <= 1:
t = t + h
x = x + c1 * k1 + c3 * k3 + c4 * k4 + c5 * k5
T = np.append(T, t)
X = np.append(X, [x], 0)
h = h * min(max(0.94 * (1 / r) ** 0.25, 0.1), 4.0)
if h > self.hmax:
h = self.hmax
elif h < self.hmin or t == t - h:
raise RuntimeError("Error: Could not converge to the required tolerance.")
break
return (T, X)
def rkf(
f: Callable[[float, float], float],
a: float, b: float, x0: float,
atol: float, rtol: float,
hmax: float, hmin: float,
) -> Tuple[
Sequence[float], Sequence[float],
]:
A = np.array((0, 1/4, 3/8, 12/13, 1, 1/2))
B = np.array((
( 0, 0, 0, 0, 0, 0),
( 1/4, 0, 0, 0, 0, 0),
( 3/32, 9/32, 0, 0, 0, 0),
(1932/2197, -7200/2197, 7296/2197, 0, 0, 0),
( 439/216, -8, 3680/513, -845/4104, 0, 0),
( -8/27, 2, -3544/2565, 1859/4104, -11/40, 0),
))
R = np.array((1/360, 0, -128/4275, -2197/75240, 1/50, 2/55))
C = np.array((25/216, 0, 1408/2565, 2197/4104, -1/5))
k = np.empty((6,))
t = a
x = x0
h = hmax
T = [t]
X = [x0]
while t < b:
if h > b - t:
h = b - t
Ta = A*h + t
for i, ta in enumerate(Ta):
k[i] = h * f(ta, x + np.dot(
B[i, :i],
k[:i],
))
r = np.abs(np.dot(R, k)) / h
r /= atol + rtol * (np.abs(x) + np.abs(k[0]))
if len(np.shape(r)) > 0:
r = max(r)
if r <= 1:
t += h
x += np.dot(C, k[:5])
T.append(t)
X.append(x)
h *= min(max(0.94 * (1 / r) ** 0.25, 0.1), 4.0)
if h > hmax:
h = hmax
elif h < hmin or t == t - h:
raise ValueError("Error: Could not converge to the required tolerance.")
return T, X
def test_fun(t: float, k: float) -> float:
return 3*t - 2*k + 1/(t**2 + k**2)
def main():
args = dict(f=test_fun, a=-3, b=11, x0=-1, atol=1e-3, rtol=-3, hmax=100, hmin=-100)
old = rkf_old(**args).solve
new = partial(rkf, **args)
for method in (old, new):
t, x = method()
print(t)
print(x)
N = 20_000
print(f'{timeit(method, number=N)/N*1e6:.1f} us')
main()
This outputs
[-3 11]
[-1.00000000e+00 -6.00218231e+05]
53.9 us
[-3, 11]
[-1, -600218.2310934969]
58.1 us
-
1\$\begingroup\$ I did suspect those coefficients to be vectors and matrix, but I was stumbling on how to implement it, and your answer is exactly what I was looking for. Many thanks. \$\endgroup\$Amirhossein Rezaei– Amirhossein Rezaei2021年03月30日 01:50:18 +00:00Commented Mar 30, 2021 at 1:50
All these lines are really weird:
a4 = 9.230769230769231e-01 # 12/13
Unless you have a good reason (which I'd then state in the code as a comment) to do that, just write a4 = 12/13
instead.
Gonna be the same anyway:
>>> import dis
>>> dis.dis('a4 = 12/13')
1 0 LOAD_CONST 0 (0.9230769230769231)
2 STORE_NAME 0 (a4)
4 LOAD_CONST 1 (None)
6 RETURN_VALUE
>>> dis.dis('a4 = 9.230769230769231e-01')
1 0 LOAD_CONST 0 (0.9230769230769231)
2 STORE_NAME 0 (a4)
4 LOAD_CONST 1 (None)
6 RETURN_VALUE
This line for example is not right:
b51 = 2.032407407407407e+00 # 439/216
The values differ slightly, your value being less accurate:
>>> 2.032407407407407e+00
2.032407407407407
>>> 439/216
2.0324074074074074
-
\$\begingroup\$ I'm not sure your if suggestion would make the code more efficient. Do you have any thoughts on the main algorithm? (By main algorithm, I mean what happens in the while loop) \$\endgroup\$Amirhossein Rezaei– Amirhossein Rezaei2021年03月27日 22:05:22 +00:00Commented Mar 27, 2021 at 22:05
-
1\$\begingroup\$ Since I showed that it's gonna be the same (except for the minor value differences), you should be sure that it does not make it more efficient :-). No other thoughts, looks too complicated for me right now and I'm a newb at numpy. \$\endgroup\$Manuel– Manuel2021年03月28日 00:29:50 +00:00Commented Mar 28, 2021 at 0:29
A significant improvement is to use lists and Python's built in append and convert the final list to array, instead of using np.append. I've run a test to demonstrate the performance enhancement:
def lorenz(t,u):
s=10
r=24
b=8/3
x,y,z=u
vx=s*y-s*x
vy=r*x-x*z-y
vz=x*y-b*z
return np.array([vx,vy,vz])
x0=[2,2,2]
t, u = rkf( f=lorenz, a=0, b=1e+3, x0=x0, atol=1e-8, rtol=1e-6 , hmax=1e-1, hmin=1e-40,show_info=True).solve()
Now, when using numpy arrays and np.append I get:
Execution time: 56.7198397 seconds
Number of data points: 120732
Using list and Python's append:
Execution time: 8.3110496 seconds
Number of data points: 120732
Which is a huge difference on the performance. Also another slight improvement is to use sqrt(sqrt()) instead of **0.25 :
h = h * min( max( 0.94 * sqrt(sqrt( 1 / r )), 0.1 ), 4.0 )
Feel free to add your thoughts and suggestions.
-
2\$\begingroup\$ I suspect that the performance boost of using the built-in isn't due to stock Python being faster, but rather misuse of numpy. Rather than appends, having a fixed-size array and doing slice assignments to it should help. \$\endgroup\$Reinderien– Reinderien2021年03月29日 13:54:34 +00:00Commented Mar 29, 2021 at 13:54
Explore related questions
See similar questions with these tags.
class rkf():
-->class RKF:
. \$\endgroup\$