lua-users home
lua-l archive

Re: LuaJIT2 performance for number crunching

[Date Prev][Date Next][Thread Prev][Thread Next] [Date Index] [Thread Index]


Hi all,
here I am again with my Lua numeric algorithm implementation. I've
managed to write the rkf45 ODE integrator in vector form using the
cblas functions to perform vector arithmentic.
The results is good in term of accuracy, I've tested for the same
example as before with N=2 and I obtain the same results. For the
other side the execution is speed is ~ 100x - 150x slower!
I guess the reason is that, once again, for some reason LuaJIT2 refuse
to compile the code but I don't have any clear idea of why that
happens.
In attachment you will find the algorithm in vector form
rkf45vec.lua.in and the result after template preprocessing,
rkf45vec-out.lua. I include also the good version of the template and
the benchmark code.
Please note that you cannot run it with plain luajit2 because this
latter isn't linked with cblas as gsl shell. I guess that this problem
can be easily solved by loading the cblas library but I can give more
help if needed.
I've given a look at the trace and it seems that the root of the
problem is the cblas function that LuaJIT2 doesn't like:
[TRACE --- rkf45vec-out.lua:78 -- NYI: unsupported C function type at
rkf45vec-out.lua:83]
the function incriminated is cblas_daxpy. But I don't really know.
I hope that Mike can save me yet another time! :-)
Francesco

Attachment: rkf45vec.lua.in
Description: Binary data

local abs, max, min = math.abs, math.max, math.min
local ffi = require "ffi"
local vecsize = 2 * ffi.sizeof('double')
ffi.cdef[[
 typedef struct {
 double t;
 double h;
 double y[2];
 double dydt[2];
 } odevec_state;
 typedef struct {
 double y0[2];
 double ytmp[2];
 double k1[2];
 double k2[2];
 double k3[2];
 double k4[2];
 double k5[2];
 double k6[2];
 } ode_workspace;
 void cblas_daxpy (const int N, const double ALPHA,
		 const double * X, const int INCX,
		 double * Y, const int INCY);
 int cblas_idamax (const int N, const double * X, const int INCX);
 void cblas_dscal (const int N, const double ALPHA, double * X, const int INCX);
]]
local function ode_new()
 return ffi.new('odevec_state')
end
local function ode_init(s, t0, h0, f, y)
 ffi.copy(s.y, y, vecsize)
 f(t0, s.y, s.dydt)
 s.t = t0
 s.h = h0
end
local function hadjust(rmax, h)
 local S = 0.9
 if rmax > 1.1 then
 local r = S / rmax^(1/5)
 r = max(0.2, r)
 return r * h, -1
 elseif rmax < 0.5 then
 local r = S / rmax^(1/(5+1))
 r = max(1, min(r, 5))
 return r * h, 1
 end
 return h, 0
end
local ws = ffi.new('ode_workspace')
local function rkf45_evolve(s, f, t1)
 local t, h = s.t, s.h
 local hadj, inc
 ffi.copy (ws.y0, s.y, vecsize)
 ffi.copy (ws.k1, s.dydt, vecsize)
 if t + h > t1 then h = t1 - t end
 while h > 0 do
 local rmax = 0
 do
	 ffi.copy (ws.ytmp, s.y, vecsize)
	 ffi.C.cblas_daxpy (2, h * 0.25, ws.k1, 1, ws.ytmp, 1)
	 -- k2 step
	 f(t + 0.25 * h, ws.ytmp, ws.k2)
	 ffi.copy (ws.ytmp, s.y, vecsize)
	 ffi.C.cblas_daxpy (2, h * 0.09375, ws.k1, 1, ws.ytmp, 1)
	 ffi.C.cblas_daxpy (2, h * 0.28125, ws.k2, 1, ws.ytmp, 1)
	 -- k3 step
	 f(t + 0.375 * h, ws.ytmp, ws.k3)
	 ffi.copy (ws.ytmp, s.y, vecsize)
	 ffi.C.cblas_daxpy (2, h * 0.87938097405553, ws.k1, 1, ws.ytmp, 1)
	 ffi.C.cblas_daxpy (2, h * -3.2771961766045, ws.k2, 1, ws.ytmp, 1)
	 ffi.C.cblas_daxpy (2, h * 3.3208921256259, ws.k3, 1, ws.ytmp, 1)
	 -- k4 step
	 f(t + 0.92307692307692 * h, ws.ytmp, ws.k4)
	 ffi.copy (ws.ytmp, s.y, vecsize)
	 ffi.C.cblas_daxpy (2, h * 2.0324074074074, ws.k1, 1, ws.ytmp, 1)
	 ffi.C.cblas_daxpy (2, h * -8, ws.k2, 1, ws.ytmp, 1)
	 ffi.C.cblas_daxpy (2, h * 7.1734892787524, ws.k3, 1, ws.ytmp, 1)
	 ffi.C.cblas_daxpy (2, h * -0.20589668615984, ws.k4, 1, ws.ytmp, 1)
	 -- k5 step
	 f(t + 1 * h, ws.ytmp, ws.k5)
	 ffi.copy (ws.ytmp, s.y, vecsize)
	 ffi.C.cblas_daxpy (2, h * -0.2962962962963, ws.k1, 1, ws.ytmp, 1)
	 ffi.C.cblas_daxpy (2, h * 2, ws.k2, 1, ws.ytmp, 1)
	 ffi.C.cblas_daxpy (2, h * -1.3816764132554, ws.k3, 1, ws.ytmp, 1)
	 ffi.C.cblas_daxpy (2, h * 0.45297270955166, ws.k4, 1, ws.ytmp, 1)
	 ffi.C.cblas_daxpy (2, h * -0.275, ws.k5, 1, ws.ytmp, 1)
	 -- k6 step and final sum
	 -- since k2 is no more used we could use k2 to store k6
	 f(t + 0.5 * h, ws.ytmp, ws.k6)
	 ffi.C.cblas_daxpy (2, h * 0.11851851851852, ws.k1, 1, s.y, 1)
	 ffi.C.cblas_daxpy (2, h * 0.51898635477583, ws.k3, 1, s.y, 1)
	 ffi.C.cblas_daxpy (2, h * 0.50613149034202, ws.k4, 1, s.y, 1)
	 ffi.C.cblas_daxpy (2, h * -0.18, ws.k5, 1, s.y, 1)
	 ffi.C.cblas_daxpy (2, h * 0.036363636363636, ws.k6, 1, s.y, 1)
 
 local yerr, r, d0
 yerr = h * (0.0027777777777778 * ws.k1[0] + -0.029941520467836 * ws.k3[0] + -0.029199893673578 * ws.k4[0] + 0.02 * ws.k5[0] + 0.036363636363636 * ws.k6[0])
 d0 = 0 * (1 * abs(s.y[0])) + 1e-06
 r = abs(yerr) / abs(d0)
 rmax = max(r, rmax)
 yerr = h * (0.0027777777777778 * ws.k1[1] + -0.029941520467836 * ws.k3[1] + -0.029199893673578 * ws.k4[1] + 0.02 * ws.k5[1] + 0.036363636363636 * ws.k6[1])
 d0 = 0 * (1 * abs(s.y[1])) + 1e-06
 r = abs(yerr) / abs(d0)
 rmax = max(r, rmax)
 end
 hadj, inc = hadjust(rmax, h)
 if inc >= 0 then break end
 ffi.copy(s.y, ws.y0, vecsize)
 h = hadj
 end
 f(t + h, s.y, s.dydt)
 s.t = t + h
 s.h = hadj
 return h
end
return {new= ode_new, init= ode_init, evolve= rkf45_evolve}
--
-- A Lua preprocessor for template code specialization.
-- Adapted by Steve Donovan, based on original code of Rici Lake.
--
local M = {}
-------------------------------------------------------------------------------
local function preprocess(chunk, name, defs)
 local function parseDollarParen(pieces, chunk, s, e)
 local append, format = table.insert, string.format
 local s = 1
 for term, executed, e in chunk:gmatch("()$(%b())()") do
	 append(pieces,
		format("%q..(%s or '')..", chunk:sub(s, term - 1), executed))
	 s = e
 end
 append(pieces, format("%q", chunk:sub(s)))
 end
 local function parseHashLines(chunk)
 local append = table.insert
 local pieces, s, args = chunk:find("^\n*#ARGS%s*(%b())[ \t]*\n")
 if not args or find(args, "^%(%s*%)$") then
	 pieces, s = {"return function(_put) ", n = 1}, s or 1
 else
	 pieces = {"return function(_put, ", args:sub(2), n = 2}
 end
 while true do
	 local ss, e, lua = chunk:find("^#+([^\n]*\n?)", s)
	 if not e then
	 ss, e, lua = chunk:find("\n#+([^\n]*\n?)", s)
	 append(pieces, "_put(")
	 parseDollarParen(pieces, chunk:sub(s, ss))
	 append(pieces, ")")
	 if not e then break end
	 end
	 append(pieces, lua)
	 s = e + 1
 end
 append(pieces, " end")
 return table.concat(pieces)
 end
 local ppenv
 if defs._self then
 ppenv = defs._self
 else
 ppenv = {string= string, table= table, template= M}
 for k, v in pairs(defs) do ppenv[k] = v end
 ppenv._self = ppenv
 local include = function(filename)
			 return M.process(filename, ppenv)
		 end
 setfenv(include, ppenv)
 ppenv.include = include
 end
 local code = parseHashLines(chunk)
 local fcode = loadstring(code, name)
 if fcode then
 setfenv(fcode, ppenv)
 return fcode()
 end
end
local function read_file(filename)
 local f = io.open(filename)
 local content = f:read('*a')
 f:close()
 return content
end
local function process(filename, defs)
 local template = read_file(filename)
 local codegen = preprocess(template, 'ode_codegen', defs)
 local code = {}
 local add = function(s) code[#code+1] = s end
 codegen(add)
 return table.concat(code)
end
local function require(filename)
 local f = loadstring(process(filename .. '.lua.in', {}), 'ode_out')
 if not f then error 'error loading ODE module' end
 return f()
end
local function load(filename, defs)
 local f = loadstring(process(filename, defs), 'ode_out')
 if not f then error 'error loading ODE module' end
 return f()
end
M.process = process
M.require = require
M.load = load
return M
local template = require 'template'
local ffi = require "ffi"
local ode_spec = {N = 2, eps_abs = 1e-6, eps_rel = 0, a_y = 1, a_dydt = 0}
local ode = template.load('rkf45vec.lua.in', ode_spec)
function f_vanderpol_gen(mu)
 return function(t, y, f) 
	 f[0] = y[1]
	 f[1] = -y[0] + mu * y[1] * (1-y[0]^2)
	 end
end
local f = f_vanderpol_gen(10.0)
local s = ode.new()
local y = ffi.new('double[2]', {1, 0})
local t0, t1, h0 = 0, 20000, 0.01
local init, evol = ode.init, ode.evolve
for k=1, 10 do
 init(s, t0, h0, f, y)
 while s.t < t1 do
 evol(s, f, t1)
 end
 print(s.t, s.y[0], s.y[1])
end

AltStyle によって変換されたページ (->オリジナル) /