author | Francesco Abbate <francesco.bbt@gmail.com> | 2013年01月20日 19:21:09 +0100 |
---|---|---|
committer | Francesco Abbate <francesco.bbt@gmail.com> | 2013年01月20日 19:21:09 +0100 |
commit | 623b7ac2a0b18c5df2ed89695df8e2849d681fdd (patch) | |
tree | 1025c8eef45ebde898c65f3841c76fec212380df | |
parent | 0a746cfd9cb89eb5ff6c6576e96e3c291c2ee6b1 (diff) | |
download | gsl-shell-623b7ac2a0b18c5df2ed89695df8e2849d681fdd.tar.gz |
-rw-r--r-- | benchmarks/results.csv | 16 | ||||
-rw-r--r-- | expr-parser.lua (renamed from mini-parser.lua) | 63 | ||||
-rw-r--r-- | expr-print.lua | 52 | ||||
-rw-r--r-- | gdt-lm.lua | 484 |
diff --git a/benchmarks/results.csv b/benchmarks/results.csv new file mode 100644 index 00000000..a750caf7 --- /dev/null +++ b/benchmarks/results.csv @@ -0,0 +1,16 @@ +Test,Source,Time +ODE rk8pd,LuaJIT2 joff,10.408 +ODE rk8pd,C,1.449 +ODE rk8pd,LuaJIT2,0.732 +ODE rkf45,LuaJIT2 joff,22.27 +ODE rkf45,C,2.192 +ODE rkf45,LuaJIT2,0.95 +SF roots,LuaJIT2 joff,18.765 +SF roots,LuaJIT2 FFI,6.437 +SF roots,LuaJIT2,6.531 +VEGAS,LuaJIT2 joff,134.617 +VEGAS,C,2.509 +VEGAS,LuaJIT2,2.914 +QAG,LuaJIT2 joff,6.889 +QAG,C,1.886 +QAG,LuaJIT2,1.107 diff --git a/mini-parser.lua b/expr-parser.lua index 610d2d67..27b193fb 100644 --- a/mini-parser.lua +++ b/expr-parser.lua @@ -62,7 +62,7 @@ function mini_lexer.next_token(lexer) return {type= c} end if lexer:match('[%l%u_]') then - local str = lexer:consume('[%l%u_]%w*') + local str = lexer:consume('[%l%u_][%w_]*') return {type= 'ident', value= str} end if lexer:match('[1-9]') then @@ -155,7 +155,7 @@ local function expr_list(lexer, actions) end local function schema(lexer, actions) - local y = expr_list(lexer, actions) + local y = expr(lexer, actions, 0) expect(lexer, '~') local x = expr_list(lexer, actions) return actions.schema(x, y) @@ -165,61 +165,4 @@ local function parse_expr(lexer, actions) return expr(lexer, actions, 0) end -local AST_create = { - infix = function(sym, a, b) return {operator= sym, a, b} end, - ident = function(id) return {ident= id} end, - prefix = function(sym, a) return {operator= sym, a} end, - number = function(x) return {number= x} end, - exprlist = function(a, ls) if ls then ls[#ls+1] = a else ls = {list= true, a} end; return ls end, - schema = function(x, y) return {schema= true, x= x, y= y} end, -} - -local format, concat = string.format, table.concat -local AST_print - -local function is_ident_simple(s) - return s:match('^[_%l%u]%w*$') -end - -local function AST_print_op(e, prio) - if #e == 1 then - local c, c_prio = AST_print(e[1]) - if c_prio < prio then c = format('(%s)', c) end - return format("%s%s", e.operator, c) - else - local a, a_prio = AST_print(e[1]) - local b, b_prio = AST_print(e[2]) - if a_prio < prio then a = format('(%s)', a) end - if b_prio < prio then b = format('(%s)', b) end - local temp = (prio < 2 and "%s %s %s" or "%s%s%s") - return format(temp, a, e.operator, b) - end -end - -local function AST_print_exprlist(e) - local t = {} - for k = 1, #e do t[k] = AST_print(e[k]) end - return concat(t, ', ') -end - -AST_print = function(e) - if e.schema then - local ys = AST_print_exprlist(e.y) - local xs = AST_print_exprlist(e.x) - return format("%s ~ %s", ys, xs) - elseif e.list then - return AST_print_exprlist(e) - elseif e.ident then - local s = e.ident - if not is_ident_simple(s) then s = format('[%s]', s) end - return s, 3 - elseif e.number then - return e.number, 3 - else - local prio = oper_table[e.operator] - local s = AST_print_op(e, prio) - return s, prio - end -end - -return {lexer = new_lexer, schema= schema, parse = parse_expr, AST= AST_create, print = AST_print} +return {lexer = new_lexer, schema= schema, parse = parse_expr} diff --git a/expr-print.lua b/expr-print.lua new file mode 100644 index 00000000..27e8f75c --- /dev/null +++ b/expr-print.lua @@ -0,0 +1,52 @@ +local format, concat = string.format, table.concat + +local oper_table = {['+'] = 0, ['-'] = 0, ['*'] = 1, ['/'] = 1, ['^'] = 2} + +local ex_print + +local function is_ident_simple(s) + return s:match('^[%l%u_][%w_]*$') +end + +local function op_print(e, prio) + if #e == 1 then + local c, c_prio = ex_print(e[1]) + if c_prio < prio then c = format('(%s)', c) end + return format("%s%s", e.operator, c) + else + local a, a_prio = ex_print(e[1]) + local b, b_prio = ex_print(e[2]) + if a_prio < prio then a = format('(%s)', a) end + if b_prio < prio then b = format('(%s)', b) end + local temp = (prio < 2 and "%s %s %s" or "%s%s%s") + return format(temp, a, e.operator, b) + end +end + +local function exlist_print(e) + local t = {} + for k = 1, #e do t[k] = ex_print(e[k]) end + return concat(t, ', ') +end + +ex_print = function(e) + if type(e) == 'number' then + return e, 3 + elseif type(e) == 'string' then + local s = e + if not is_ident_simple(s) then s = format('[%s]', s) end + return s, 3 + else + local prio = oper_table[e.operator] + local s = op_print(e, prio) + return s, prio + end +end + +local function schema_print(e) + local ys = exlist_print(e.y) + local xs = exlist_print(e.x) + return format("%s ~ %s", ys, xs) +end + +return {schema = schema_print, expr = ex_print, expr_list = exlist_print} diff --git a/gdt-lm.lua b/gdt-lm.lua index 92ecef63..6da57885 100644 --- a/gdt-lm.lua +++ b/gdt-lm.lua @@ -1,183 +1,329 @@ -local cgdt = require 'cgdt' +local mini = require 'expr-parser' +local expr_print = require 'expr-print' -local element_is_number = gdt.element_is_number -local format = string.format -local concat = table.concat local sqrt, abs = math.sqrt, math.abs --- status: 0 => string, 1 => numbers +local FACTOR_CLASS = 0 +local SCALAR_CLASS = 1 + local function find_column_type(t, j) - local n = #t - for i = 1, n do - local x = gdt.get_number_unsafe(t, i, j) - if not x then return 0 end - end - return 1 -end - -local function lm_expr_names(t, expr) - local n, m = t:dim() - - local code_lines = {} - local code = function(line) code_lines[#code_lines+1] = line end - - code([[local _LM = require 'lm-helpers']]) - code([[local enum = function(x) return x end]]) - local var_name = t:headers() - for j = 1, m do - local name = var_name[j] - code(format([[local %s = _LM.var_name(%q)]], name, name)) - end - code(format([[local expr_names = _LM.find_names(%s)]], expr)) - code([[return expr_names]]) - - local code_str = concat(code_lines, '\n') - local f = load(code_str) - local expr_names = f() - return expr_names -end - -local function lm_prepare(t, expr) - local n, m = t:dim() - local column_class = {} - local code_lines = {} - local code = function(line) code_lines[#code_lines+1] = line end - local name = t:headers(j) - code([[local _LM = require 'lm-helpers']]) - code([[local _get = gdt.get]]) - code([[local enum = function(x) return {value = x} end]]) - for j = 1, m do - column_class[j] = find_column_type(t, j) - local value = (column_class[j] == 1 and "0" or format("_LM.factor(\"%s\")", name[j])) - code(format("local %s = %s", name[j], value)) - end - - code(format([[local _y_spec = _LM.eval_test(%s)]], expr)) - - code([[local _eval_func = _LM.eval_func]]) - - code(format("local _eval = gdt.alloc(%d, _y_spec.np)", n)) - - code(format("for _i = 1, %d do", n)) - for j = 1, m do - local line - if column_class[j] == 1 then - line = format(" %s = _get(_t, _i, %d)", name[j], j) - else - line = format(" %s.value = _get(_t, _i, %d)", name[j], j) - end - code(line) - end - code("") - code(format(" _eval_func(_y_spec, _eval, _i, %s)", expr)) - code(format("end", n)) - - return format("return function(_t)\n%s\nreturn _eval, _y_spec\nend", concat(code_lines, "\n")) + local n = #t + for i = 1, n do + local x = gdt.get_number_unsafe(t, i, j) + if not x then return FACTOR_CLASS end + end + return SCALAR_CLASS +end + +local function mult(a, b) + if a == 1 then return b end + if b == 1 then return a end + return {operator= '*', a, b} +end + +local function scalar_infix(sym, a, b) + if sym == '*' then + return mult(a, b) + else + return {operator= sym, a, b} + end +end + +local function factor_infix(sym, a, b) + if not (a or b) then return nil end + if sym ~= '*' then + error('non multiplicative opeation on factors') + end + local c = {} + if a then for i, f in ipairs(a) do c[#c+1] = f end end + if b then for i, f in ipairs(b) do c[#c+1] = f end end + return c +end + +local function infix_action(sym, a, b) + local c = {} + c.scalar = scalar_infix(sym, a.scalar, b.scalar) + c.factor = factor_infix(sym, a.factor, b.factor) + return c +end + +local function prefix_action(sym, a) + if a.factor then error('non multiplicative opeation on factors') end + return {scalar= {operator= sym, a}} +end + +local function lm_actions_gen(t) + local n, m = t:dim() + + local column_class = {} + for j = 1, m do column_class[j] = find_column_type(t, j) end + + local function ident_action(id) + local index = t:col_index(id) + if column_class[index] == FACTOR_CLASS then + return {scalar= 1, factor= {id}} + else + return {scalar= id} + end + end + + return { + infix = infix_action, + ident = ident_action, + prefix = prefix_action, + number = function(x) return {scalar= x} end, + exprlist = function(a, ls) if ls then ls[#ls+1] = a else ls = {a} end; return ls end, + schema = function(x, y) return {x= x, y= y} end, + } end local function add_unique(t, val) - for k, x in ipairs(t) do - if x == val then return 0 end - end - local n = #t + 1 - t[n] = val - return n -end - -local function lm_main(Xt, t, inf) - local N = #t - - local index = {} - local curr_index = 1 - local coeff_name = {} - inf.factors, inf.factor_index = {}, {} - for k = 1, inf.np do - local factor_name = inf.factor_names[k] - index[k] = curr_index - if inf.class[k] == 0 then - local factors, factor_index = {}, {} - for i = 1, N do - local str = Xt:get(i, k) - local ui = add_unique(factors, str) - if ui > 0 then - factor_index[str] = ui - if ui > 1 then coeff_name[curr_index + (ui - 2)] = factor_name .. "/" .. str end - end - end - inf.factors[k] = factors - inf.factor_index[k] = factor_index - curr_index = curr_index + (#factors - 1) - else - coeff_name[curr_index] = factor_name - curr_index = curr_index + 1 - end - end - - local col = curr_index - 1 - - local X = matrix.alloc(N, col) - local X_data = X.data - for i = 1, N do - local idx0 = col * (i - 1) - for k = 1, inf.np do - local j = index[k] - if inf.class[k] == 1 then - X_data[idx0 + j - 1] = gdt.get_number_unsafe(Xt, i, k) - else - local factors = inf.factors[k] - local factor_index = inf.factor_index[k] - local req_f = Xt:get(i, k) - local nf = #factors - 1 - for kf = 1, nf do - X_data[idx0 + (j - 1) + (kf - 1)] = 0 - end - local kfx = factor_index[req_f] - 1 - if kfx > 0 then X_data[idx0 + (j - 1) + (kfx - 1)] = 1 end - end - end - end - - return X, coeff_name -end - -local function lm_model(t, expr) - local code = lm_prepare(t, expr) - local f_gen = assert(load(code)) - local f_code = f_gen() - local Xt, inf = f_code(t) - inf.factor_names = lm_expr_names(t, expr) - return lm_main(Xt, t, inf) + for k, x in ipairs(t) do + if x == val then return k end + end + local n = #t + 1 + t[n] = val + return n +end + +local function level_number(factors, levels) + if not factors then return 0 end + local nb = 1 + for _, factor_name in ipairs(factors) do + nb = nb * #levels[factor_name] + end + return nb +end + +local function enum_levels(factors, levels) + local ls, ks, ms = {}, {}, {} + local n = #factors + for i, name in ipairs(factors) do + ks[i], ms[i] = 0, #levels[name] + end + + -- Start the counter from 1 instead of 0 to omit the first + -- level. It will be implicitely the reference. + ks[n] = (factors.omit_ref_level and 1 or 0) + while true do + local lev = {} + for i, name in ipairs(factors) do + lev[i] = levels[name][ks[i] + 1] + end + ls[#ls + 1] = lev + + for i = n, 0, -1 do + if i == 0 then return ls end + ks[i] = (ks[i] + 1) % ms[i] + if ks[i] > 0 then break end + end + end +end + +local function eval_operator(op, a, b) + if op == '+' then return a + b + elseif op == '-' then return a - b + elseif op == '*' then return a * b + elseif op == '/' then return a / b + elseif op == '^' then return a ^ b + else error('unkown operation: ' .. op) end +end + +local function eval_scalar(t, i, expr) + local tp = type(expr) + if tp == 'string' then + local j = t:col_index(expr) + return t:get(i, j) + elseif tp == 'number' then + return expr + else + if #expr == 1 then + return - eval_scalar(t, i, expr[1]) + else + local a = eval_scalar(t, i, expr[1]) + local b = eval_scalar(t, i, expr[2]) + return eval_operator(expr.operator, a, b) + end + end +end + +local function level_does_match(t, i, factors, req_levels) + for k, factor_name in ipairs(factors) do + local y = t:get(i, t:col_index(factor_name)) + if y ~= req_levels[k] then return 0 end + end + return 1 +end + +local function expr_are_equal(a, b) + if a == b then + return true + elseif type(a) == 'table' and type(b) == 'table' then + if a.operator == b.operator then + return expr_are_equal(a[1], b[1]) and expr_are_equal(a[2], b[2]) + end + end + return false +end + +local function scalar_term_exists(expr_list, s) + for i, expr in ipairs(expr_list) do + if not expr.factor and expr_are_equal(expr.scalar, s) then + return true + end + end + return false +end + +local function print_expr_level(factors, levels) + local t = {} + for i, f in ipairs(factors) do + t[i] = string.format("%s%s", f, levels[i]) + end + return table.concat(t, ':') +end + +local function expr_is_unit(e) + return type(e) == 'number' and e == 1 +end + +local function build_lm_model(t, expr_list) + local N, M = t:dim() + + -- list of unique factors referenced in expr_list + local used_factors = {} + for k, expr in ipairs(expr_list) do + if expr.factor then + for _, f_name in ipairs(expr.factor) do + add_unique(used_factors, f_name) + end + end + end + + -- flag the factors whose scalar part is already used + -- in the model. In these cases the first level of the + -- factor will be omitted from the model matrix. + for k, expr in ipairs(expr_list) do + if expr.factor then + local s = expr.scalar + if scalar_term_exists(expr_list, s) then + expr.factor.omit_ref_level = true + end + end + end + + -- for each unique used factor prepare the levels list and + -- set the column index + local levels, factor_index = {}, {} + for k, name in ipairs(used_factors) do + levels[name] = {} + factor_index[name] = t:col_index(name) + end + + -- find the levels for each of the used factors + local get = t.get + for i = 1, N do + for _, name in ipairs(used_factors) do + local v = get(t, i, factor_index[name]) + add_unique(levels[name], v) + end + end + + local expr_index = {} + local curr_index = 1 + for k, expr in ipairs(expr_list) do + expr_index[k] = curr_index + if expr.factor then + local lnb = level_number(expr.factor, levels) + local inn = expr.factor.omit_ref_level and lnb - 1 or lnb + curr_index = curr_index + inn + else + curr_index = curr_index + 1 + end + end + expr_index[#expr_list + 1] = curr_index + + local model_m = curr_index - 1 + local X = matrix.alloc(N, model_m) + local names = {} + for k, expr in ipairs(expr_list) do + local scalar_repr = expr_print.expr(expr.scalar) + local index_offs = expr_index[k] + local scalar_expr = expr.scalar + if not expr.factor then + local j = index_offs + for i = 1, N do + local xs = eval_scalar(t, i, scalar_expr) + X:set(i, j, xs) + end + names[j] = scalar_repr + else + local expr_levels = enum_levels(expr.factor, levels) + local j0 = index_offs + for i = 1, N do + local xs = eval_scalar(t, i, scalar_expr) + for j, req_lev in ipairs(expr_levels) do + local match = level_does_match(t, i, expr.factor, req_lev) + X:set(i, j0 + (j - 1), match * xs) + end + end + for j, req_lev in ipairs(expr_levels) do + local level_repr = print_expr_level(expr.factor, req_lev) + local nm + if expr_is_unit(expr.scalar) then + nm = level_repr + else + nm = scalar_repr .. ' * ' .. level_repr + end + names[j0 + (j - 1)] = nm + end + end + end + + return X, names +end + +local function eval_y(t, y_expr) + local N = #t + local y = matrix.alloc(N, 1) + for i = 1, N do + local yi = eval_scalar(t, i, y_expr) + y.data[i - 1] = yi + end + return y end local function t_test(xm, s, n, df) - local t = xm / s - local at = abs(t) - local p_value = 2 * (1 - randist.tdist_P(at, df)) - return t, (p_value >= 2e-16 and p_value or '< 2e-16') -end - -local function lm(t, expr) - local a, b = string.match(expr, "%s*([%S]+)%s*~(.+)") - assert(a, "invalid lm expression") - local n, m = t:dim() - local jy = t:col_index(a) - assert(jy, "invalid variable specification in lm expression") - local set = gdt.set - local y = matrix.new(n, 1, |i| t:get(i, jy)) - local X, name = lm_model(t, b) - local c, chisq, cov = num.linfit(X, y) - local coeff = gdt.alloc(#c, {"term", "estimate", "std error", "t value" ,"Pr(>|t|)"}) - for i = 1, #c do - coeff:set(i, 1, name[i]) - local xm, s = c[i], cov:get(i,i) - coeff:set(i, 2, xm) - coeff:set(i, 3, sqrt(s)) - local t, p_value = t_test(xm, sqrt(s), n, n - #c) - coeff:set(i, 4, t) - coeff:set(i, 5, p_value) - end - return {coeff = coeff, c = c, chisq = chisq, cov = cov, X = X} + local t = xm / s + local at = abs(t) + local p_value = 2 * (1 - randist.tdist_P(at, df)) + return t, (p_value >= 2e-16 and p_value or '< 2e-16') +end + +local function linfit(X, y, names) + local n = #y + local c, chisq, cov = num.linfit(X, y) + local coeff = gdt.alloc(#c, {"term", "estimate", "std error", "t value" ,"Pr(>|t|)"}) + for i = 1, #c do + coeff:set(i, 1, names[i]) + local xm, s = c[i], cov:get(i,i) + coeff:set(i, 2, xm) + coeff:set(i, 3, sqrt(s)) + local t, p_value = t_test(xm, sqrt(s), n, n - #c) + coeff:set(i, 4, t) + coeff:set(i, 5, p_value) + end + return {coeff = coeff, c = c, chisq = chisq, cov = cov, X = X} +end + +local function lm(t, model_formula) + local actions = lm_actions_gen(t) + local l = mini.lexer(model_formula) + local schema = mini.schema(l, actions) + local X, names = build_lm_model(t, schema.x) + local y = eval_y(t, schema.y.scalar) + return linfit(X, y, names) end gdt.lm = lm |