-rw-r--r-- | gdt-lm.lua | 28 | ||||
-rw-r--r-- | lm-expr.lua | 26 | ||||
-rw-r--r-- | lm-helpers.lua | 28 |
diff --git a/gdt-lm.lua b/gdt-lm.lua index 307f7fd5..9e3d74a4 100644 --- a/gdt-lm.lua +++ b/gdt-lm.lua @@ -16,6 +16,28 @@ local function find_column_type(t, j) return 1 end +local function lm_expr_names(t, expr) + local n, m = t:dim() + + local code_lines = {} + local code = function(line) code_lines[#code_lines+1] = line end + + code([[local _LM = require 'lm-helpers']]) + code([[local enum = function(x) return x end]]) + local var_name = t:headers() + for j = 1, m do + local name = var_name[j] + code(format([[local %s = _LM.var_name(%q)]], name, name)) + end + code(format([[local expr_names = _LM.find_names(%s)]], expr)) + code([[return expr_names]]) + + local code_str = concat(code_lines, '\n') + local f = load(code_str) + local expr_names = f() + return expr_names +end + local function lm_prepare(t, expr) local n, m = t:dim() local column_class = {} @@ -72,6 +94,7 @@ local function lm_main(Xt, t, inf) local coeff_name = {} inf.factors, inf.factor_index = {}, {} for k = 1, inf.np do + local factor_name = inf.factor_names[k] index[k] = curr_index if inf.class[k] == 0 then local factors, factor_index = {}, {} @@ -80,14 +103,14 @@ local function lm_main(Xt, t, inf) local ui = add_unique(factors, str) if ui > 0 then factor_index[str] = ui - if ui > 1 then coeff_name[curr_index + (ui - 2)] = str end + if ui > 1 then coeff_name[curr_index + (ui - 2)] = factor_name .. "/" .. str end end end inf.factors[k] = factors inf.factor_index[k] = factor_index curr_index = curr_index + (#factors - 1) else - coeff_name[curr_index] = string.char(string.byte('a') + curr_index - 1) + coeff_name[curr_index] = factor_name curr_index = curr_index + 1 end end @@ -123,6 +146,7 @@ local function lm_model(t, expr) local code = lm_prepare(t, expr) local f_code = load(code)() local Xt, inf = f_code(t) + inf.factor_names = lm_expr_names(t, expr) return lm_main(Xt, t, inf) end diff --git a/lm-expr.lua b/lm-expr.lua new file mode 100644 index 00000000..3c59d7c3 --- /dev/null +++ b/lm-expr.lua @@ -0,0 +1,26 @@ +local format = string.format + +local function as_char(v, prio) + if type(v) == 'table' and v.name then + return (v.prio < prio and format("(%s)", v.name) or v.name) + end + return tostring(v) +end + +local var_name + +local var_name_mt = { + __add = function(a, b) return var_name(format("%s + %s", as_char(a, 0), as_char(b, 0)), 0) end, + __mul = function(a, b) return var_name(format("%s * %s", as_char(a, 2), as_char(b, 2)), 2) end, + __sub = function(a, b) return var_name(format("%s - %s", as_char(a, 0), as_char(b, 0)), 0) end, + __div = function(a, b) return var_name(format("%s / %s", as_char(a, 2), as_char(b, 2)), 2) end, + __pow = function(a, b) return var_name(format("%s^%s", as_char(a, 10), as_char(b, 10)), 10) end, + __unm = function(a) return var_name(format("-%s", as_char(a, 1)), 1) end, +} + +var_name = function (name, prio) + local t = {prio= prio or 10, name= name} + return setmetatable(t, var_name_mt) +end + +return var_name diff --git a/lm-helpers.lua b/lm-helpers.lua index 48a5685e..098971fa 100644 --- a/lm-helpers.lua +++ b/lm-helpers.lua @@ -1,5 +1,7 @@ local LM = {} +local var_name = require 'lm-expr' + local factor_mt local function mul_factor(a, b) @@ -34,4 +36,30 @@ function LM.eval_func(inf, pt, i, ...) end end +LM.var_name = var_name + +local function expr_to_name(expr) + if type(expr) == 'table' then + return expr.name + else + local base = '(average)' + local minus = expr < 0 and '- ' or '' + if expr == 1 or expr == -1 then + return string.format("%s%s", minus, base) + else + return string.format("%s%s / %g", minus, base, math.abs(expr)) + end + end +end + +function LM.find_names(...) + local n = select("#", ...) + local names = {} + for k = 1, n do + local expr = select(k, ...) + names[k] = expr_to_name(expr) + end + return names +end + return LM |