gsl-shell.git - gsl-shell

index : gsl-shell.git
gsl-shell
summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat
-rw-r--r--gdt-lm.lua 28
-rw-r--r--lm-expr.lua 26
-rw-r--r--lm-helpers.lua 28
3 files changed, 80 insertions, 2 deletions
diff --git a/gdt-lm.lua b/gdt-lm.lua
index 307f7fd5..9e3d74a4 100644
--- a/gdt-lm.lua
+++ b/gdt-lm.lua
@@ -16,6 +16,28 @@ local function find_column_type(t, j)
return 1
end
+local function lm_expr_names(t, expr)
+ local n, m = t:dim()
+
+ local code_lines = {}
+ local code = function(line) code_lines[#code_lines+1] = line end
+
+ code([[local _LM = require 'lm-helpers']])
+ code([[local enum = function(x) return x end]])
+ local var_name = t:headers()
+ for j = 1, m do
+ local name = var_name[j]
+ code(format([[local %s = _LM.var_name(%q)]], name, name))
+ end
+ code(format([[local expr_names = _LM.find_names(%s)]], expr))
+ code([[return expr_names]])
+
+ local code_str = concat(code_lines, '\n')
+ local f = load(code_str)
+ local expr_names = f()
+ return expr_names
+end
+
local function lm_prepare(t, expr)
local n, m = t:dim()
local column_class = {}
@@ -72,6 +94,7 @@ local function lm_main(Xt, t, inf)
local coeff_name = {}
inf.factors, inf.factor_index = {}, {}
for k = 1, inf.np do
+ local factor_name = inf.factor_names[k]
index[k] = curr_index
if inf.class[k] == 0 then
local factors, factor_index = {}, {}
@@ -80,14 +103,14 @@ local function lm_main(Xt, t, inf)
local ui = add_unique(factors, str)
if ui > 0 then
factor_index[str] = ui
- if ui > 1 then coeff_name[curr_index + (ui - 2)] = str end
+ if ui > 1 then coeff_name[curr_index + (ui - 2)] = factor_name .. "/" .. str end
end
end
inf.factors[k] = factors
inf.factor_index[k] = factor_index
curr_index = curr_index + (#factors - 1)
else
- coeff_name[curr_index] = string.char(string.byte('a') + curr_index - 1)
+ coeff_name[curr_index] = factor_name
curr_index = curr_index + 1
end
end
@@ -123,6 +146,7 @@ local function lm_model(t, expr)
local code = lm_prepare(t, expr)
local f_code = load(code)()
local Xt, inf = f_code(t)
+ inf.factor_names = lm_expr_names(t, expr)
return lm_main(Xt, t, inf)
end
diff --git a/lm-expr.lua b/lm-expr.lua
new file mode 100644
index 00000000..3c59d7c3
--- /dev/null
+++ b/lm-expr.lua
@@ -0,0 +1,26 @@
+local format = string.format
+
+local function as_char(v, prio)
+ if type(v) == 'table' and v.name then
+ return (v.prio < prio and format("(%s)", v.name) or v.name)
+ end
+ return tostring(v)
+end
+
+local var_name
+
+local var_name_mt = {
+ __add = function(a, b) return var_name(format("%s + %s", as_char(a, 0), as_char(b, 0)), 0) end,
+ __mul = function(a, b) return var_name(format("%s * %s", as_char(a, 2), as_char(b, 2)), 2) end,
+ __sub = function(a, b) return var_name(format("%s - %s", as_char(a, 0), as_char(b, 0)), 0) end,
+ __div = function(a, b) return var_name(format("%s / %s", as_char(a, 2), as_char(b, 2)), 2) end,
+ __pow = function(a, b) return var_name(format("%s^%s", as_char(a, 10), as_char(b, 10)), 10) end,
+ __unm = function(a) return var_name(format("-%s", as_char(a, 1)), 1) end,
+}
+
+var_name = function (name, prio)
+ local t = {prio= prio or 10, name= name}
+ return setmetatable(t, var_name_mt)
+end
+
+return var_name
diff --git a/lm-helpers.lua b/lm-helpers.lua
index 48a5685e..098971fa 100644
--- a/lm-helpers.lua
+++ b/lm-helpers.lua
@@ -1,5 +1,7 @@
local LM = {}
+local var_name = require 'lm-expr'
+
local factor_mt
local function mul_factor(a, b)
@@ -34,4 +36,30 @@ function LM.eval_func(inf, pt, i, ...)
end
end
+LM.var_name = var_name
+
+local function expr_to_name(expr)
+ if type(expr) == 'table' then
+ return expr.name
+ else
+ local base = '(average)'
+ local minus = expr < 0 and '- ' or ''
+ if expr == 1 or expr == -1 then
+ return string.format("%s%s", minus, base)
+ else
+ return string.format("%s%s / %g", minus, base, math.abs(expr))
+ end
+ end
+end
+
+function LM.find_names(...)
+ local n = select("#", ...)
+ local names = {}
+ for k = 1, n do
+ local expr = select(k, ...)
+ names[k] = expr_to_name(expr)
+ end
+ return names
+end
+
return LM
generated by cgit v1.2.3 (git 2.39.1) at 2025年09月12日 18:12:42 +0000

AltStyle によって変換されたページ (->オリジナル) /