-rw-r--r-- | gdt-lm.lua | 28 |
diff --git a/gdt-lm.lua b/gdt-lm.lua index 307f7fd5..9e3d74a4 100644 --- a/gdt-lm.lua +++ b/gdt-lm.lua @@ -16,6 +16,28 @@ local function find_column_type(t, j) return 1 end +local function lm_expr_names(t, expr) + local n, m = t:dim() + + local code_lines = {} + local code = function(line) code_lines[#code_lines+1] = line end + + code([[local _LM = require 'lm-helpers']]) + code([[local enum = function(x) return x end]]) + local var_name = t:headers() + for j = 1, m do + local name = var_name[j] + code(format([[local %s = _LM.var_name(%q)]], name, name)) + end + code(format([[local expr_names = _LM.find_names(%s)]], expr)) + code([[return expr_names]]) + + local code_str = concat(code_lines, '\n') + local f = load(code_str) + local expr_names = f() + return expr_names +end + local function lm_prepare(t, expr) local n, m = t:dim() local column_class = {} @@ -72,6 +94,7 @@ local function lm_main(Xt, t, inf) local coeff_name = {} inf.factors, inf.factor_index = {}, {} for k = 1, inf.np do + local factor_name = inf.factor_names[k] index[k] = curr_index if inf.class[k] == 0 then local factors, factor_index = {}, {} @@ -80,14 +103,14 @@ local function lm_main(Xt, t, inf) local ui = add_unique(factors, str) if ui > 0 then factor_index[str] = ui - if ui > 1 then coeff_name[curr_index + (ui - 2)] = str end + if ui > 1 then coeff_name[curr_index + (ui - 2)] = factor_name .. "/" .. str end end end inf.factors[k] = factors inf.factor_index[k] = factor_index curr_index = curr_index + (#factors - 1) else - coeff_name[curr_index] = string.char(string.byte('a') + curr_index - 1) + coeff_name[curr_index] = factor_name curr_index = curr_index + 1 end end @@ -123,6 +146,7 @@ local function lm_model(t, expr) local code = lm_prepare(t, expr) local f_code = load(code)() local Xt, inf = f_code(t) + inf.factor_names = lm_expr_names(t, expr) return lm_main(Xt, t, inf) end |