-rw-r--r-- | examples/linfit.lua | 2 | ||||
-rw-r--r-- | examples/nlinfit.lua | 51 | ||||
-rw-r--r-- | matrix-init.lua | 1 | ||||
-rw-r--r-- | misc.lua | 39 | ||||
-rw-r--r-- | num/lmfit.lua.in | 4 | ||||
-rw-r--r-- | num/ode-defs.lua.in | 18 | ||||
-rw-r--r-- | num/rk8pd.lua.in | 12 | ||||
-rw-r--r-- | num/rkf45.lua.in | 12 |
diff --git a/examples/linfit.lua b/examples/linfit.lua index 3b023bcf..df67c28c 100644 --- a/examples/linfit.lua +++ b/examples/linfit.lua @@ -16,7 +16,7 @@ function demo1() print('Linear fit coefficients: ') - fit = function(x) return c[1]+c[2]*x end + local fit = function(x) return c[1]+c[2]*x end p = graph.fxplot(fit, x0, x1) p:add(graph.xyline(x, y), 'blue', {{'stroke'}, {'marker', size=5}}) diff --git a/examples/nlinfit.lua b/examples/nlinfit.lua index d5db8c87..5231b3c6 100644 --- a/examples/nlinfit.lua +++ b/examples/nlinfit.lua @@ -54,4 +54,55 @@ function demo1() p:show() end +function demo2() + local n = 50 + local px = matrix.vec {1.55, -1.1, 12.5} + local p0 = matrix.vec {2.5, -1.5, 5.3} + local xs = |i| (i-1)/n + local r = gsl.rng() + + local fmodel = function(p, t, J) + local e, s = exp(p[2] * t), sin(p[3] * t) + if J then + J[1] = e * s + J[2] = t * p[1] * e * s + J[3] = t * p[1] * e * cos(p[3] * t) + end + return p[1] * e * s + end + + local y = matrix.new(n, 1, |i,j| fmodel(px, xs(i)) * (1 + gsl.rnd.gaussian(r, 0.1))) + local x = matrix.new(n, 1, |i,j| xs(i)) + + local function fdf(p, f, J) + for k=1, n do + local ym = fmodel(p, xs(k), J and J[k]) + if f then f[k] = ym - y[k] end + end + end + + local pl = graph.plot('Non-linear fit / A * exp(a t) sin(w t)') + pl:addline(graph.xyline(x, y), 'blue', {{'marker', size= 5, mark="triangle"}}) + + local s = gsl.nlinfit {n= n, p= #p0} + + s:set(fdf, p0) + print(s.x, s.chisq) + + pl:addline(graph.fxline(|x| fmodel(s.x, x), 0, xs(n)), 'red', {{'dash', 7, 3, 3, 3}}) + + for i=1, 10 do + s:iterate() + print('ITER=', i, ': ', s.x, s.chisq) + if s:test(0, 1e-8) then break end + end + + pl:addline(graph.fxline(|x| fmodel(s.x, x), 0, xs(n)), 'red') + pl.pad = true + pl:show() + + return pl +end + echo 'demo1() - Simple non-linear fit example' +echo 'demo2() - Non-linear fir of oscillatory function' diff --git a/matrix-init.lua b/matrix-init.lua index 327e3192..0d91c149 100644 --- a/matrix-init.lua +++ b/matrix-init.lua @@ -248,6 +248,7 @@ local function matrix_tostring_gen(sel) end end local eps = sqrt(sq) * 1e-8 + eps = eps > 0 and eps or 1 lsrow = {} local lmax = 0 @@ -21,48 +21,17 @@ function gsl.ode(spec) local ode = template.load(string.format('num/%s.lua.in', method), spec) - local ode_methods = { - evolve = function(s, f, t) - s._sync = false - return ode.evolve(s._state, f, t) - end, - init = function(s, t0, h0, f, ...) - s._sync = false - return ode.init(s._state, t0, h0, f, ...) - end + local mt = { + __index = {evolve = ode.evolve, init = ode.init} } - - local ODE = {__index= function(s, k) - if k == 't' then return s._state.t end - if k == 'y' then - if not s._sync then - for k=1, s.dim do - s._y[k] = s._state.y[k-1] - end - s._sync = true - end - return s._y - end - return ode_methods[k] - end} - - local dim = spec.N - local solver = {_state = ode.new(), dim= dim, _y = matrix.new(dim,1)} - setmetatable(solver, ODE) - return solver + return setmetatable(ode.new(), mt) end local NLINFIT = { __index = function(t, k) if k == 'chisq' then - local f = t.lm.f - local csq = 0 - local n = matrix.dim(f) - for i=1, n do - csq = csq + f[i]^2 - end - return csq + return t.lm.chisq() else if t.lm[k] then return t.lm[k] end end diff --git a/num/lmfit.lua.in b/num/lmfit.lua.in index b7a997d8..613ebfcd 100644 --- a/num/lmfit.lua.in +++ b/num/lmfit.lua.in @@ -821,6 +821,10 @@ M.test = function(epsabs, epsrel) return test_delta(state_dx, state_x, epsabs, epsrel) end +M.chisq = function() + return cgsl.gsl_blas_dnrm2 (state_f) + end + M.x, M.f = object(state_x), object(state_f) return M diff --git a/num/ode-defs.lua.in b/num/ode-defs.lua.in index 667e1181..d962662d 100644 --- a/num/ode-defs.lua.in +++ b/num/ode-defs.lua.in @@ -34,24 +34,16 @@ # return table.concat(sm, ' + ') # end -local ffi = require 'ffi' - -ffi.cdef[[ - typedef struct { - double t; - double h; - double y[$(N)]; - double dydt[$(N)]; - } ode_state; -]] +local ffi = require 'ffi' local function ode_new() - return ffi.new('ode_state') + local n = $(N) + return {t = 0, h = 1, dim = n, y = matrix.new(n, 1), dydt = matrix.new(n, 1)} end local function ode_init(s, t0, h0, f, $(VL'y')) - $(AL's.y') = $(VL'y') - $(AL's.dydt') = f(t0, $(VL'y')) + $(AL's.y.data') = $(VL'y') + $(AL's.dydt.data') = f(t0, $(VL'y')) s.t = t0 s.h = h0 end diff --git a/num/rk8pd.lua.in b/num/rk8pd.lua.in index 284c05d8..d86ccb0e 100644 --- a/num/rk8pd.lua.in +++ b/num/rk8pd.lua.in @@ -163,16 +163,16 @@ $(include 'num/ode-defs.lua.in') # y_err_only = (a_dydt == 0) local function rk8pd_evolve(s, f, t1) - local t, h = s.t, s.h + local t, h, s_y, s_dydt = s.t, s.h, s.y, s.dydt local hadj, inc local $(VL'y') - local $(VL'k1') = $(AL's.dydt') + local $(VL'k1') = $(AL's_dydt.data') if t + h > t1 then h = t1 - t end while h > 0 do - $(VL'y') = $(AL's.y') + $(VL'y') = $(AL's_y.data') local rmax = 0 do @@ -194,7 +194,7 @@ local function rk8pd_evolve(s, f, t1) # if not y_err_only then local $(VL'dydt') = f(t + h, $(VL'y')) # for i = 0, N-1 do - s.dydt[$(i)] = dydt_$(i) + s_dydt.data[$(i)] = dydt_$(i) # end # end @@ -224,12 +224,12 @@ local function rk8pd_evolve(s, f, t1) # if y_err_only then local $(VL'dydt') = f(t + h, $(VL'y')) # for i = 0, N-1 do - s.dydt[$(i)] = dydt_$(i) + s_dydt.data[$(i)] = dydt_$(i) # end # end # for i = 0, N-1 do - s.y[$(i)] = y_$(i) + s_y.data[$(i)] = y_$(i) # end s.t = t + h s.h = hadj diff --git a/num/rkf45.lua.in b/num/rkf45.lua.in index a0cf592b..42765367 100644 --- a/num/rkf45.lua.in +++ b/num/rkf45.lua.in @@ -56,16 +56,16 @@ $(include 'num/ode-defs.lua.in') # y_err_only = (a_dydt == 0) local function rkf45_evolve(s, f, t1) - local t, h = s.t, s.h + local t, h, s_y, s_dydt = s.t, s.h, s.y, s.dydt local hadj, inc local $(VL'y') - local $(VL'k1') = $(AL's.dydt') + local $(VL'k1') = $(AL's_dydt.data') if t + h > t1 then h = t1 - t end while h > 0 do - $(VL'y') = $(AL's.y') + $(VL'y') = $(AL's_y.data') local rmax = 0 do @@ -87,7 +87,7 @@ local function rkf45_evolve(s, f, t1) # if not y_err_only then local $(VL'dydt') = f(t + h, $(VL'y')) # for i = 0, N-1 do - s.dydt[$(i)] = dydt_$(i) + s_dydt.data[$(i)] = dydt_$(i) # end # end @@ -115,12 +115,12 @@ local function rkf45_evolve(s, f, t1) # if y_err_only then local $(VL'dydt') = f(t + h, $(VL'y')) # for i = 0, N-1 do - s.dydt[$(i)] = dydt_$(i) + s_dydt.data[$(i)] = dydt_$(i) # end # end # for i = 0, N-1 do - s.y[$(i)] = y_$(i) + s_y.data[$(i)] = y_$(i) # end s.t = t + h s.h = hadj |