author | francesco <francesco.bbt@gmail.com> | 2011年03月15日 15:11:19 +0100 |
---|---|---|
committer | francesco <francesco.bbt@gmail.com> | 2011年03月15日 15:11:19 +0100 |
commit | bc4bb3cfa6b625349f594fefc6b53b8da8148ba4 (patch) | |
tree | c3e6032ef5ee937043e6c04b7618aa2319e3c67c | |
parent | 31901995023dcc5777326f1deb8a1b5e39f59b5a (diff) | |
download | gsl-shell-bc4bb3cfa6b625349f594fefc6b53b8da8148ba4.tar.gz |
-rw-r--r-- | benchmarks/lmfit/lm_test_enso.lua (renamed from lm_test_enso.lua) | 0 | ||||
-rw-r--r-- | benchmarks/lmfit/lm_test_thurber.lua (renamed from lm_test_thurber.lua) | 0 | ||||
-rw-r--r-- | examples/nlinfit-jit.lua | 56 | ||||
-rw-r--r-- | lua-lmtest.lua | 37 | ||||
-rw-r--r-- | misc.lua | 86 |
diff --git a/lm_test_enso.lua b/benchmarks/lmfit/lm_test_enso.lua index 761eb717..761eb717 100644 --- a/lm_test_enso.lua +++ b/benchmarks/lmfit/lm_test_enso.lua diff --git a/lm_test_thurber.lua b/benchmarks/lmfit/lm_test_thurber.lua index 461b6516..461b6516 100644 --- a/lm_test_thurber.lua +++ b/benchmarks/lmfit/lm_test_thurber.lua diff --git a/examples/nlinfit-jit.lua b/examples/nlinfit-jit.lua new file mode 100644 index 00000000..01685cd3 --- /dev/null +++ b/examples/nlinfit-jit.lua @@ -0,0 +1,56 @@ + +local sin, cos, exp, sqrt = math.sin, math.cos, math.exp, math.sqrt +local pi = math.pi + +function demo1() + local n = 40 + + local yrf, sigrf + + local fdf = function(x, f, J) + for i=1, n do + local A, lambda, b = x[1], x[2], x[3] + local t, y, sig = i-1, yrf[i], sigrf[i] + local e = exp(- lambda * t) + if f then f[i] = (A*e+b - y)/sig end + if J then + J:set(i, 1, e / sig) + J:set(i, 2, - t * A * e / sig) + J:set(i, 3, 1 / sig) + end + end + end + + local model = function(x, t) + local A, lambda, b = x[1], x[2], x[3] + return A * exp(- lambda * t) + b + end + + local xref = gsl.vector {5, 0.1, 1} + + local r = gsl.rng('mt19937') + r:set(0) + + yrf = gsl.new(n, 1, function(i) return model(xref, i-1) + gsl.rnd.gaussian(r, 0.1) end) + sigrf = gsl.new(n, 1, function() return 0.1 end) + + local s = gsl.nlinfit {n= n, p= 3} + + s:set(fdf, gsl.vector {1, 0, 0}) + print(gsl.tr(s.x), s.chisq) + + for i=1, 10 do + s:iterate() + print('ITER=', i, ': ', gsl.tr(s.x), s.chisq) + if s:test(0, 1e-8) then break end + end + + local p = graph.plot('Non-linear fit example') + local pts = graph.ipath(gsl.sequence(function(i) return i-1, yrf[i] end, n)) + local fitln = graph.fxline(function(t) return model(s.x, t) end, 0, n-1) + p:addline(pts, 'blue', {{'marker', size=4}}) + p:addline(fitln) + p:show() +end + +gsl.echo 'demo1() - Simple non-linear fit example' diff --git a/lua-lmtest.lua b/lua-lmtest.lua deleted file mode 100644 index 46c8944d..00000000 --- a/lua-lmtest.lua +++ /dev/null @@ -1,37 +0,0 @@ - -local sin, cos, exp, sqrt = math.sin, math.cos, math.exp, math.sqrt - -local n = 40 - -local yrf, sigrf - -local fdf = function(x, f, J) - for i=1, n do - local A, lambda, b = x[1], x[2], x[3] - local t, y, sig = i-1, yrf[i], sigrf[i] - local e = exp(- lambda * t) - if f then f[i] = (A*e+b - y)/sig end - if J then - J:set(i, 1, e / sig) - J:set(i, 2, - t * A * e / sig) - J:set(i, 3, 1 / sig) - end - end - end - -local A, lambda, b = 5, 0.1, 1 -local r = gsl.rng('mt19937') -r:set(0) - -yrf = gsl.new(n, 1, function(i) return A * exp(-lambda*(i-1)) + b + gsl.rnd.gaussian(r, 0.1) end) -sigrf = gsl.new(n, 1, function() return 0.1 end) - -local s = gsl.nlinfit {n= n, p= 3} - -s:set(fdf, gsl.vector {1, 0, 0}) -print(gsl.tr(s.x), s.chisq) - -for i=1, 10 do - s:iterate() - print('ITER=', i, ': ', gsl.tr(s.x), s.chisq) -end diff --git a/misc.lua b/misc.lua new file mode 100644 index 00000000..38aa74ac --- /dev/null +++ b/misc.lua @@ -0,0 +1,86 @@ +local gsl = gsl or _G + +local template = require 'template' + +function gsl.ode(spec) + local required = {N= 'number', eps_abs= 'number'} + local defaults = {eps_rel = 0, a_y = 1, a_dydt = 0} + local is_known = {rkf45= true, rk8pd= true} + + for k, tp in pairs(required) do + if type(spec[k]) ~= tp then + error(string.format('parameter %s should be a %s', k, tp)) + end + end + for k, v in pairs(defaults) do + if not spec[k] then spec[k] = v end + end + + local method = spec.method and spec.method or 'rkf45' + if not is_known[method] then error('unknown ode method: ' .. method) end + spec.method = nil + + local ode = template.load(string.format('num/%s.lua.in', method), spec) + + local ode_methods = { + evolve = function(s, f, t) + s._sync = false + return ode.evolve(s._state, f, t) + end, + init = function(s, t0, h0, f, ...) + s._sync = false + return ode.init(s._state, t0, h0, f, ...) + end + } + + local ODE = {__index= function(s, k) + if k == 't' then return s._state.t end + if k == 'y' then + if not s._sync then + for k=1, s.dim do + s._y[k] = s._state.y[k-1] + end + s._sync = true + end + return s._y + end + return ode_methods[k] + end} + + local dim = spec.N + local solver = {_state = ode.new(), dim= dim, _y = gsl.new(dim,1)} + setmetatable(solver, ODE) + + return solver +end + +local NLINFIT = { + __index = function(t, k) + if k == 'chisq' then + local f = t.lm.f + return gsl.prod(f, f)[1] + else + if t.lm[k] then return t.lm[k] end + end + end +} + +function gsl.nlinfit(spec) + if not spec.n then error 'number of points "n" not specified' end + if not spec.p then error 'number of parameters "p" not specified' end + + if spec.n <= 0 or spec.p <= 0 then + error '"n" and "p" shoud be positive integers' + end + + local n, p = spec.n, spec.p + local s = { lm = template.load('num/lmfit.lua.in', {N= n, P= p}) } + + s.set = function(ss, fdf, x0) return ss.lm.set(fdf, x0) end + s.iterate = function(ss) return ss.lm.iterate() end + s.test = function(ss, epsabs, epsrel) return ss.lm.test(epsabs, epsrel) end + + setmetatable(s, NLINFIT) + + return s +end |