-rw-r--r-- | expr-print.lua | 6 | ||||
-rw-r--r-- | gdt-lm.lua | 18 |
diff --git a/expr-print.lua b/expr-print.lua index 0520d776..a14d400c 100644 --- a/expr-print.lua +++ b/expr-print.lua @@ -1,4 +1,5 @@ local expr_lexer = require 'expr-lexer' +local AST = require 'expr-actions' local format, concat = string.format, table.concat @@ -108,8 +109,9 @@ end local function ref_list_rec(expr, list) if type(expr) == 'number' then return - elseif type(expr) == 'string' then - list[expr] = true + elseif AST.is_variable(expr) then + local _, var_name = AST.is_variable(expr) + list[var_name] = true elseif expr.literal then return elseif expr.func then diff --git a/gdt-lm.lua b/gdt-lm.lua index 31e26fe4..f14bdf46 100644 --- a/gdt-lm.lua +++ b/gdt-lm.lua @@ -169,9 +169,27 @@ function FIT.eval(fit, tn) return sy end +local function check_var_references(t, schema) + local refs = {} + for _, expr in ipairs(schema.x) do + expr_print.references(expr, refs) + end + expr_print.references(schema.y, refs) + for _, expr in ipairs(schema.conds) do + expr_print.references(expr, refs) + end + for var_name in pairs(refs) do + if not t:col_index(var_name) then + error('invalid reference to column name \"'..var_name.."\"", 3) + end + end +end + local function lm(t, model_formula, options) local schema = expr_parse.schema(model_formula, AST, false) + check_var_references(t, schema) + local expand = not options or (options.expand == nil or options.expand) local xs = expand and expand_exprs(schema.x) or schema.x local x_exprs = gdt_factors.compute(t, xs) |