author | Francesco Abbate <francesco.bbt@gmail.com> | 2013年05月12日 23:17:57 +0200 |
---|---|---|
committer | Francesco Abbate <francesco.bbt@gmail.com> | 2013年05月12日 23:17:57 +0200 |
commit | 338fd9e668b582105def249a3be6a377d0160272 (patch) | |
tree | 65f2c2324b85ecb7722b9a517d04764a3774632f | |
parent | 204fa404431f9c619f0dfc5ae3efc463749fa8ed (diff) | |
download | gsl-shell-338fd9e668b582105def249a3be6a377d0160272.tar.gz |
-rw-r--r-- | expr-print.lua | 12 | ||||
-rw-r--r-- | gdt-expr.lua | 224 | ||||
-rw-r--r-- | gdt-factors.lua | 35 | ||||
-rw-r--r-- | gdt-interp.lua | 18 | ||||
-rw-r--r-- | gdt-lm.lua | 47 |
diff --git a/expr-print.lua b/expr-print.lua index 03b1e185..ebbc6518 100644 --- a/expr-print.lua +++ b/expr-print.lua @@ -67,13 +67,13 @@ local function eval_operator(op, a, b) else error('unkown operation: ' .. op) end end -local function eval(expr, scope, scope_state) +local function eval(expr, scope, ...) if type(expr) == 'number' then return expr elseif type(expr) == 'string' then - return scope.ident(expr, scope_state) + return scope.ident(expr, ...) elseif expr.func then - local arg_value = eval(expr.arg, scope, scope_state) + local arg_value = eval(expr.arg, scope, ...) if arg_value then local f = scope.func(expr) if not f then error('unknown function: ' .. expr.func) end @@ -81,11 +81,11 @@ local function eval(expr, scope, scope_state) end else if #expr == 1 then - local v = eval(expr[1], scope, scope_state) + local v = eval(expr[1], scope, ...) if v then return -v end else - local a = eval(expr[1], scope, scope_state) - local b = eval(expr[2], scope, scope_state) + local a = eval(expr[1], scope, ...) + local b = eval(expr[2], scope, ...) if a and b then return eval_operator(expr.operator, a, b) end diff --git a/gdt-expr.lua b/gdt-expr.lua index 17e77799..f36b36a8 100644 --- a/gdt-expr.lua +++ b/gdt-expr.lua @@ -25,15 +25,17 @@ local function expr_is_unit(e) return e == 1 end -local function add_expr_refs(expr, refs) +local function add_expr_refs(expr, refs, factor_refs) expr_print.references(expr.scalar, refs) if expr.factor then - for k, f in ipairs(expr.factor) do refs[f] = true end + for k, f in ipairs(expr.factor) do + refs[f] = true + factor_refs[f] = true + end end end -local function table_var_resolve(expr, st) - local t, i = st[1], st[2] +local function table_var_resolve(expr, t, i) return t:get(i, expr) end @@ -41,32 +43,28 @@ local function math_func_resolve(expr) return math[expr.func] end -local eval_table_context = { +local table_context = { ident = table_var_resolve, func = math_func_resolve, } local function map_missing_rows(t, expr_list, y_expr_scalar, conditions) - local refs, levels = {}, {} + local refs, factor_refs, levels = {}, {}, {} for k, expr in ipairs(expr_list) do - add_expr_refs(expr, refs) + add_expr_refs(expr, refs, factor_refs) end if y_expr_scalar then expr_print.references(y_expr_scalar, refs) end - for factor_name in pairs(refs) do + for factor_name in pairs(factor_refs) do levels[factor_name] = {} end local N = #t local index_map = {} local map_i, map_len = 1, 0 - local eval_state = {t, 0} for i = 1, N do - -- set the row to be evaluated to "i" - eval_state[2] = i - local row_undef = false for col_name in pairs(refs) do row_undef = row_undef or (not t:get(i, col_name)) @@ -74,12 +72,12 @@ local function map_missing_rows(t, expr_list, y_expr_scalar, conditions) if not row_undef then for _, cond in ipairs(conditions) do - local cx = expr_print.eval(cond, eval_table_context, eval_state) + local cx = expr_print.eval(cond, table_context, t, i) row_undef = row_undef or (cx == 0) end end if not row_undef then - for col_name in pairs(refs) do + for col_name in pairs(factor_refs) do list_add_unique(levels[col_name], t:get(i, col_name)) end end @@ -133,9 +131,99 @@ local function index_map_iter(index_map, ils) end local function annotate_mult(expr_list, levels) + local n = 0 + for _, expr in ipairs(expr_list) do + local mult = expr.factor and level_number(expr.factor, levels) or 1 + expr.mult = mult + n = n + mult + end + return n +end + +local function pred_coeff_name(pred) + local ls = {} + for k, name, level in iter_by_two, pred, -1 do + ls[#ls+1] = string.format("%s:%s", name, level) + end + return table.concat(ls, ' / ') +end + +local function eval_predicates(factors, levels) + local NF = #factors + local factor_levels = {} + local counter = {} + for p = 1, NF do + factor_levels[p] = levels[factors[p]] + counter[p] = 0 + end + + local pred_list = {} + -- the following code cycles through all the factors/levels + -- combinations for the given factor set (subset of ls at + -- index "k") + counter[NF + 1] = 0 + while counter[NF + 1] == 0 do + local pred = {} + for p = 1, NF do + pred[#pred + 1] = factors[p] + pred[#pred + 1] = factor_levels[p][counter[p] + 2] + end + pred_list[#pred_list+1] = pred + + for p = 1, NF + 1 do + local cn = counter[p] + 1 + if p > NF or cn < #factor_levels[p] - 1 then + counter[p] = cn + break + else + counter[p] = 0 + end + end + end + + return pred_list +end + +local function predlist_add_coeff_names(names, pred_list, expr) + local is_unit = expr_is_unit(expr.scalar) + local scalar_repr = expr_print.expr(expr.scalar) + for _, pred in ipairs(pred_list) do + local cname = pred_coeff_name(pred) + names[#names+1] = (is_unit and cname or scalar_repr .. ' * ' .. cname) + end +end + +local function eval_pred_list(t, pred, i) + local match = true + for k, name, level in iter_by_two, pred, -1 do + match = match and (t:get(i, name) == level) + end + return (match and 1 or 0) +end + +local function eval_coeff_names(expr_list, levels) + local names = {} for _, expr in ipairs(expr_list) do - expr.mult = expr.factor and level_number(expr.factor, levels) or 1 + if expr.factor then + local pred_list = eval_predicates(expr.factor, levels) + print('EXPR', expr.factor, 'predicate', pred_list) + predlist_add_coeff_names(names, pred_list, expr) + else + names[#names+1] = expr_print.expr(expr.scalar) + end end + return names +end + +function gdt_expr.prepare_model(t, expr_list, y_expr, conditions) + local index_map, levels = map_missing_rows(t, expr_list, y_expr, conditions or {}) + local model_dim = annotate_mult(expr_list, levels) + local info = { + names = eval_coeff_names(expr_list, levels), + levels = levels, + dim = model_dim, + } + return info, index_map end -- return the model matrix for the given table and expression list. @@ -146,126 +234,54 @@ end -- the function returns X, Y and index_map, respectively: X model matrix, Y column matrix -- and index mapping. This latter given the correspondance -- (table's row index) => (matrix' row index) -function gdt_expr.eval_matrix(t, expr_list, y_expr, conditions, annotate) - -- the "index_map" creates a mapping between matrix indexes and table - -- indexes to take into account missing data in some rows. - local index_map, levels = map_missing_rows(t, expr_list, y_expr, conditions) - - if annotate then - annotate_mult(expr_list, levels) +function gdt_expr.eval_matrix(t, info, expr_list, y_expr, index_map) + if not index_map then + index_map = map_missing_rows(t, expr_list, y_expr, {}) end - local N = #t - local NE = #expr_list - local XM = 0 - for k, e in ipairs(expr_list) do XM = XM + e.mult end + local NE, XM = #expr_list, info.dim - local eval_state = {t, 0} - - local function set_scalar_column(X, expr_scalar, j, names) - names[#names+1] = expr_print.expr(expr_scalar) + local function set_scalar_column(X, expr_scalar, j) for _, i, x_i in index_map_iter, index_map, {-1, 0, 0} do - eval_state[2] = i - local xs = expr_print.eval(expr_scalar, eval_table_context, eval_state) + local xs = expr_print.eval(expr_scalar, table_context, t, i) assert(xs, string.format('missing value in data table at row: %d', i)) X:set(x_i, j, xs) end end - local function eval_pred_list(pred, i) - local match = true - for k, name, level in iter_by_two, pred, -1 do - match = match and (t:get(i, name) == level) - end - return (match and 1 or 0) - end - - local function pred_coeff_name(pred) - local ls = {} - for k, name, level in iter_by_two, pred, -1 do - ls[#ls+1] = string.format("%s:%s", name, level) - end - return table.concat(ls, ' / ') - end - - local function set_contrasts_matrix(X, expr, j, names) - local factors = expr.factor - local NF = #factors - local factor_levels = {} - local counter = {} - for p = 1, NF do - factor_levels[p] = levels[factors[p]] - counter[p] = 0 - end - - local pred_list = {} - -- the following code cycles through all the factors/levels - -- combinations for the given factor set (subset of ls at - -- index "k") - counter[NF + 1] = 0 - while counter[NF + 1] == 0 do - local pred = {} - for p = 1, NF do - pred[#pred + 1] = factors[p] - pred[#pred + 1] = factor_levels[p][counter[p] + 2] - end - local coeff_name = pred_coeff_name(pred) - if expr_is_unit(expr.scalar) then - names[#names+1] = coeff_name - else - local scalar_repr = expr_print.expr(expr.scalar) - names[#names+1] = scalar_repr .. ' * ' .. coeff_name - end - -- add coefficient "pred" - pred_list[#pred_list + 1] = pred - - for p = 1, NF + 1 do - local cn = counter[p] + 1 - if p > NF or cn < #factor_levels[p] - 1 then - counter[p] = cn - break - else - counter[p] = 0 - end - end - end - + local function set_contrasts_matrix(X, expr, j) + local pred_list = eval_predicates(expr.factor, info.levels) for _, i, x_i in index_map_iter, index_map, {-1, 0, 0} do - eval_state[2] = i - local xs = expr_print.eval(expr.scalar, eval_table_context, eval_state) + local xs = expr_print.eval(expr.scalar, table_context, t, i) assert(xs, string.format('missing value in data table at row: %d', i)) for k, pred in ipairs(pred_list) do - local fs = eval_pred_list(pred, i) + local fs = eval_pred_list(t, pred, i) X:set(x_i, j + (k - 1), xs * fs) end end end - local names = {} - -- here NR and XM gives the dimension of the model matrix local NR = index_map_count(index_map) + if NR == 0 then error('invalid data table, no valid rows found') end + local X = matrix.alloc(NR, XM) local Y = y_expr and matrix.alloc(NR, 1) local col_index = 1 - for k = 1, NE do - local expr = expr_list[k] + for _, expr in ipairs(expr_list) do if expr.factor then - set_contrasts_matrix(X, expr, col_index, names) + set_contrasts_matrix(X, expr, col_index) else - set_scalar_column(X, expr.scalar, col_index, names) + set_scalar_column(X, expr.scalar, col_index) end col_index = col_index + expr.mult end if y_expr then - set_scalar_column(Y, y_expr, 1, names) + set_scalar_column(Y, y_expr, 1) end - local info = {levels= levels} - if annotate then info.names = names end - - return X, Y, info, index_map + return X, Y end return gdt_expr diff --git a/gdt-factors.lua b/gdt-factors.lua new file mode 100644 index 00000000..a999f79c --- /dev/null +++ b/gdt-factors.lua @@ -0,0 +1,35 @@ +local AST = require 'expr-actions' + +local function expr_find_factors_rec(t, expr, factors) + if AST.is_number(expr) then + return expr + elseif AST.is_variable(expr) then + local _, var_name, force_enum = AST.is_variable(expr) + if force_enum or t:col_type(var_name) == 'factor' then + factors[#factors+1] = var_name + return 1 + else + return expr + end + elseif expr.operator == '*' then + local a, b = expr[1], expr[2] + local sa1 = expr_find_factors_rec(t, a, factors) + local sa2 = expr_find_factors_rec(t, b, factors) + return AST.infix('*', sa1, sa2) + else + return expr + end +end + +function compute_factors(t, expr_list) + local els = {} + for i, e in ipairs(expr_list) do + local et, factors = {}, {} + et.scalar = expr_find_factors_rec(t, e, factors) + if #factors > 0 then et.factor = factors end + els[i] = et + end + return els +end + +return {compute= compute_factors}
\ No newline at end of file diff --git a/gdt-interp.lua b/gdt-interp.lua index 52c6cbb3..9a4a2721 100644 --- a/gdt-interp.lua +++ b/gdt-interp.lua @@ -1,5 +1,8 @@ -local cgsl = require 'gsl' +local expr_parse = require 'expr-parse' local gdt_expr = require 'gdt-expr' +local gdt_factors = require 'gdt-factors' +local AST = require 'expr-actions' +local cgsl = require 'gsl' local interp_lookup = { linear = cgsl.gsl_interp_linear, @@ -11,20 +14,21 @@ local interp_lookup = { } function gdt.interp(t, expr_formula, interp_type) - local schema = gdt_expr.parse_schema(t, expr_formula) + local schema = expr_parse.schema(expr_formula, AST, false) + local x_exprs = gdt_factors.compute(t, schema.x) local T = interp_lookup[interp_type or "cspline"] if T == nil then error("invalid interpolator type") end - local info = gdt_expr.eval_mult(t, schema.x) - local X, y = gdt_expr.eval_matrix(t, schema.x, info, schema.y.scalar) + local info, index_map = gdt_expr.prepare_model(t, x_exprs, schema.y.scalar) + local X, y = gdt_expr.eval_matrix(t, info, x_exprs, schema.y, index_map) + local n = #y local interp = ffi.gc(cgsl.gsl_interp_alloc(T, n), cgsl.gsl_interp_free) local accel = ffi.gc(cgsl.gsl_interp_accel_alloc(), cgsl.gsl_interp_accel_free) - local x_data, y_data = X.data, y.data - cgsl.gsl_interp_init(interp, x_data, y_data, n) + cgsl.gsl_interp_init(interp, X.data, y.data, n) local function eval(x_req) - return cgsl.gsl_interp_eval(interp, x_data, y_data, x_req, acc) + return cgsl.gsl_interp_eval(interp, X.data, y.data, x_req, acc) end return eval end diff --git a/gdt-lm.lua b/gdt-lm.lua index c5393eb9..38805f01 100644 --- a/gdt-lm.lua +++ b/gdt-lm.lua @@ -1,6 +1,7 @@ local expr_parse = require 'expr-parse' local expr_print = require 'expr-print' local gdt_expr = require 'gdt-expr' +local gdt_factors = require 'gdt-factors' local check = require 'check' local mon = require 'monomial' local AST = require 'expr-actions' @@ -133,11 +134,11 @@ end local FIT = {} function FIT.model(fit, t_alt) - return gdt_expr.eval_matrix(t_alt, fit.x_exprs, fit.info) + return gdt_expr.eval_matrix(t_alt, fit.info, fit.x_exprs) end function FIT.predict(fit, t_alt) - local X = gdt_expr.eval_matrix(t_alt, fit.x_exprs, fit.info) + local X = gdt_expr.eval_matrix(t_alt, fit.info, fit.x_exprs) return X * fit.c end @@ -160,7 +161,7 @@ function FIT.eval(fit, tn) eval_table:set(1, k, tn[name]) end local coeff = fit.c - local sX = gdt_expr.eval_matrix(eval_table, fit.x_exprs, fit.info) + local sX = gdt_expr.eval_matrix(eval_table, fit.info, fit.x_exprs) local sy = 0 for k = 0, #coeff - 1 do sy = sy + sX.data[k] * coeff.data[k] @@ -168,38 +169,6 @@ function FIT.eval(fit, tn) return sy end -local function expr_find_factors_rec(t, expr, factors) - if AST.is_number(expr) then - return expr - elseif AST.is_variable(expr) then - local _, var_name, force_enum = AST.is_variable(expr) - if force_enum or t:col_type(var_name) == 'factor' then - factors[#factors+1] = var_name - return 1 - else - return expr - end - elseif expr.operator == '*' then - local a, b = expr[1], expr[2] - local sa1 = expr_find_factors_rec(t, a, factors) - local sa2 = expr_find_factors_rec(t, b, factors) - return AST.infix('*', sa1, sa2) - else - return expr - end -end - -function gdt_expr.extract_factors(t, expr_list) - local els = {} - for i, e in ipairs(expr_list) do - local et, factors = {}, {} - et.scalar = expr_find_factors_rec(t, e, factors) - if #factors > 0 then et.factor = factors end - els[i] = et - end - return els -end - local function lm(t, model_formula, options) local schema = expr_parse.schema(model_formula, AST, false) @@ -212,15 +181,17 @@ local function lm(t, model_formula, options) print("EXPANDED") for _, e in ipairs(xs) do print(e) end - local x_exprs = gdt_expr.extract_factors(t, xs) + local x_exprs = gdt_factors.compute(t, xs) local y_expr = schema.y --- local info = gdt_expr.eval_mult(t, x_exprs, y_expr) + local info, index_map = gdt_expr.prepare_model(t, x_exprs, y_expr, schema.conds) + + print('info', info, index_map) print('EXPRESSIONS') for _, e in ipairs(x_exprs) do print(e) end - local X, y, info, index_map = gdt_expr.eval_matrix(t, x_exprs, y_expr, schema.conds, true) + local X, y = gdt_expr.eval_matrix(t, info, x_exprs, y_expr, index_map) local fit = compute_fit(X, y, info.names) if options and options.predict then |