gsl-shell.git - gsl-shell

index : gsl-shell.git
gsl-shell
summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat
-rw-r--r--gdt-lm.lua 150
-rw-r--r--gdt-plot.lua 15
-rw-r--r--gdt.lua 11
-rw-r--r--lm-helpers.lua 37
4 files changed, 198 insertions, 15 deletions
diff --git a/gdt-lm.lua b/gdt-lm.lua
new file mode 100644
index 00000000..dd821372
--- /dev/null
+++ b/gdt-lm.lua
@@ -0,0 +1,150 @@
+
+local cgdt = require 'cgdt'
+
+local element_is_number = gdt.element_is_number
+local format = string.format
+local concat = table.concat
+
+-- status: 0 => string, 1 => numbers
+local function find_column_type(t, j)
+ local n = #t
+ for i = 1, n do
+ local x = gdt.get_number_unsafe(t, i, j)
+ if not x then return 0 end
+ end
+ return 1
+end
+
+local function lm_prepare(t, expr)
+ local n, m = t:dim()
+ local column_class = {}
+ local code_lines = {}
+ local code = function(line) code_lines[#code_lines+1] = line end
+ local name = t:headers(j)
+ code([[local _LM = require 'lm-helpers']])
+ code([[local select = select]])
+ code([[local _get, _set = gdt.get, gdt.set]])
+ code([[local enum = function(x) return {value = x} end]])
+ for j = 1, m do
+ column_class[j] = find_column_type(t, j)
+ local value = (column_class[j] == 1 and "0" or format("_LM.factor(\"%s\")", name[j]))
+ code(format("local %s = %s", name[j], value))
+ end
+
+ code(format([[local _y_spec = _LM.eval_test(%s)]], expr))
+
+ code([[local _eval_func = _LM.eval_func]])
+
+ code(format("local _eval = gdt.new(%d, _y_spec.np)", n))
+
+ code(format("for _i = 1, %d do", n))
+ for j = 1, m do
+ local line
+ if column_class[j] == 1 then
+ line = format(" %s = _get(_t, _i, %d)", name[j], j)
+ else
+ line = format(" %s.value = _get(_t, _i, %d)", name[j], j)
+ end
+ code(line)
+ end
+ code("")
+ code(format(" _eval_func(_y_spec, _eval, _i, %s)", expr))
+ code(format("end", n))
+
+ return format("return function(_t)\n%s\nreturn _eval, _y_spec\nend", concat(code_lines, "\n"))
+end
+
+local function add_unique(t, val)
+ for k, x in ipairs(t) do
+ if x == val then return 0 end
+ end
+ local n = #t + 1
+ t[n] = val
+ return n
+end
+
+local function lm_main(Xt, t, inf)
+ local N = #t
+
+ local index = {}
+ local curr_index = 1
+ local coeff_name = {}
+ inf.factors, inf.factor_index = {}, {}
+ for k = 1, inf.np do
+ index[k] = curr_index
+ if inf.class[k] == 0 then
+ local factors, factor_index = {}, {}
+ for i = 1, N do
+ local str = Xt:get(i, k)
+ local ui = add_unique(factors, str)
+ if ui > 0 then
+ factor_index[str] = ui
+ if ui > 1 then coeff_name[curr_index + (ui - 2)] = str end
+ end
+ end
+ inf.factors[k] = factors
+ inf.factor_index[k] = factor_index
+ curr_index = curr_index + (#factors - 1)
+ else
+ coeff_name[curr_index] = string.char(string.byte('a') + curr_index - 1)
+ curr_index = curr_index + 1
+ end
+ end
+
+ local col = curr_index - 1
+
+ local X = matrix.alloc(N, col)
+ local X_data = X.data
+ for i = 1, N do
+ local idx0 = col * (i - 1)
+ for k = 1, inf.np do
+ local j = index[k]
+ if inf.class[k] == 1 then
+ X_data[idx0 + j - 1] = gdt.get_number_unsafe(Xt, i, k)
+ else
+ local factors = inf.factors[k]
+ local factor_index = inf.factor_index[k]
+ local req_f = Xt:get(i, k)
+ local nf = #factors - 1
+ for kf = 1, nf do
+ X_data[idx0 + (j - 1) + (kf - 1)] = 0
+ end
+ local kfx = factor_index[req_f] - 1
+ if kfx > 0 then X_data[idx0 + (j - 1) + (kfx - 1)] = 1 end
+ end
+ end
+ end
+
+ return X, coeff_name
+end
+
+local function lm_model(t, expr)
+ local code = lm_prepare(t, expr)
+ local f_code = load(code)()
+ local Xt, inf = f_code(t)
+ return lm_main(Xt, t, inf)
+end
+
+local function lm(t, expr)
+ local a, b = string.match(expr, "%s*([%S]+)%s*~(.+)")
+ assert(a, "invalid lm expression")
+ local n, m = t:dim()
+ local jy = t:col_index(a)
+ assert(jy, "invalid variable specification in lm expression")
+ local sqrt = math.sqrt
+ local y = matrix.new(n, 1, |i| t:get(i, jy))
+ local X, name = lm_model(t, b)
+ local c, chisq, cov = num.linfit(X, y)
+ local coeff = gdt.new(#c, 3)
+ coeff:set_header(1, "name")
+ coeff:set_header(2, "value")
+ coeff:set_header(3, "stddev")
+ for i = 1, #c do
+ coeff:set(i, 1, name[i])
+ coeff:set(i, 2, c[i])
+ coeff:set(i, 3, sqrt(cov:get(i,i)))
+ end
+ return {coeff = coeff, c = c, chisq = chisq, cov = cov, X = X}
+end
+
+gdt.lm = lm
diff --git a/gdt-plot.lua b/gdt-plot.lua
index 55ad3f59..68aec567 100644
--- a/gdt-plot.lua
+++ b/gdt-plot.lua
@@ -388,22 +388,7 @@ local function set_elements(X, P, i, ...)
end
end
-local function gdt_table_linfit(t, f, jy)
- local N, M = t:dim()
- local row = t:cursor()
- local P = count_args(f(row))
-
- local X, Y = matrix.alloc(N, P), matrix.alloc(N, 1)
- for i, row in t:rows() do
- set_elements(X, P, i, f(row, i))
- Y:set(i, 1, t:get(i, jy))
- end
-
- return num.linfit(X, Y)
-end
-
gdt.barplot = gdt_table_barplot
gdt.plot = gdt_table_lineplot
gdt.xyplot = gdt_table_xyplot
gdt.reduce = gdt_table_reduce
-gdt.lm = gdt_table_linfit
diff --git a/gdt.lua b/gdt.lua
index a53198fd..d6c769b8 100644
--- a/gdt.lua
+++ b/gdt.lua
@@ -11,6 +11,10 @@ local gdt_table_cursor = ffi.typeof("gdt_table_cursor")
local TAG_STRING = tonumber(cgdt.TAG_STRING)
local TAG_NUMBER = tonumber(cgdt.TAG_NUMBER)
+local function element_is_number(e)
+ return (e.word.hi <= TAG_NUMBER)
+end
+
local function gdt_element(t, e)
local val
if e.word.hi <= TAG_NUMBER then
@@ -30,6 +34,11 @@ local function gdt_table_get(t, i, j)
return gdt_element(t, e)
end
+local function gdt_table_get_number_unsafe(t, i, j)
+ local e = cgdt.gdt_table_get(t, i - 1, j - 1)
+ if e.word.hi <= TAG_NUMBER then return e.number end
+end
+
local function gdt_table_set(t, i, j, val)
assert(i > 0 and i <= t.size1, 'invalid row index')
assert(j > 0 and j <= t.size2, 'invalid column index')
@@ -280,6 +289,8 @@ gdt = {
get = gdt_table_get,
set = gdt_table_set,
filter = gdt_table_filter,
+
+ get_number_unsafe = gdt_table_get_number_unsafe,
}
return gdt
diff --git a/lm-helpers.lua b/lm-helpers.lua
new file mode 100644
index 00000000..48a5685e
--- /dev/null
+++ b/lm-helpers.lua
@@ -0,0 +1,37 @@
+local LM = {}
+
+local factor_mt
+
+local function mul_factor(a, b)
+ local c = {value= a.value .. ":" .. b.value}
+ return setmetatable(c, factor_mt)
+end
+
+factor_mt = {
+ __mul = mul_factor,
+}
+
+function LM.factor(name)
+ local t = {name= name, value = ""}
+ return setmetatable(t, factor_mt)
+end
+
+function LM.eval_test(...)
+ local inf = {np = select('#', ...)}
+ inf.class = {}
+ for k= 1, inf.np do
+ local v = select(k, ...)
+ inf.class[k] = (type(v) == 'number' and 1 or 0)
+ end
+ return inf
+end
+
+function LM.eval_func(inf, pt, i, ...)
+ for k = 1, inf.np do
+ local x = select(k, ...)
+ local value = (inf.class[k] == 1 and x or x.value)
+ gdt.set(pt, i, k, value)
+ end
+end
+
+return LM
generated by cgit v1.2.3 (git 2.39.1) at 2025年09月21日 05:10:22 +0000

AltStyle によって変換されたページ (->オリジナル) /