lua-users home
lua-l archive

Re: LuaJIT2 performance for number crunching

[Date Prev][Date Next][Thread Prev][Thread Next] [Date Index] [Thread Index]


Hi all,
after a good amount of work I've finalized the writing of the ODE
integration routine. The method that I've implemented is the Embedded
Runge-Kutta-Fehlberg (4, 5) method. This latter method is already much
better then the simple Runge-Kutta method and the idea is that the
following step would be to implement the Embedded Runge-Kutta
Prince-Dormand (8,9) method as someone suggested.
The implementation I've done is Lua is virtually identical to those
given in the GSL library. I've implemented the same methodology to
control the step size to limit the error accordingly to the user
input. The difference is that I don't use vector but everything is
expanded to local variables using a template preprocessor.
To develop the interface I've further refined the Lua preprocessor
that Steve Donovan made based on Rici Lake's original code snippet.
I've changed the implementation to avoid to write in the global
namespace and I've also adde a function to include other files during
pre processing. The resulting file is "template.lua".
In order to test the algorithm both for accuracy I've taken a basic
GSL example to show ODE evolution. I've changed the integration method
to rkf45, in the original examples was rk8pd (runge-kutte
prince-dormand). Then I've augmented the integration time and repeated
the whole process 10 times.
The results are just perfect in term of accuracy. Results produced
with LuaJIT2 are the same of those given by the C code.
For the other size it seems that there is a small problem because the
performance of LuaJIT2 are in this case below my expectations. Here
what I've got:
LuaJIT2:
real	0m14.498s
user	0m14.497s
sys	0m0.000s
C code (-O2) with GSL library:
real	0m1.094s
user	0m1.088s
sys	0m0.000s
so the C code in this case is approx 13.5x times faster.
I hope I've made a big stupid error in my implementation because my
hope was to have better results :-)
You will find in attachment all the files, if someone want to give a
look. The most important one is the preprocessed file,
"rkf45-out.lua". This file is generated from "rkf45.lua.in" and
"ode-defs.lua.in" by using the template module.
Otherwise if you want to reproduce the example with LuaJIT2 you will
need to add to math functions like sin, cos etc the "math." prefix.
The reason is that GSL shell put all the mathematical functions in the
common namespace. You can easily tests everyting by taking the luajit2
branch in the GSL shell git repository.
-- 
Francesco

Attachment: rkf45.lua.in
Description: Binary data

local ffi = require 'ffi'
ffi.cdef[[
 typedef struct {
 double t;
 double h;
 double y[2];
 double dydt[2];
 } ode_state;
]]
local function ode_new()
 return ffi.new('ode_state')
end
local function ode_init(s, t0, h0, f, y_0,y_1)
 s.y[0],s.y[1] = y_0,y_1
 s.dydt[0],s.dydt[1] = f(t0, y_0,y_1)
 s.t = t0
 s.h = h0
end
local function hadjust(rmax, h)
 local S = 0.9
 if rmax > 1.1 then
 local r = S / rmax^(1/5)
 r = max(0.2, r)
 return r * h, -1
 elseif rmax < 0.5 then
 local r = S / rmax^(1/(5+1))
 r = max(1, min(r, 5))
 return r * h, 1
 end
 return h, 0
end
-- These are the differences of fifth and fourth order coefficients
-- for error estimation */
function rkf45_evolve(s, f, t1)
 local t, h = s.t, s.h
 local dydt = s.dydt
 local hadj, inc
 local y_0,y_1
 local k1_0,k1_1 = dydt[0],dydt[1]
 if t + h > t1 then h = t1 - t end
 while h > 0 do
 y_0,y_1 = s.y[0],s.y[1]
 ytmp_0 = y_0 + 0.25 * h * k1_0
 ytmp_1 = y_1 + 0.25 * h * k1_1
 -- k2 step
 local k2_0,k2_1 = f(t + 0.25 * h, ytmp_0,ytmp_1)
 ytmp_0 = y_0 + h * (0.09375 * k1_0 + 0.28125 * k2_0)
 ytmp_1 = y_1 + h * (0.09375 * k1_1 + 0.28125 * k2_1)
 -- k3 step
 local k3_0,k3_1 = f(t + 0.375 * h, ytmp_0,ytmp_1)
 ytmp_0 = y_0 + h * (0.87938097405553 * k1_0 + -3.2771961766045 * k2_0 + 3.3208921256259 * k3_0)
 ytmp_1 = y_1 + h * (0.87938097405553 * k1_1 + -3.2771961766045 * k2_1 + 3.3208921256259 * k3_1)
 -- k4 step
 local k4_0,k4_1 = f(t + 0.92307692307692 * h, ytmp_0,ytmp_1)
 ytmp_0 = y_0 + h * (2.0324074074074 * k1_0 + -8 * k2_0 + 7.1734892787524 * k3_0 + -0.20589668615984 * k4_0)
 ytmp_1 = y_1 + h * (2.0324074074074 * k1_1 + -8 * k2_1 + 7.1734892787524 * k3_1 + -0.20589668615984 * k4_1)
 -- k5 step
 local k5_0,k5_1 = f(t + 1 * h, ytmp_0,ytmp_1)
 ytmp_0 = y_0 + h * (-0.2962962962963 * k1_0 + 2 * k2_0 + -1.3816764132554 * k3_0 + 0.45297270955166 * k4_0 + -0.275 * k5_0)
 ytmp_1 = y_1 + h * (-0.2962962962963 * k1_1 + 2 * k2_1 + -1.3816764132554 * k3_1 + 0.45297270955166 * k4_1 + -0.275 * k5_1)
 -- k6 step and final sum
 -- since k2 is no more used we can use k2 to store k6
 local k6_0,k6_1 = f(t + 0.5 * h, ytmp_0,ytmp_1)
 local di
 di = 0.11851851851852 * k1_0 + 0.51898635477583 * k3_0 + 0.50613149034202 * k4_0 + -0.18 * k5_0 + 0.036363636363636 * k6_0
 y_0 = y_0 + h * di
 di = 0.11851851851852 * k1_1 + 0.51898635477583 * k3_1 + 0.50613149034202 * k4_1 + -0.18 * k5_1 + 0.036363636363636 * k6_1
 y_1 = y_1 + h * di
 
 local yerr, r, d0
 local rmax = 0
 yerr = h * (0.0027777777777778 * k1_0 + -0.029941520467836 * k3_0 + -0.029199893673578 * k4_0 + 0.02 * k5_0 + 0.036363636363636 * k6_0)
 d0 = 0 * (1 * abs(y_0)) + 1e-06
 r = abs(yerr) / abs(d0)
 rmax = max(r, rmax)
 yerr = h * (0.0027777777777778 * k1_1 + -0.029941520467836 * k3_1 + -0.029199893673578 * k4_1 + 0.02 * k5_1 + 0.036363636363636 * k6_1)
 d0 = 0 * (1 * abs(y_1)) + 1e-06
 r = abs(yerr) / abs(d0)
 rmax = max(r, rmax)
 hadj, inc = hadjust(rmax, h)
 if inc >= 0 then break end
 
 h = hadj
 end
 dydt[0],dydt[1] = f(t + h, y_0,y_1)
 s.y[0],s.y[1] = y_0,y_1 
 s.t = t + h
 s.h = hadj
 return h
end
return {new= ode_new, init= ode_init, evolve= rkf45_evolve}
local template = require 'template'
local ode_spec = {N = 2, eps_abs = 1e-6, eps_rel = 0, a_y = 1, a_dydt = 0}
local codegen = template.compile('rkf45.lua.in', ode_spec)
local ode = codegen()
function f_vanderpol_gen(mu)
 return function(t, x, y) return y, -x + mu * y * (1-x^2) end
end
local f = f_vanderpol_gen(10.0)
local s = ode.new()
local x, y = 1, 0
local t0, t1, h0 = 0, 20000, 0.01
local init, evol = ode.init, ode.evolve
for k=1, 10 do
 init(s, t0, h0, f, x, y)
 while s.t < t1 do
 evol(s, f, t1)
 end
 print(s.t, s.y[0], s.y[1])
end
#include <stdio.h>
#include <gsl/gsl_errno.h>
#include <gsl/gsl_matrix.h>
#include <gsl/gsl_odeiv.h>
int
func (double t, const double y[], double f[],
 void *params)
{
 double mu = *(double *)params;
 f[0] = y[1];
 f[1] = -y[0] - mu*y[1]*(y[0]*y[0] - 1);
 return GSL_SUCCESS;
}
int
jac (double t, const double y[], double *dfdy,
 double dfdt[], void *params)
{
 double mu = *(double *)params;
 gsl_matrix_view dfdy_mat
 = gsl_matrix_view_array (dfdy, 2, 2);
 gsl_matrix * m = &dfdy_mat.matrix;
 gsl_matrix_set (m, 0, 0, 0.0);
 gsl_matrix_set (m, 0, 1, 1.0);
 gsl_matrix_set (m, 1, 0, -2.0*mu*y[0]*y[1] - 1.0);
 gsl_matrix_set (m, 1, 1, -mu*(y[0]*y[0] - 1.0));
 dfdt[0] = 0.0;
 dfdt[1] = 0.0;
 return GSL_SUCCESS;
}
int
main (void)
{
 const gsl_odeiv_step_type * T = gsl_odeiv_step_rkf45;
 int k;
 for (k=0; k < 10; k++)
 {
 gsl_odeiv_step * s = gsl_odeiv_step_alloc (T, 2);
 gsl_odeiv_control * c = gsl_odeiv_control_y_new (1e-6, 0.0);
 gsl_odeiv_evolve * e = gsl_odeiv_evolve_alloc (2);
 double mu = 10;
 gsl_odeiv_system sys = {func, jac, 2, &mu};
 double t = 0.0, t1 = 20000.0;
 double h = 1e-6;
 double y[2] = { 1.0, 0.0 };
 while (t < t1)
	{
	 int status = gsl_odeiv_evolve_apply (e, c, s,
					 &sys,
					 &t, t1,
					 &h, y);
	 if (status != GSL_SUCCESS)
	 break;
	}
 printf ("%g %g %g\n", t, y[0], y[1]);
 gsl_odeiv_evolve_free (e);
 gsl_odeiv_control_free (c);
 gsl_odeiv_step_free (s);
 }
 return 0;
}

Attachment: ode-defs.lua.in
Description: Binary data

--
-- A Lua preprocessor for template code specialization.
-- Adapted by Steve Donovan, based on original code of Rici Lake.
--
local M = {}
-------------------------------------------------------------------------------
local function preprocess(chunk, name, defs)
 local function parseDollarParen(pieces, chunk, s, e)
 local append, format = table.insert, string.format
 local s = 1
 for term, executed, e in chunk:gmatch("()$(%b())()") do
	 append(pieces,
		format("%q..(%s or '')..", chunk:sub(s, term - 1), executed))
	 s = e
 end
 append(pieces, format("%q", chunk:sub(s)))
 end
 local function parseHashLines(chunk)
 local append = table.insert
 local pieces, s, args = chunk:find("^\n*#ARGS%s*(%b())[ \t]*\n")
 if not args or find(args, "^%(%s*%)$") then
	 pieces, s = {"return function(_put) ", n = 1}, s or 1
 else
	 pieces = {"return function(_put, ", args:sub(2), n = 2}
 end
 while true do
	 local ss, e, lua = chunk:find("^#+([^\n]*\n?)", s)
	 if not e then
	 ss, e, lua = chunk:find("\n#+([^\n]*\n?)", s)
	 append(pieces, "_put(")
	 parseDollarParen(pieces, chunk:sub(s, ss))
	 append(pieces, ")")
	 if not e then break end
	 end
	 append(pieces, lua)
	 s = e + 1
 end
 append(pieces, " end")
 return table.concat(pieces)
 end
 local ppenv
 if defs._self then
 ppenv = defs._self
 else
 ppenv = {string= string, table= table, template= M}
 for k, v in pairs(defs) do ppenv[k] = v end
 ppenv._self = ppenv
 local include = function(filename)
			 return M.process(filename, ppenv)
		 end
 setfenv(include, ppenv)
 ppenv.include = include
 end
 local code = parseHashLines(chunk)
 local fcode = loadstring(code, name)
 if fcode then
 setfenv(fcode, ppenv)
 return fcode()
 end
end
local function read_file(filename)
 local f = io.open(filename)
 local content = f:read('*a')
 f:close()
 return content
end
local function process(filename, defs)
 local template = read_file(filename)
 local codegen = preprocess(template, 'ode_codegen', defs)
 local code = {}
 local add = function(s) code[#code+1] = s end
 codegen(add)
 return table.concat(code)
end
local function compile(filename, defs)
 return loadstring(process(filename, defs), 'ode_out')
end
M.process = process
M.compile = compile
return M

AltStyle によって変換されたページ (->オリジナル) /