gsl-shell.git - gsl-shell

index : gsl-shell.git
gsl-shell
summary refs log tree commit diff
path: root/gdt-lm.lua
diff options
context:
space:
mode:
Diffstat (limited to 'gdt-lm.lua')
-rw-r--r--gdt-lm.lua 28
1 files changed, 26 insertions, 2 deletions
diff --git a/gdt-lm.lua b/gdt-lm.lua
index 307f7fd5..9e3d74a4 100644
--- a/gdt-lm.lua
+++ b/gdt-lm.lua
@@ -16,6 +16,28 @@ local function find_column_type(t, j)
return 1
end
+local function lm_expr_names(t, expr)
+ local n, m = t:dim()
+
+ local code_lines = {}
+ local code = function(line) code_lines[#code_lines+1] = line end
+
+ code([[local _LM = require 'lm-helpers']])
+ code([[local enum = function(x) return x end]])
+ local var_name = t:headers()
+ for j = 1, m do
+ local name = var_name[j]
+ code(format([[local %s = _LM.var_name(%q)]], name, name))
+ end
+ code(format([[local expr_names = _LM.find_names(%s)]], expr))
+ code([[return expr_names]])
+
+ local code_str = concat(code_lines, '\n')
+ local f = load(code_str)
+ local expr_names = f()
+ return expr_names
+end
+
local function lm_prepare(t, expr)
local n, m = t:dim()
local column_class = {}
@@ -72,6 +94,7 @@ local function lm_main(Xt, t, inf)
local coeff_name = {}
inf.factors, inf.factor_index = {}, {}
for k = 1, inf.np do
+ local factor_name = inf.factor_names[k]
index[k] = curr_index
if inf.class[k] == 0 then
local factors, factor_index = {}, {}
@@ -80,14 +103,14 @@ local function lm_main(Xt, t, inf)
local ui = add_unique(factors, str)
if ui > 0 then
factor_index[str] = ui
- if ui > 1 then coeff_name[curr_index + (ui - 2)] = str end
+ if ui > 1 then coeff_name[curr_index + (ui - 2)] = factor_name .. "/" .. str end
end
end
inf.factors[k] = factors
inf.factor_index[k] = factor_index
curr_index = curr_index + (#factors - 1)
else
- coeff_name[curr_index] = string.char(string.byte('a') + curr_index - 1)
+ coeff_name[curr_index] = factor_name
curr_index = curr_index + 1
end
end
@@ -123,6 +146,7 @@ local function lm_model(t, expr)
local code = lm_prepare(t, expr)
local f_code = load(code)()
local Xt, inf = f_code(t)
+ inf.factor_names = lm_expr_names(t, expr)
return lm_main(Xt, t, inf)
end
generated by cgit v1.2.3 (git 2.39.1) at 2025年09月17日 08:50:22 +0000

AltStyle によって変換されたページ (->オリジナル) /