WIP: fix problems with gdt.lm implementation - gsl-shell.git - gsl-shell

index : gsl-shell.git
gsl-shell
summary refs log tree commit diff
diff options
context:
space:
mode:
authorFrancesco Abbate <francesco.bbt@gmail.com>2013年05月12日 23:17:57 +0200
committerFrancesco Abbate <francesco.bbt@gmail.com>2013年05月12日 23:17:57 +0200
commit338fd9e668b582105def249a3be6a377d0160272 (patch)
tree65f2c2324b85ecb7722b9a517d04764a3774632f
parent204fa404431f9c619f0dfc5ae3efc463749fa8ed (diff)
downloadgsl-shell-338fd9e668b582105def249a3be6a377d0160272.tar.gz
WIP: fix problems with gdt.lm implementation
Use variadic functions for stateless expr evaluation. Do not store levels for numeric variables. Separate computation of levels and names from model matrix computations.
Diffstat
-rw-r--r--expr-print.lua 12
-rw-r--r--gdt-expr.lua 224
-rw-r--r--gdt-factors.lua 35
-rw-r--r--gdt-interp.lua 18
-rw-r--r--gdt-lm.lua 47
5 files changed, 181 insertions, 155 deletions
diff --git a/expr-print.lua b/expr-print.lua
index 03b1e185..ebbc6518 100644
--- a/expr-print.lua
+++ b/expr-print.lua
@@ -67,13 +67,13 @@ local function eval_operator(op, a, b)
else error('unkown operation: ' .. op) end
end
-local function eval(expr, scope, scope_state)
+local function eval(expr, scope, ...)
if type(expr) == 'number' then
return expr
elseif type(expr) == 'string' then
- return scope.ident(expr, scope_state)
+ return scope.ident(expr, ...)
elseif expr.func then
- local arg_value = eval(expr.arg, scope, scope_state)
+ local arg_value = eval(expr.arg, scope, ...)
if arg_value then
local f = scope.func(expr)
if not f then error('unknown function: ' .. expr.func) end
@@ -81,11 +81,11 @@ local function eval(expr, scope, scope_state)
end
else
if #expr == 1 then
- local v = eval(expr[1], scope, scope_state)
+ local v = eval(expr[1], scope, ...)
if v then return -v end
else
- local a = eval(expr[1], scope, scope_state)
- local b = eval(expr[2], scope, scope_state)
+ local a = eval(expr[1], scope, ...)
+ local b = eval(expr[2], scope, ...)
if a and b then
return eval_operator(expr.operator, a, b)
end
diff --git a/gdt-expr.lua b/gdt-expr.lua
index 17e77799..f36b36a8 100644
--- a/gdt-expr.lua
+++ b/gdt-expr.lua
@@ -25,15 +25,17 @@ local function expr_is_unit(e)
return e == 1
end
-local function add_expr_refs(expr, refs)
+local function add_expr_refs(expr, refs, factor_refs)
expr_print.references(expr.scalar, refs)
if expr.factor then
- for k, f in ipairs(expr.factor) do refs[f] = true end
+ for k, f in ipairs(expr.factor) do
+ refs[f] = true
+ factor_refs[f] = true
+ end
end
end
-local function table_var_resolve(expr, st)
- local t, i = st[1], st[2]
+local function table_var_resolve(expr, t, i)
return t:get(i, expr)
end
@@ -41,32 +43,28 @@ local function math_func_resolve(expr)
return math[expr.func]
end
-local eval_table_context = {
+local table_context = {
ident = table_var_resolve,
func = math_func_resolve,
}
local function map_missing_rows(t, expr_list, y_expr_scalar, conditions)
- local refs, levels = {}, {}
+ local refs, factor_refs, levels = {}, {}, {}
for k, expr in ipairs(expr_list) do
- add_expr_refs(expr, refs)
+ add_expr_refs(expr, refs, factor_refs)
end
if y_expr_scalar then
expr_print.references(y_expr_scalar, refs)
end
- for factor_name in pairs(refs) do
+ for factor_name in pairs(factor_refs) do
levels[factor_name] = {}
end
local N = #t
local index_map = {}
local map_i, map_len = 1, 0
- local eval_state = {t, 0}
for i = 1, N do
- -- set the row to be evaluated to "i"
- eval_state[2] = i
-
local row_undef = false
for col_name in pairs(refs) do
row_undef = row_undef or (not t:get(i, col_name))
@@ -74,12 +72,12 @@ local function map_missing_rows(t, expr_list, y_expr_scalar, conditions)
if not row_undef then
for _, cond in ipairs(conditions) do
- local cx = expr_print.eval(cond, eval_table_context, eval_state)
+ local cx = expr_print.eval(cond, table_context, t, i)
row_undef = row_undef or (cx == 0)
end
end
if not row_undef then
- for col_name in pairs(refs) do
+ for col_name in pairs(factor_refs) do
list_add_unique(levels[col_name], t:get(i, col_name))
end
end
@@ -133,9 +131,99 @@ local function index_map_iter(index_map, ils)
end
local function annotate_mult(expr_list, levels)
+ local n = 0
+ for _, expr in ipairs(expr_list) do
+ local mult = expr.factor and level_number(expr.factor, levels) or 1
+ expr.mult = mult
+ n = n + mult
+ end
+ return n
+end
+
+local function pred_coeff_name(pred)
+ local ls = {}
+ for k, name, level in iter_by_two, pred, -1 do
+ ls[#ls+1] = string.format("%s:%s", name, level)
+ end
+ return table.concat(ls, ' / ')
+end
+
+local function eval_predicates(factors, levels)
+ local NF = #factors
+ local factor_levels = {}
+ local counter = {}
+ for p = 1, NF do
+ factor_levels[p] = levels[factors[p]]
+ counter[p] = 0
+ end
+
+ local pred_list = {}
+ -- the following code cycles through all the factors/levels
+ -- combinations for the given factor set (subset of ls at
+ -- index "k")
+ counter[NF + 1] = 0
+ while counter[NF + 1] == 0 do
+ local pred = {}
+ for p = 1, NF do
+ pred[#pred + 1] = factors[p]
+ pred[#pred + 1] = factor_levels[p][counter[p] + 2]
+ end
+ pred_list[#pred_list+1] = pred
+
+ for p = 1, NF + 1 do
+ local cn = counter[p] + 1
+ if p > NF or cn < #factor_levels[p] - 1 then
+ counter[p] = cn
+ break
+ else
+ counter[p] = 0
+ end
+ end
+ end
+
+ return pred_list
+end
+
+local function predlist_add_coeff_names(names, pred_list, expr)
+ local is_unit = expr_is_unit(expr.scalar)
+ local scalar_repr = expr_print.expr(expr.scalar)
+ for _, pred in ipairs(pred_list) do
+ local cname = pred_coeff_name(pred)
+ names[#names+1] = (is_unit and cname or scalar_repr .. ' * ' .. cname)
+ end
+end
+
+local function eval_pred_list(t, pred, i)
+ local match = true
+ for k, name, level in iter_by_two, pred, -1 do
+ match = match and (t:get(i, name) == level)
+ end
+ return (match and 1 or 0)
+end
+
+local function eval_coeff_names(expr_list, levels)
+ local names = {}
for _, expr in ipairs(expr_list) do
- expr.mult = expr.factor and level_number(expr.factor, levels) or 1
+ if expr.factor then
+ local pred_list = eval_predicates(expr.factor, levels)
+ print('EXPR', expr.factor, 'predicate', pred_list)
+ predlist_add_coeff_names(names, pred_list, expr)
+ else
+ names[#names+1] = expr_print.expr(expr.scalar)
+ end
end
+ return names
+end
+
+function gdt_expr.prepare_model(t, expr_list, y_expr, conditions)
+ local index_map, levels = map_missing_rows(t, expr_list, y_expr, conditions or {})
+ local model_dim = annotate_mult(expr_list, levels)
+ local info = {
+ names = eval_coeff_names(expr_list, levels),
+ levels = levels,
+ dim = model_dim,
+ }
+ return info, index_map
end
-- return the model matrix for the given table and expression list.
@@ -146,126 +234,54 @@ end
-- the function returns X, Y and index_map, respectively: X model matrix, Y column matrix
-- and index mapping. This latter given the correspondance
-- (table's row index) => (matrix' row index)
-function gdt_expr.eval_matrix(t, expr_list, y_expr, conditions, annotate)
- -- the "index_map" creates a mapping between matrix indexes and table
- -- indexes to take into account missing data in some rows.
- local index_map, levels = map_missing_rows(t, expr_list, y_expr, conditions)
-
- if annotate then
- annotate_mult(expr_list, levels)
+function gdt_expr.eval_matrix(t, info, expr_list, y_expr, index_map)
+ if not index_map then
+ index_map = map_missing_rows(t, expr_list, y_expr, {})
end
- local N = #t
- local NE = #expr_list
- local XM = 0
- for k, e in ipairs(expr_list) do XM = XM + e.mult end
+ local NE, XM = #expr_list, info.dim
- local eval_state = {t, 0}
-
- local function set_scalar_column(X, expr_scalar, j, names)
- names[#names+1] = expr_print.expr(expr_scalar)
+ local function set_scalar_column(X, expr_scalar, j)
for _, i, x_i in index_map_iter, index_map, {-1, 0, 0} do
- eval_state[2] = i
- local xs = expr_print.eval(expr_scalar, eval_table_context, eval_state)
+ local xs = expr_print.eval(expr_scalar, table_context, t, i)
assert(xs, string.format('missing value in data table at row: %d', i))
X:set(x_i, j, xs)
end
end
- local function eval_pred_list(pred, i)
- local match = true
- for k, name, level in iter_by_two, pred, -1 do
- match = match and (t:get(i, name) == level)
- end
- return (match and 1 or 0)
- end
-
- local function pred_coeff_name(pred)
- local ls = {}
- for k, name, level in iter_by_two, pred, -1 do
- ls[#ls+1] = string.format("%s:%s", name, level)
- end
- return table.concat(ls, ' / ')
- end
-
- local function set_contrasts_matrix(X, expr, j, names)
- local factors = expr.factor
- local NF = #factors
- local factor_levels = {}
- local counter = {}
- for p = 1, NF do
- factor_levels[p] = levels[factors[p]]
- counter[p] = 0
- end
-
- local pred_list = {}
- -- the following code cycles through all the factors/levels
- -- combinations for the given factor set (subset of ls at
- -- index "k")
- counter[NF + 1] = 0
- while counter[NF + 1] == 0 do
- local pred = {}
- for p = 1, NF do
- pred[#pred + 1] = factors[p]
- pred[#pred + 1] = factor_levels[p][counter[p] + 2]
- end
- local coeff_name = pred_coeff_name(pred)
- if expr_is_unit(expr.scalar) then
- names[#names+1] = coeff_name
- else
- local scalar_repr = expr_print.expr(expr.scalar)
- names[#names+1] = scalar_repr .. ' * ' .. coeff_name
- end
- -- add coefficient "pred"
- pred_list[#pred_list + 1] = pred
-
- for p = 1, NF + 1 do
- local cn = counter[p] + 1
- if p > NF or cn < #factor_levels[p] - 1 then
- counter[p] = cn
- break
- else
- counter[p] = 0
- end
- end
- end
-
+ local function set_contrasts_matrix(X, expr, j)
+ local pred_list = eval_predicates(expr.factor, info.levels)
for _, i, x_i in index_map_iter, index_map, {-1, 0, 0} do
- eval_state[2] = i
- local xs = expr_print.eval(expr.scalar, eval_table_context, eval_state)
+ local xs = expr_print.eval(expr.scalar, table_context, t, i)
assert(xs, string.format('missing value in data table at row: %d', i))
for k, pred in ipairs(pred_list) do
- local fs = eval_pred_list(pred, i)
+ local fs = eval_pred_list(t, pred, i)
X:set(x_i, j + (k - 1), xs * fs)
end
end
end
- local names = {}
-
-- here NR and XM gives the dimension of the model matrix
local NR = index_map_count(index_map)
+ if NR == 0 then error('invalid data table, no valid rows found') end
+
local X = matrix.alloc(NR, XM)
local Y = y_expr and matrix.alloc(NR, 1)
local col_index = 1
- for k = 1, NE do
- local expr = expr_list[k]
+ for _, expr in ipairs(expr_list) do
if expr.factor then
- set_contrasts_matrix(X, expr, col_index, names)
+ set_contrasts_matrix(X, expr, col_index)
else
- set_scalar_column(X, expr.scalar, col_index, names)
+ set_scalar_column(X, expr.scalar, col_index)
end
col_index = col_index + expr.mult
end
if y_expr then
- set_scalar_column(Y, y_expr, 1, names)
+ set_scalar_column(Y, y_expr, 1)
end
- local info = {levels= levels}
- if annotate then info.names = names end
-
- return X, Y, info, index_map
+ return X, Y
end
return gdt_expr
diff --git a/gdt-factors.lua b/gdt-factors.lua
new file mode 100644
index 00000000..a999f79c
--- /dev/null
+++ b/gdt-factors.lua
@@ -0,0 +1,35 @@
+local AST = require 'expr-actions'
+
+local function expr_find_factors_rec(t, expr, factors)
+ if AST.is_number(expr) then
+ return expr
+ elseif AST.is_variable(expr) then
+ local _, var_name, force_enum = AST.is_variable(expr)
+ if force_enum or t:col_type(var_name) == 'factor' then
+ factors[#factors+1] = var_name
+ return 1
+ else
+ return expr
+ end
+ elseif expr.operator == '*' then
+ local a, b = expr[1], expr[2]
+ local sa1 = expr_find_factors_rec(t, a, factors)
+ local sa2 = expr_find_factors_rec(t, b, factors)
+ return AST.infix('*', sa1, sa2)
+ else
+ return expr
+ end
+end
+
+function compute_factors(t, expr_list)
+ local els = {}
+ for i, e in ipairs(expr_list) do
+ local et, factors = {}, {}
+ et.scalar = expr_find_factors_rec(t, e, factors)
+ if #factors > 0 then et.factor = factors end
+ els[i] = et
+ end
+ return els
+end
+
+return {compute= compute_factors} \ No newline at end of file
diff --git a/gdt-interp.lua b/gdt-interp.lua
index 52c6cbb3..9a4a2721 100644
--- a/gdt-interp.lua
+++ b/gdt-interp.lua
@@ -1,5 +1,8 @@
-local cgsl = require 'gsl'
+local expr_parse = require 'expr-parse'
local gdt_expr = require 'gdt-expr'
+local gdt_factors = require 'gdt-factors'
+local AST = require 'expr-actions'
+local cgsl = require 'gsl'
local interp_lookup = {
linear = cgsl.gsl_interp_linear,
@@ -11,20 +14,21 @@ local interp_lookup = {
}
function gdt.interp(t, expr_formula, interp_type)
- local schema = gdt_expr.parse_schema(t, expr_formula)
+ local schema = expr_parse.schema(expr_formula, AST, false)
+ local x_exprs = gdt_factors.compute(t, schema.x)
local T = interp_lookup[interp_type or "cspline"]
if T == nil then error("invalid interpolator type") end
- local info = gdt_expr.eval_mult(t, schema.x)
- local X, y = gdt_expr.eval_matrix(t, schema.x, info, schema.y.scalar)
+ local info, index_map = gdt_expr.prepare_model(t, x_exprs, schema.y.scalar)
+ local X, y = gdt_expr.eval_matrix(t, info, x_exprs, schema.y, index_map)
+
local n = #y
local interp = ffi.gc(cgsl.gsl_interp_alloc(T, n), cgsl.gsl_interp_free)
local accel = ffi.gc(cgsl.gsl_interp_accel_alloc(), cgsl.gsl_interp_accel_free)
- local x_data, y_data = X.data, y.data
- cgsl.gsl_interp_init(interp, x_data, y_data, n)
+ cgsl.gsl_interp_init(interp, X.data, y.data, n)
local function eval(x_req)
- return cgsl.gsl_interp_eval(interp, x_data, y_data, x_req, acc)
+ return cgsl.gsl_interp_eval(interp, X.data, y.data, x_req, acc)
end
return eval
end
diff --git a/gdt-lm.lua b/gdt-lm.lua
index c5393eb9..38805f01 100644
--- a/gdt-lm.lua
+++ b/gdt-lm.lua
@@ -1,6 +1,7 @@
local expr_parse = require 'expr-parse'
local expr_print = require 'expr-print'
local gdt_expr = require 'gdt-expr'
+local gdt_factors = require 'gdt-factors'
local check = require 'check'
local mon = require 'monomial'
local AST = require 'expr-actions'
@@ -133,11 +134,11 @@ end
local FIT = {}
function FIT.model(fit, t_alt)
- return gdt_expr.eval_matrix(t_alt, fit.x_exprs, fit.info)
+ return gdt_expr.eval_matrix(t_alt, fit.info, fit.x_exprs)
end
function FIT.predict(fit, t_alt)
- local X = gdt_expr.eval_matrix(t_alt, fit.x_exprs, fit.info)
+ local X = gdt_expr.eval_matrix(t_alt, fit.info, fit.x_exprs)
return X * fit.c
end
@@ -160,7 +161,7 @@ function FIT.eval(fit, tn)
eval_table:set(1, k, tn[name])
end
local coeff = fit.c
- local sX = gdt_expr.eval_matrix(eval_table, fit.x_exprs, fit.info)
+ local sX = gdt_expr.eval_matrix(eval_table, fit.info, fit.x_exprs)
local sy = 0
for k = 0, #coeff - 1 do
sy = sy + sX.data[k] * coeff.data[k]
@@ -168,38 +169,6 @@ function FIT.eval(fit, tn)
return sy
end
-local function expr_find_factors_rec(t, expr, factors)
- if AST.is_number(expr) then
- return expr
- elseif AST.is_variable(expr) then
- local _, var_name, force_enum = AST.is_variable(expr)
- if force_enum or t:col_type(var_name) == 'factor' then
- factors[#factors+1] = var_name
- return 1
- else
- return expr
- end
- elseif expr.operator == '*' then
- local a, b = expr[1], expr[2]
- local sa1 = expr_find_factors_rec(t, a, factors)
- local sa2 = expr_find_factors_rec(t, b, factors)
- return AST.infix('*', sa1, sa2)
- else
- return expr
- end
-end
-
-function gdt_expr.extract_factors(t, expr_list)
- local els = {}
- for i, e in ipairs(expr_list) do
- local et, factors = {}, {}
- et.scalar = expr_find_factors_rec(t, e, factors)
- if #factors > 0 then et.factor = factors end
- els[i] = et
- end
- return els
-end
-
local function lm(t, model_formula, options)
local schema = expr_parse.schema(model_formula, AST, false)
@@ -212,15 +181,17 @@ local function lm(t, model_formula, options)
print("EXPANDED")
for _, e in ipairs(xs) do print(e) end
- local x_exprs = gdt_expr.extract_factors(t, xs)
+ local x_exprs = gdt_factors.compute(t, xs)
local y_expr = schema.y
--- local info = gdt_expr.eval_mult(t, x_exprs, y_expr)
+ local info, index_map = gdt_expr.prepare_model(t, x_exprs, y_expr, schema.conds)
+
+ print('info', info, index_map)
print('EXPRESSIONS')
for _, e in ipairs(x_exprs) do print(e) end
- local X, y, info, index_map = gdt_expr.eval_matrix(t, x_exprs, y_expr, schema.conds, true)
+ local X, y = gdt_expr.eval_matrix(t, info, x_exprs, y_expr, index_map)
local fit = compute_fit(X, y, info.names)
if options and options.predict then
generated by cgit v1.2.3 (git 2.39.1) at 2025年10月04日 17:30:18 +0000

AltStyle によって変換されたページ (->オリジナル) /