complete rewrite of gdt.lm using mini expr parser - gsl-shell.git - gsl-shell

index : gsl-shell.git
gsl-shell
summary refs log tree commit diff
diff options
context:
space:
mode:
authorFrancesco Abbate <francesco.bbt@gmail.com>2013年01月20日 19:21:09 +0100
committerFrancesco Abbate <francesco.bbt@gmail.com>2013年01月20日 19:21:09 +0100
commit623b7ac2a0b18c5df2ed89695df8e2849d681fdd (patch)
tree1025c8eef45ebde898c65f3841c76fec212380df
parent0a746cfd9cb89eb5ff6c6576e96e3c291c2ee6b1 (diff)
downloadgsl-shell-623b7ac2a0b18c5df2ed89695df8e2849d681fdd.tar.gz
complete rewrite of gdt.lm using mini expr parser
Diffstat
-rw-r--r--benchmarks/results.csv 16
-rw-r--r--expr-parser.lua (renamed from mini-parser.lua)63
-rw-r--r--expr-print.lua 52
-rw-r--r--gdt-lm.lua 484
4 files changed, 386 insertions, 229 deletions
diff --git a/benchmarks/results.csv b/benchmarks/results.csv
new file mode 100644
index 00000000..a750caf7
--- /dev/null
+++ b/benchmarks/results.csv
@@ -0,0 +1,16 @@
+Test,Source,Time
+ODE rk8pd,LuaJIT2 joff,10.408
+ODE rk8pd,C,1.449
+ODE rk8pd,LuaJIT2,0.732
+ODE rkf45,LuaJIT2 joff,22.27
+ODE rkf45,C,2.192
+ODE rkf45,LuaJIT2,0.95
+SF roots,LuaJIT2 joff,18.765
+SF roots,LuaJIT2 FFI,6.437
+SF roots,LuaJIT2,6.531
+VEGAS,LuaJIT2 joff,134.617
+VEGAS,C,2.509
+VEGAS,LuaJIT2,2.914
+QAG,LuaJIT2 joff,6.889
+QAG,C,1.886
+QAG,LuaJIT2,1.107
diff --git a/mini-parser.lua b/expr-parser.lua
index 610d2d67..27b193fb 100644
--- a/mini-parser.lua
+++ b/expr-parser.lua
@@ -62,7 +62,7 @@ function mini_lexer.next_token(lexer)
return {type= c}
end
if lexer:match('[%l%u_]') then
- local str = lexer:consume('[%l%u_]%w*')
+ local str = lexer:consume('[%l%u_][%w_]*')
return {type= 'ident', value= str}
end
if lexer:match('[1-9]') then
@@ -155,7 +155,7 @@ local function expr_list(lexer, actions)
end
local function schema(lexer, actions)
- local y = expr_list(lexer, actions)
+ local y = expr(lexer, actions, 0)
expect(lexer, '~')
local x = expr_list(lexer, actions)
return actions.schema(x, y)
@@ -165,61 +165,4 @@ local function parse_expr(lexer, actions)
return expr(lexer, actions, 0)
end
-local AST_create = {
- infix = function(sym, a, b) return {operator= sym, a, b} end,
- ident = function(id) return {ident= id} end,
- prefix = function(sym, a) return {operator= sym, a} end,
- number = function(x) return {number= x} end,
- exprlist = function(a, ls) if ls then ls[#ls+1] = a else ls = {list= true, a} end; return ls end,
- schema = function(x, y) return {schema= true, x= x, y= y} end,
-}
-
-local format, concat = string.format, table.concat
-local AST_print
-
-local function is_ident_simple(s)
- return s:match('^[_%l%u]%w*$')
-end
-
-local function AST_print_op(e, prio)
- if #e == 1 then
- local c, c_prio = AST_print(e[1])
- if c_prio < prio then c = format('(%s)', c) end
- return format("%s%s", e.operator, c)
- else
- local a, a_prio = AST_print(e[1])
- local b, b_prio = AST_print(e[2])
- if a_prio < prio then a = format('(%s)', a) end
- if b_prio < prio then b = format('(%s)', b) end
- local temp = (prio < 2 and "%s %s %s" or "%s%s%s")
- return format(temp, a, e.operator, b)
- end
-end
-
-local function AST_print_exprlist(e)
- local t = {}
- for k = 1, #e do t[k] = AST_print(e[k]) end
- return concat(t, ', ')
-end
-
-AST_print = function(e)
- if e.schema then
- local ys = AST_print_exprlist(e.y)
- local xs = AST_print_exprlist(e.x)
- return format("%s ~ %s", ys, xs)
- elseif e.list then
- return AST_print_exprlist(e)
- elseif e.ident then
- local s = e.ident
- if not is_ident_simple(s) then s = format('[%s]', s) end
- return s, 3
- elseif e.number then
- return e.number, 3
- else
- local prio = oper_table[e.operator]
- local s = AST_print_op(e, prio)
- return s, prio
- end
-end
-
-return {lexer = new_lexer, schema= schema, parse = parse_expr, AST= AST_create, print = AST_print}
+return {lexer = new_lexer, schema= schema, parse = parse_expr}
diff --git a/expr-print.lua b/expr-print.lua
new file mode 100644
index 00000000..27e8f75c
--- /dev/null
+++ b/expr-print.lua
@@ -0,0 +1,52 @@
+local format, concat = string.format, table.concat
+
+local oper_table = {['+'] = 0, ['-'] = 0, ['*'] = 1, ['/'] = 1, ['^'] = 2}
+
+local ex_print
+
+local function is_ident_simple(s)
+ return s:match('^[%l%u_][%w_]*$')
+end
+
+local function op_print(e, prio)
+ if #e == 1 then
+ local c, c_prio = ex_print(e[1])
+ if c_prio < prio then c = format('(%s)', c) end
+ return format("%s%s", e.operator, c)
+ else
+ local a, a_prio = ex_print(e[1])
+ local b, b_prio = ex_print(e[2])
+ if a_prio < prio then a = format('(%s)', a) end
+ if b_prio < prio then b = format('(%s)', b) end
+ local temp = (prio < 2 and "%s %s %s" or "%s%s%s")
+ return format(temp, a, e.operator, b)
+ end
+end
+
+local function exlist_print(e)
+ local t = {}
+ for k = 1, #e do t[k] = ex_print(e[k]) end
+ return concat(t, ', ')
+end
+
+ex_print = function(e)
+ if type(e) == 'number' then
+ return e, 3
+ elseif type(e) == 'string' then
+ local s = e
+ if not is_ident_simple(s) then s = format('[%s]', s) end
+ return s, 3
+ else
+ local prio = oper_table[e.operator]
+ local s = op_print(e, prio)
+ return s, prio
+ end
+end
+
+local function schema_print(e)
+ local ys = exlist_print(e.y)
+ local xs = exlist_print(e.x)
+ return format("%s ~ %s", ys, xs)
+end
+
+return {schema = schema_print, expr = ex_print, expr_list = exlist_print}
diff --git a/gdt-lm.lua b/gdt-lm.lua
index 92ecef63..6da57885 100644
--- a/gdt-lm.lua
+++ b/gdt-lm.lua
@@ -1,183 +1,329 @@
-local cgdt = require 'cgdt'
+local mini = require 'expr-parser'
+local expr_print = require 'expr-print'
-local element_is_number = gdt.element_is_number
-local format = string.format
-local concat = table.concat
local sqrt, abs = math.sqrt, math.abs
--- status: 0 => string, 1 => numbers
+local FACTOR_CLASS = 0
+local SCALAR_CLASS = 1
+
local function find_column_type(t, j)
- local n = #t
- for i = 1, n do
- local x = gdt.get_number_unsafe(t, i, j)
- if not x then return 0 end
- end
- return 1
-end
-
-local function lm_expr_names(t, expr)
- local n, m = t:dim()
-
- local code_lines = {}
- local code = function(line) code_lines[#code_lines+1] = line end
-
- code([[local _LM = require 'lm-helpers']])
- code([[local enum = function(x) return x end]])
- local var_name = t:headers()
- for j = 1, m do
- local name = var_name[j]
- code(format([[local %s = _LM.var_name(%q)]], name, name))
- end
- code(format([[local expr_names = _LM.find_names(%s)]], expr))
- code([[return expr_names]])
-
- local code_str = concat(code_lines, '\n')
- local f = load(code_str)
- local expr_names = f()
- return expr_names
-end
-
-local function lm_prepare(t, expr)
- local n, m = t:dim()
- local column_class = {}
- local code_lines = {}
- local code = function(line) code_lines[#code_lines+1] = line end
- local name = t:headers(j)
- code([[local _LM = require 'lm-helpers']])
- code([[local _get = gdt.get]])
- code([[local enum = function(x) return {value = x} end]])
- for j = 1, m do
- column_class[j] = find_column_type(t, j)
- local value = (column_class[j] == 1 and "0" or format("_LM.factor(\"%s\")", name[j]))
- code(format("local %s = %s", name[j], value))
- end
-
- code(format([[local _y_spec = _LM.eval_test(%s)]], expr))
-
- code([[local _eval_func = _LM.eval_func]])
-
- code(format("local _eval = gdt.alloc(%d, _y_spec.np)", n))
-
- code(format("for _i = 1, %d do", n))
- for j = 1, m do
- local line
- if column_class[j] == 1 then
- line = format(" %s = _get(_t, _i, %d)", name[j], j)
- else
- line = format(" %s.value = _get(_t, _i, %d)", name[j], j)
- end
- code(line)
- end
- code("")
- code(format(" _eval_func(_y_spec, _eval, _i, %s)", expr))
- code(format("end", n))
-
- return format("return function(_t)\n%s\nreturn _eval, _y_spec\nend", concat(code_lines, "\n"))
+ local n = #t
+ for i = 1, n do
+ local x = gdt.get_number_unsafe(t, i, j)
+ if not x then return FACTOR_CLASS end
+ end
+ return SCALAR_CLASS
+end
+
+local function mult(a, b)
+ if a == 1 then return b end
+ if b == 1 then return a end
+ return {operator= '*', a, b}
+end
+
+local function scalar_infix(sym, a, b)
+ if sym == '*' then
+ return mult(a, b)
+ else
+ return {operator= sym, a, b}
+ end
+end
+
+local function factor_infix(sym, a, b)
+ if not (a or b) then return nil end
+ if sym ~= '*' then
+ error('non multiplicative opeation on factors')
+ end
+ local c = {}
+ if a then for i, f in ipairs(a) do c[#c+1] = f end end
+ if b then for i, f in ipairs(b) do c[#c+1] = f end end
+ return c
+end
+
+local function infix_action(sym, a, b)
+ local c = {}
+ c.scalar = scalar_infix(sym, a.scalar, b.scalar)
+ c.factor = factor_infix(sym, a.factor, b.factor)
+ return c
+end
+
+local function prefix_action(sym, a)
+ if a.factor then error('non multiplicative opeation on factors') end
+ return {scalar= {operator= sym, a}}
+end
+
+local function lm_actions_gen(t)
+ local n, m = t:dim()
+
+ local column_class = {}
+ for j = 1, m do column_class[j] = find_column_type(t, j) end
+
+ local function ident_action(id)
+ local index = t:col_index(id)
+ if column_class[index] == FACTOR_CLASS then
+ return {scalar= 1, factor= {id}}
+ else
+ return {scalar= id}
+ end
+ end
+
+ return {
+ infix = infix_action,
+ ident = ident_action,
+ prefix = prefix_action,
+ number = function(x) return {scalar= x} end,
+ exprlist = function(a, ls) if ls then ls[#ls+1] = a else ls = {a} end; return ls end,
+ schema = function(x, y) return {x= x, y= y} end,
+ }
end
local function add_unique(t, val)
- for k, x in ipairs(t) do
- if x == val then return 0 end
- end
- local n = #t + 1
- t[n] = val
- return n
-end
-
-local function lm_main(Xt, t, inf)
- local N = #t
-
- local index = {}
- local curr_index = 1
- local coeff_name = {}
- inf.factors, inf.factor_index = {}, {}
- for k = 1, inf.np do
- local factor_name = inf.factor_names[k]
- index[k] = curr_index
- if inf.class[k] == 0 then
- local factors, factor_index = {}, {}
- for i = 1, N do
- local str = Xt:get(i, k)
- local ui = add_unique(factors, str)
- if ui > 0 then
- factor_index[str] = ui
- if ui > 1 then coeff_name[curr_index + (ui - 2)] = factor_name .. "/" .. str end
- end
- end
- inf.factors[k] = factors
- inf.factor_index[k] = factor_index
- curr_index = curr_index + (#factors - 1)
- else
- coeff_name[curr_index] = factor_name
- curr_index = curr_index + 1
- end
- end
-
- local col = curr_index - 1
-
- local X = matrix.alloc(N, col)
- local X_data = X.data
- for i = 1, N do
- local idx0 = col * (i - 1)
- for k = 1, inf.np do
- local j = index[k]
- if inf.class[k] == 1 then
- X_data[idx0 + j - 1] = gdt.get_number_unsafe(Xt, i, k)
- else
- local factors = inf.factors[k]
- local factor_index = inf.factor_index[k]
- local req_f = Xt:get(i, k)
- local nf = #factors - 1
- for kf = 1, nf do
- X_data[idx0 + (j - 1) + (kf - 1)] = 0
- end
- local kfx = factor_index[req_f] - 1
- if kfx > 0 then X_data[idx0 + (j - 1) + (kfx - 1)] = 1 end
- end
- end
- end
-
- return X, coeff_name
-end
-
-local function lm_model(t, expr)
- local code = lm_prepare(t, expr)
- local f_gen = assert(load(code))
- local f_code = f_gen()
- local Xt, inf = f_code(t)
- inf.factor_names = lm_expr_names(t, expr)
- return lm_main(Xt, t, inf)
+ for k, x in ipairs(t) do
+ if x == val then return k end
+ end
+ local n = #t + 1
+ t[n] = val
+ return n
+end
+
+local function level_number(factors, levels)
+ if not factors then return 0 end
+ local nb = 1
+ for _, factor_name in ipairs(factors) do
+ nb = nb * #levels[factor_name]
+ end
+ return nb
+end
+
+local function enum_levels(factors, levels)
+ local ls, ks, ms = {}, {}, {}
+ local n = #factors
+ for i, name in ipairs(factors) do
+ ks[i], ms[i] = 0, #levels[name]
+ end
+
+ -- Start the counter from 1 instead of 0 to omit the first
+ -- level. It will be implicitely the reference.
+ ks[n] = (factors.omit_ref_level and 1 or 0)
+ while true do
+ local lev = {}
+ for i, name in ipairs(factors) do
+ lev[i] = levels[name][ks[i] + 1]
+ end
+ ls[#ls + 1] = lev
+
+ for i = n, 0, -1 do
+ if i == 0 then return ls end
+ ks[i] = (ks[i] + 1) % ms[i]
+ if ks[i] > 0 then break end
+ end
+ end
+end
+
+local function eval_operator(op, a, b)
+ if op == '+' then return a + b
+ elseif op == '-' then return a - b
+ elseif op == '*' then return a * b
+ elseif op == '/' then return a / b
+ elseif op == '^' then return a ^ b
+ else error('unkown operation: ' .. op) end
+end
+
+local function eval_scalar(t, i, expr)
+ local tp = type(expr)
+ if tp == 'string' then
+ local j = t:col_index(expr)
+ return t:get(i, j)
+ elseif tp == 'number' then
+ return expr
+ else
+ if #expr == 1 then
+ return - eval_scalar(t, i, expr[1])
+ else
+ local a = eval_scalar(t, i, expr[1])
+ local b = eval_scalar(t, i, expr[2])
+ return eval_operator(expr.operator, a, b)
+ end
+ end
+end
+
+local function level_does_match(t, i, factors, req_levels)
+ for k, factor_name in ipairs(factors) do
+ local y = t:get(i, t:col_index(factor_name))
+ if y ~= req_levels[k] then return 0 end
+ end
+ return 1
+end
+
+local function expr_are_equal(a, b)
+ if a == b then
+ return true
+ elseif type(a) == 'table' and type(b) == 'table' then
+ if a.operator == b.operator then
+ return expr_are_equal(a[1], b[1]) and expr_are_equal(a[2], b[2])
+ end
+ end
+ return false
+end
+
+local function scalar_term_exists(expr_list, s)
+ for i, expr in ipairs(expr_list) do
+ if not expr.factor and expr_are_equal(expr.scalar, s) then
+ return true
+ end
+ end
+ return false
+end
+
+local function print_expr_level(factors, levels)
+ local t = {}
+ for i, f in ipairs(factors) do
+ t[i] = string.format("%s%s", f, levels[i])
+ end
+ return table.concat(t, ':')
+end
+
+local function expr_is_unit(e)
+ return type(e) == 'number' and e == 1
+end
+
+local function build_lm_model(t, expr_list)
+ local N, M = t:dim()
+
+ -- list of unique factors referenced in expr_list
+ local used_factors = {}
+ for k, expr in ipairs(expr_list) do
+ if expr.factor then
+ for _, f_name in ipairs(expr.factor) do
+ add_unique(used_factors, f_name)
+ end
+ end
+ end
+
+ -- flag the factors whose scalar part is already used
+ -- in the model. In these cases the first level of the
+ -- factor will be omitted from the model matrix.
+ for k, expr in ipairs(expr_list) do
+ if expr.factor then
+ local s = expr.scalar
+ if scalar_term_exists(expr_list, s) then
+ expr.factor.omit_ref_level = true
+ end
+ end
+ end
+
+ -- for each unique used factor prepare the levels list and
+ -- set the column index
+ local levels, factor_index = {}, {}
+ for k, name in ipairs(used_factors) do
+ levels[name] = {}
+ factor_index[name] = t:col_index(name)
+ end
+
+ -- find the levels for each of the used factors
+ local get = t.get
+ for i = 1, N do
+ for _, name in ipairs(used_factors) do
+ local v = get(t, i, factor_index[name])
+ add_unique(levels[name], v)
+ end
+ end
+
+ local expr_index = {}
+ local curr_index = 1
+ for k, expr in ipairs(expr_list) do
+ expr_index[k] = curr_index
+ if expr.factor then
+ local lnb = level_number(expr.factor, levels)
+ local inn = expr.factor.omit_ref_level and lnb - 1 or lnb
+ curr_index = curr_index + inn
+ else
+ curr_index = curr_index + 1
+ end
+ end
+ expr_index[#expr_list + 1] = curr_index
+
+ local model_m = curr_index - 1
+ local X = matrix.alloc(N, model_m)
+ local names = {}
+ for k, expr in ipairs(expr_list) do
+ local scalar_repr = expr_print.expr(expr.scalar)
+ local index_offs = expr_index[k]
+ local scalar_expr = expr.scalar
+ if not expr.factor then
+ local j = index_offs
+ for i = 1, N do
+ local xs = eval_scalar(t, i, scalar_expr)
+ X:set(i, j, xs)
+ end
+ names[j] = scalar_repr
+ else
+ local expr_levels = enum_levels(expr.factor, levels)
+ local j0 = index_offs
+ for i = 1, N do
+ local xs = eval_scalar(t, i, scalar_expr)
+ for j, req_lev in ipairs(expr_levels) do
+ local match = level_does_match(t, i, expr.factor, req_lev)
+ X:set(i, j0 + (j - 1), match * xs)
+ end
+ end
+ for j, req_lev in ipairs(expr_levels) do
+ local level_repr = print_expr_level(expr.factor, req_lev)
+ local nm
+ if expr_is_unit(expr.scalar) then
+ nm = level_repr
+ else
+ nm = scalar_repr .. ' * ' .. level_repr
+ end
+ names[j0 + (j - 1)] = nm
+ end
+ end
+ end
+
+ return X, names
+end
+
+local function eval_y(t, y_expr)
+ local N = #t
+ local y = matrix.alloc(N, 1)
+ for i = 1, N do
+ local yi = eval_scalar(t, i, y_expr)
+ y.data[i - 1] = yi
+ end
+ return y
end
local function t_test(xm, s, n, df)
- local t = xm / s
- local at = abs(t)
- local p_value = 2 * (1 - randist.tdist_P(at, df))
- return t, (p_value >= 2e-16 and p_value or '< 2e-16')
-end
-
-local function lm(t, expr)
- local a, b = string.match(expr, "%s*([%S]+)%s*~(.+)")
- assert(a, "invalid lm expression")
- local n, m = t:dim()
- local jy = t:col_index(a)
- assert(jy, "invalid variable specification in lm expression")
- local set = gdt.set
- local y = matrix.new(n, 1, |i| t:get(i, jy))
- local X, name = lm_model(t, b)
- local c, chisq, cov = num.linfit(X, y)
- local coeff = gdt.alloc(#c, {"term", "estimate", "std error", "t value" ,"Pr(>|t|)"})
- for i = 1, #c do
- coeff:set(i, 1, name[i])
- local xm, s = c[i], cov:get(i,i)
- coeff:set(i, 2, xm)
- coeff:set(i, 3, sqrt(s))
- local t, p_value = t_test(xm, sqrt(s), n, n - #c)
- coeff:set(i, 4, t)
- coeff:set(i, 5, p_value)
- end
- return {coeff = coeff, c = c, chisq = chisq, cov = cov, X = X}
+ local t = xm / s
+ local at = abs(t)
+ local p_value = 2 * (1 - randist.tdist_P(at, df))
+ return t, (p_value >= 2e-16 and p_value or '< 2e-16')
+end
+
+local function linfit(X, y, names)
+ local n = #y
+ local c, chisq, cov = num.linfit(X, y)
+ local coeff = gdt.alloc(#c, {"term", "estimate", "std error", "t value" ,"Pr(>|t|)"})
+ for i = 1, #c do
+ coeff:set(i, 1, names[i])
+ local xm, s = c[i], cov:get(i,i)
+ coeff:set(i, 2, xm)
+ coeff:set(i, 3, sqrt(s))
+ local t, p_value = t_test(xm, sqrt(s), n, n - #c)
+ coeff:set(i, 4, t)
+ coeff:set(i, 5, p_value)
+ end
+ return {coeff = coeff, c = c, chisq = chisq, cov = cov, X = X}
+end
+
+local function lm(t, model_formula)
+ local actions = lm_actions_gen(t)
+ local l = mini.lexer(model_formula)
+ local schema = mini.schema(l, actions)
+ local X, names = build_lm_model(t, schema.x)
+ local y = eval_y(t, schema.y.scalar)
+ return linfit(X, y, names)
end
gdt.lm = lm
generated by cgit v1.2.3 (git 2.39.1) at 2025年09月20日 01:20:30 +0000

AltStyle によって変換されたページ (->オリジナル) /