lua-users home
lua-l archive

Re: LuaJIT2 performance for number crunching

[Date Prev][Date Next][Thread Prev][Thread Next] [Date Index] [Thread Index]


On 16 February 2011 11:15, steve donovan <steve.j.donovan@gmail.com> wrote:
> On Wed, Feb 16, 2011 at 1:04 PM, T T <t34www@googlemail.com> wrote:
>> of code I would like to write or work with.  One would hope that such
>> a (common?) case as unrolling small loops could be done automatically
>> for the programmer, rather than require writing special templating
>> solutions.
>
> One issue is that we do not have this:
>
> const n = 3
>
> With that kind of guarantee, unrolling becomes possible.
OK, I went of my way and templated Francesco's original code in a
simple way to see if I can do any better, the code is attached
(rk4-unroll4.lua). Basically, I unrolled the loops up to 4 levels
like this:
 if dim > 4 then
 for i = 1,dim do
 y0[i] = y[i]
 end
 else
 if dim >= 1 then
 y0[1] = y[1]
 end
 if dim >= 2 then
 y0[2] = y[2]
 end
 if dim >= 3 then
 y0[3] = y[3]
 end
 if dim >= 4 then
 y0[4] = y[4]
 end
 end
Note that this works for any dimension (thus, it is not specialized
and doesn't require run-time generation).
Timings with luajit -Omaxsnap=300:
 dim=2 time=0.25 sec
 dim=3 time=0.29 sec
 dim=4 time=0.41 sec
 dim=5 time=1.02 sec
 dim=6 time=1.23 sec
Hey, now that's not too bad, is it? Mike Pall's code posted earlier
(rk_2d.lua) runs in 0.23 sec on my machine. That's darn close to what
I've got for dim=2.
> One approach to this is to use an integrated preprocessor that uses
> the token-filter patch. That would make our 'n' to be a macro which is
> expanded as '3' wherever used.
Why not let the machine do it based on the runtime value of 'n' (with
some safe guards to fall back on the interpreter if it changes)?
> But the situation that Francesco is dealing with is template
> specialization, as a C++ programmer would understand it.  That
> requires re-compiling different cases.
I'm not convinced that this is the best way forward for stuff like simple loops.
Cheers,
Tomek
use = 'Lua'
if use == 'FFI' then
 ffi = require 'ffi'
 darray = ffi.typeof("double[?]")
elseif use == 'GSL' then
 darray = function(n) return new(n, 1) end
else
 darray = function(n) local t = {}; for k = 1,n do t[k] = 0; end; return t end
end
rk4 = {}
function rk4.new(n)
 local s = {
 k= darray(n+1), 
 k1= darray(n+1),
 y0= darray(n+1),
 ytmp= darray(n+1),
 y_onestep= darray(n+1),
 dim = n
 }
 return s
end
function rk4.step(y, state, h, t, sys)
 -- Makes a Runge-Kutta 4th order advance with step size h.
 local dim = state.dim
 local f = sys.f
 -- initial values of variables y.
 local y0 = state.y0
 
 -- work space 
 local ytmp = state.ytmp
 -- Runge-Kutta coefficients. Contains values of coefficient k1
 -- in the beginning 
 local k = state.k
 -- k1 step 
 if dim > 4 then
 for i = 1,dim do
 y[i] = y[i] + h / 6 * k[i]; ytmp[i] = y0[i] + 0.5 * h * k[i]
 end
 else
 if dim >= 1 then
 y[1] = y[1] + h / 6 * k[1]; ytmp[1] = y0[1] + 0.5 * h * k[1]
 end
 if dim >= 2 then
 y[2] = y[2] + h / 6 * k[2]; ytmp[2] = y0[2] + 0.5 * h * k[2]
 end
 if dim >= 3 then
 y[3] = y[3] + h / 6 * k[3]; ytmp[3] = y0[3] + 0.5 * h * k[3]
 end
 if dim >= 4 then
 y[4] = y[4] + h / 6 * k[4]; ytmp[4] = y0[4] + 0.5 * h * k[4]
 end
 end
 
 -- k2 step
 f(t + 0.5 * h, ytmp, k)
 if dim > 4 then
 for i = 1,dim do
 y[i] = y[i] + h / 3 * k[i]; ytmp[i] = y0[i] + 0.5 * h * k[i]
 end
 else
 if dim >= 1 then
 y[1] = y[1] + h / 3 * k[1]; ytmp[1] = y0[1] + 0.5 * h * k[1]
 end
 if dim >= 2 then
 y[2] = y[2] + h / 3 * k[2]; ytmp[2] = y0[2] + 0.5 * h * k[2]
 end
 if dim >= 3 then
 y[3] = y[3] + h / 3 * k[3]; ytmp[3] = y0[3] + 0.5 * h * k[3]
 end
 if dim >= 4 then
 y[4] = y[4] + h / 3 * k[4]; ytmp[4] = y0[4] + 0.5 * h * k[4]
 end
 end
 -- k3 step 
 f(t + 0.5 * h, ytmp, k)
 if dim > 4 then
 for i = 1,dim do
 y[i] = y[i] + h / 3 * k[i]; ytmp[i] = y0[i] + 0.5 * h * k[i]
 end
 else
 if dim >= 1 then
 y[1] = y[1] + h / 3 * k[1]; ytmp[1] = y0[1] + 0.5 * h * k[1]
 end
 if dim >= 2 then
 y[2] = y[2] + h / 3 * k[2]; ytmp[2] = y0[2] + 0.5 * h * k[2]
 end
 if dim >= 3 then
 y[3] = y[3] + h / 3 * k[3]; ytmp[3] = y0[3] + 0.5 * h * k[3]
 end
 if dim >= 4 then
 y[4] = y[4] + h / 3 * k[4]; ytmp[4] = y0[4] + 0.5 * h * k[4]
 end
 end
 -- k4 step 
 f(t + h, ytmp, k)
 if dim > 4 then
 for i = 1,dim do
 y[i] = y[i] + h / 6 * k[i]
 end
 else
 if dim >= 1 then
 y[1] = y[1] + h / 6 * k[1]
 end
 if dim >= 2 then
 y[2] = y[2] + h / 6 * k[2]
 end
 if dim >= 3 then
 y[3] = y[3] + h / 6 * k[3]
 end
 if dim >= 4 then
 y[4] = y[4] + h / 6 * k[4]
 end
 end
end
function rk4.apply(state, t, h, y, yerr, dydt_in, dydt_out, sys)
 local f, dim = sys.f, state.dim
 local k, k1, y0, y_onestep = state.k, state.k1, state.y0, state.y_onestep
 if dim > 4 then
 for i = 1,dim do
 y0[i] = y[i]
 end
 else
 if dim >= 1 then
 y0[1] = y[1]
 end
 if dim >= 2 then
 y0[2] = y[2]
 end
 if dim >= 3 then
 y0[3] = y[3]
 end
 if dim >= 4 then
 y0[4] = y[4]
 end
 end
 if dydt_in then 
 if dim > 4 then
 for i = 1,dim do
 k[i] = dydt_in[i]
 end
 else
 if dim >= 1 then
 k[1] = dydt_in[1]
 end
 if dim >= 2 then
 k[2] = dydt_in[2]
 end
 if dim >= 3 then
 k[3] = dydt_in[3]
 end
 if dim >= 4 then
 k[4] = dydt_in[4]
 end
 end
 else 
 f(t, y0, k)
 end
 -- Error estimation is done by step doubling procedure 
 -- Save first point derivatives
 if dim > 4 then
 for i = 1,dim do
 k1[i] = k[i]
 end
 else
 if dim >= 1 then
 k1[1] = k[1]
 end
 if dim >= 2 then
 k1[2] = k[2]
 end
 if dim >= 3 then
 k1[3] = k[3]
 end
 if dim >= 4 then
 k1[4] = k[4]
 end
 end
 -- First traverse h with one step (save to y_onestep) 
 if dim > 4 then
 for i = 1,dim do
 y_onestep[i] = y[i]
 end
 else
 if dim >= 1 then
 y_onestep[1] = y[1]
 end
 if dim >= 2 then
 y_onestep[2] = y[2]
 end
 if dim >= 3 then
 y_onestep[3] = y[3]
 end
 if dim >= 4 then
 y_onestep[4] = y[4]
 end
 end
 rk4.step (y_onestep, state, h, t, sys)
 -- Then with two steps with half step length (save to y) 
 if dim > 4 then
 for i = 1,dim do
 k[i] = k1[i]
 end
 else
 if dim >= 1 then
 k[1] = k1[1]
 end
 if dim >= 2 then
 k[2] = k1[2]
 end
 if dim >= 3 then
 k[3] = k1[3]
 end
 if dim >= 4 then
 k[4] = k1[4]
 end
 end
 rk4.step(y, state, h/2, t, sys)
 -- Update before second step 
 f(t + h/2, y, k)
 
 -- Save original y0 to k1 for possible failures 
 if dim > 4 then
 for i = 1,dim do
 k1[i] = y0[i]
 end
 else
 if dim >= 1 then
 k1[1] = y0[1]
 end
 if dim >= 2 then
 k1[2] = y0[2]
 end
 if dim >= 3 then
 k1[3] = y0[3]
 end
 if dim >= 4 then
 k1[4] = y0[4]
 end
 end
 -- Update y0 for second step 
 if dim > 4 then
 for i = 1,dim do
 y0[i] = y[i]
 end
 else
 if dim >= 1 then
 y0[1] = y[1]
 end
 if dim >= 2 then
 y0[2] = y[2]
 end
 if dim >= 3 then
 y0[3] = y[3]
 end
 if dim >= 4 then
 y0[4] = y[4]
 end
 end
 rk4.step(y, state, h/2, t + h/2, sys)
 -- Derivatives at output
 if dydt_out then f(t + h, y, dydt_out) end
 
 -- Error estimation
 --
 -- yerr = C * 0.5 * | y(onestep) - y(twosteps) | / (2^order - 1)
 --
 -- constant C is approximately 8.0 to ensure 90% of samples lie within
 -- the error (assuming a gaussian distribution with prior p(sigma)=1/sigma.)
 if dim > 4 then
 for i = 1,dim do
 yerr[i] = 4 * (y[i] - y_onestep[i]) / 15
 end
 else
 if dim >= 1 then
 yerr[1] = 4 * (y[1] - y_onestep[1]) / 15
 end
 if dim >= 2 then
 yerr[2] = 4 * (y[2] - y_onestep[2]) / 15
 end
 if dim >= 3 then
 yerr[3] = 4 * (y[3] - y_onestep[3]) / 15
 end
 if dim >= 4 then
 yerr[4] = 4 * (y[4] - y_onestep[4]) / 15
 end
 end
end
function f_ode1(t, y, dydt)
 local p, q = y[1], y[2]
 dydt[1] = - q - p^2
 dydt[2] = 2*p - q^3
end
t0, t1, h0 = 0, 200, 0.001
function do_rk(p0, q0, sample, dim)
-- local dim = tonumber(os.getenv('dim') or 2)
 local state = rk4.new(dim)
 local y, dydt, yerr = darray(dim+1), darray(dim+1), darray(dim+1)
 local sys = {f = f_ode1}
 y[1], y[2] = p0, q0
 local t = t0
 local tsamp = t0
 rk4.apply(state, t, h0, y, yerr, nil, dydt, sys)
 t = t + h0
 while t < t1 do
 rk4.apply(state, t, h0, y, yerr, dydt, dydt, sys)
 t = t + h0
 if sample and t - tsamp > sample then
 print(t, y[1], y[2])
 tsamp = t
 end
 end
 print(t, y[1], y[2])
end
for k=1, 10 do
 local th = math.pi/4 -- *(k-1)/5
 local p0, q0 = math.cos(th), math.sin(th)
 local dim = tonumber(os.getenv('dim') or 2)
 do_rk(p0, q0, sample, dim)
end

AltStyle によって変換されたページ (->オリジナル) /