lua-users home
lua-l archive

Re: LuaJIT2 performance for number crunching

[Date Prev][Date Next][Thread Prev][Thread Next] [Date Index] [Thread Index]


Hi,
I've modified the implementation of the rkf45 ode integrator in
vectorial form, now it is slightly simpler and I was hoping faster
because I dont pass a pointer to a member instance of an ffi
structure.
The results are still accurate but the program it is slow like before
or similar.
I add in attachment the two files. Please note that I've added a line
to load with ffi the 'libgslcblas-0' library. This should work on
windows with gsl-shell or plain luajit provided that you have
installed the library (included with gsl library).
I hope to get some help because I'm stuck for the moment...
Francesco
local abs, max, min = math.abs, math.max, math.min
local ffi = require "ffi"
local vecsize = 2 * ffi.sizeof('double')
cblas = ffi.load('libgslcblas-0')
ffi.cdef[[
 typedef struct {
 double t;
 double h;
 double y[2];
 double dydt[2];
 } odevec_state;
 void cblas_daxpy (const int N, const double ALPHA,
		 const double * X, const int INCX,
		 double * Y, const int INCY);
]]
local function ode_new()
 return ffi.new('odevec_state')
end
local function ode_init(s, t0, h0, f, y)
 ffi.copy(s.y, y, vecsize)
 f(t0, s.y, s.dydt)
 s.t = t0
 s.h = h0
end
local function hadjust(rmax, h)
 local S = 0.9
 if rmax > 1.1 then
 local r = S / rmax^(1/5)
 r = max(0.2, r)
 return r * h, -1
 elseif rmax < 0.5 then
 local r = S / rmax^(1/(5+1))
 r = max(1, min(r, 5))
 return r * h, 1
 end
 return h, 0
end
local ws_y = ffi.new('double[2]')
local ws_k1 = ffi.new('double[2]')
local ws_k2 = ffi.new('double[2]')
local ws_k3 = ffi.new('double[2]')
local ws_k4 = ffi.new('double[2]')
local ws_k5 = ffi.new('double[2]')
local ws_k6 = ffi.new('double[2]')
local function rkf45_evolve(s, f, t1)
 local t, h = s.t, s.h
 local hadj, inc
 ffi.copy (ws_k1, s.dydt, vecsize)
 if t + h > t1 then h = t1 - t end
 while h > 0 do
 ffi.copy (ws_y, s.y, vecsize)
 local rmax = 0
 do
	 cblas.cblas_daxpy (2, h * 0.25, ws_k1, 1, ws_y, 1)
	 -- k2 step
	 f(t + 0.25 * h, ws_y, ws_k2)
	 ffi.copy (ws_y, s.y, vecsize)
	 cblas.cblas_daxpy (2, h * 0.09375, ws_k1, 1, ws_y, 1)
	 cblas.cblas_daxpy (2, h * 0.28125, ws_k2, 1, ws_y, 1)
	 -- k3 step
	 f(t + 0.375 * h, ws_y, ws_k3)
	 ffi.copy (ws_y, s.y, vecsize)
	 cblas.cblas_daxpy (2, h * 0.87938097405553, ws_k1, 1, ws_y, 1)
	 cblas.cblas_daxpy (2, h * -3.2771961766045, ws_k2, 1, ws_y, 1)
	 cblas.cblas_daxpy (2, h * 3.3208921256259, ws_k3, 1, ws_y, 1)
	 -- k4 step
	 f(t + 0.92307692307692 * h, ws_y, ws_k4)
	 ffi.copy (ws_y, s.y, vecsize)
	 cblas.cblas_daxpy (2, h * 2.0324074074074, ws_k1, 1, ws_y, 1)
	 cblas.cblas_daxpy (2, h * -8, ws_k2, 1, ws_y, 1)
	 cblas.cblas_daxpy (2, h * 7.1734892787524, ws_k3, 1, ws_y, 1)
	 cblas.cblas_daxpy (2, h * -0.20589668615984, ws_k4, 1, ws_y, 1)
	 -- k5 step
	 f(t + 1 * h, ws_y, ws_k5)
	 ffi.copy (ws_y, s.y, vecsize)
	 cblas.cblas_daxpy (2, h * -0.2962962962963, ws_k1, 1, ws_y, 1)
	 cblas.cblas_daxpy (2, h * 2, ws_k2, 1, ws_y, 1)
	 cblas.cblas_daxpy (2, h * -1.3816764132554, ws_k3, 1, ws_y, 1)
	 cblas.cblas_daxpy (2, h * 0.45297270955166, ws_k4, 1, ws_y, 1)
	 cblas.cblas_daxpy (2, h * -0.275, ws_k5, 1, ws_y, 1)
	 -- k6 step and final sum
	 -- since k2 is no more used we could use k2 to store k6
	 f(t + 0.5 * h, ws_y, ws_k6)
	 ffi.copy (ws_y, s.y, vecsize)
	 cblas.cblas_daxpy (2, h * 0.11851851851852, ws_k1, 1, ws_y, 1)
	 cblas.cblas_daxpy (2, h * 0.51898635477583, ws_k3, 1, ws_y, 1)
	 cblas.cblas_daxpy (2, h * 0.50613149034202, ws_k4, 1, ws_y, 1)
	 cblas.cblas_daxpy (2, h * -0.18, ws_k5, 1, ws_y, 1)
	 cblas.cblas_daxpy (2, h * 0.036363636363636, ws_k6, 1, ws_y, 1)
 
 local yerr, r, d0
 yerr = h * (0.0027777777777778 * ws_k1[0] + -0.029941520467836 * ws_k3[0] + -0.029199893673578 * ws_k4[0] + 0.02 * ws_k5[0] + 0.036363636363636 * ws_k6[0])
 d0 = 0 * (1 * abs(ws_y[0])) + 1e-006
 r = abs(yerr) / abs(d0)
 rmax = max(r, rmax)
 yerr = h * (0.0027777777777778 * ws_k1[1] + -0.029941520467836 * ws_k3[1] + -0.029199893673578 * ws_k4[1] + 0.02 * ws_k5[1] + 0.036363636363636 * ws_k6[1])
 d0 = 0 * (1 * abs(ws_y[1])) + 1e-006
 r = abs(yerr) / abs(d0)
 rmax = max(r, rmax)
 end
 hadj, inc = hadjust(rmax, h)
 if inc >= 0 then break end
 h = hadj
 end
 f(t + h, ws_y, ws_k2)
 ffi.copy (s.dydt, ws_k2, vecsize)
 ffi.copy (s.y, ws_y, vecsize)
 s.t = t + h
 s.h = hadj
 return h
end
return {new= ode_new, init= ode_init, evolve= rkf45_evolve}

Attachment: rkf45vec.lua.in
Description: Binary data


AltStyle によって変換されたページ (->オリジナル) /