-rw-r--r-- | misc.lua | 31 | ||||
-rw-r--r-- | num/ode-defs.lua.in | 18 | ||||
-rw-r--r-- | num/rk8pd.lua.in | 12 | ||||
-rw-r--r-- | num/rkf45.lua.in | 12 |
@@ -21,36 +21,11 @@ function gsl.ode(spec) local ode = template.load(string.format('num/%s.lua.in', method), spec) - local ode_methods = { - evolve = function(s, f, t) - s._sync = false - return ode.evolve(s._state, f, t) - end, - init = function(s, t0, h0, f, ...) - s._sync = false - return ode.init(s._state, t0, h0, f, ...) - end + local mt = { + __index = {evolve = ode.evolve, init = ode.init} } - - local ODE = {__index= function(s, k) - if k == 't' then return s._state.t end - if k == 'y' then - if not s._sync then - for k=1, s.dim do - s._y[k] = s._state.y[k-1] - end - s._sync = true - end - return s._y - end - return ode_methods[k] - end} - - local dim = spec.N - local solver = {_state = ode.new(), dim= dim, _y = matrix.new(dim,1)} - setmetatable(solver, ODE) - return solver + return setmetatable(ode.new(), mt) end local NLINFIT = { diff --git a/num/ode-defs.lua.in b/num/ode-defs.lua.in index 667e1181..d962662d 100644 --- a/num/ode-defs.lua.in +++ b/num/ode-defs.lua.in @@ -34,24 +34,16 @@ # return table.concat(sm, ' + ') # end -local ffi = require 'ffi' - -ffi.cdef[[ - typedef struct { - double t; - double h; - double y[$(N)]; - double dydt[$(N)]; - } ode_state; -]] +local ffi = require 'ffi' local function ode_new() - return ffi.new('ode_state') + local n = $(N) + return {t = 0, h = 1, dim = n, y = matrix.new(n, 1), dydt = matrix.new(n, 1)} end local function ode_init(s, t0, h0, f, $(VL'y')) - $(AL's.y') = $(VL'y') - $(AL's.dydt') = f(t0, $(VL'y')) + $(AL's.y.data') = $(VL'y') + $(AL's.dydt.data') = f(t0, $(VL'y')) s.t = t0 s.h = h0 end diff --git a/num/rk8pd.lua.in b/num/rk8pd.lua.in index 284c05d8..d86ccb0e 100644 --- a/num/rk8pd.lua.in +++ b/num/rk8pd.lua.in @@ -163,16 +163,16 @@ $(include 'num/ode-defs.lua.in') # y_err_only = (a_dydt == 0) local function rk8pd_evolve(s, f, t1) - local t, h = s.t, s.h + local t, h, s_y, s_dydt = s.t, s.h, s.y, s.dydt local hadj, inc local $(VL'y') - local $(VL'k1') = $(AL's.dydt') + local $(VL'k1') = $(AL's_dydt.data') if t + h > t1 then h = t1 - t end while h > 0 do - $(VL'y') = $(AL's.y') + $(VL'y') = $(AL's_y.data') local rmax = 0 do @@ -194,7 +194,7 @@ local function rk8pd_evolve(s, f, t1) # if not y_err_only then local $(VL'dydt') = f(t + h, $(VL'y')) # for i = 0, N-1 do - s.dydt[$(i)] = dydt_$(i) + s_dydt.data[$(i)] = dydt_$(i) # end # end @@ -224,12 +224,12 @@ local function rk8pd_evolve(s, f, t1) # if y_err_only then local $(VL'dydt') = f(t + h, $(VL'y')) # for i = 0, N-1 do - s.dydt[$(i)] = dydt_$(i) + s_dydt.data[$(i)] = dydt_$(i) # end # end # for i = 0, N-1 do - s.y[$(i)] = y_$(i) + s_y.data[$(i)] = y_$(i) # end s.t = t + h s.h = hadj diff --git a/num/rkf45.lua.in b/num/rkf45.lua.in index a0cf592b..42765367 100644 --- a/num/rkf45.lua.in +++ b/num/rkf45.lua.in @@ -56,16 +56,16 @@ $(include 'num/ode-defs.lua.in') # y_err_only = (a_dydt == 0) local function rkf45_evolve(s, f, t1) - local t, h = s.t, s.h + local t, h, s_y, s_dydt = s.t, s.h, s.y, s.dydt local hadj, inc local $(VL'y') - local $(VL'k1') = $(AL's.dydt') + local $(VL'k1') = $(AL's_dydt.data') if t + h > t1 then h = t1 - t end while h > 0 do - $(VL'y') = $(AL's.y') + $(VL'y') = $(AL's_y.data') local rmax = 0 do @@ -87,7 +87,7 @@ local function rkf45_evolve(s, f, t1) # if not y_err_only then local $(VL'dydt') = f(t + h, $(VL'y')) # for i = 0, N-1 do - s.dydt[$(i)] = dydt_$(i) + s_dydt.data[$(i)] = dydt_$(i) # end # end @@ -115,12 +115,12 @@ local function rkf45_evolve(s, f, t1) # if y_err_only then local $(VL'dydt') = f(t + h, $(VL'y')) # for i = 0, N-1 do - s.dydt[$(i)] = dydt_$(i) + s_dydt.data[$(i)] = dydt_$(i) # end # end # for i = 0, N-1 do - s.y[$(i)] = y_$(i) + s_y.data[$(i)] = y_$(i) # end s.t = t + h s.h = hadj |