-rw-r--r-- | gdt-lm.lua | 73 |
diff --git a/gdt-lm.lua b/gdt-lm.lua index 1094a44c..f0af2074 100644 --- a/gdt-lm.lua +++ b/gdt-lm.lua @@ -202,6 +202,7 @@ local function eval_lm_matrix(t, expr_list, y_expr) local X = matrix.alloc(N, XM) local Y = y_expr and matrix.alloc(N, 1) + local index_map, index_map_start, index_map_len = {}, 1, 0 local row_index = 1 for i = 1, N do eval_set(i) @@ -209,14 +210,13 @@ local function eval_lm_matrix(t, expr_list, y_expr) local col_index = 1 for k = 1, NE do local expr = expr_list[k] - local scalar_expr = expr.scalar - local xs = eval_scalar(scalar_expr, eval_scope) - row_undef = row_undef or (not xs) - if xs then + local xs = eval_scalar(expr.scalar, eval_scope) + local is_undef = (not xs) or (expr.factor and not factors_defined(t, i, expr.factor)) + row_undef = row_undef or is_undef + if not is_undef then if not expr.factor then X:set(row_index, col_index, xs) else - row_undef = row_undef or (not factors_defined(t, i, expr.factor)) local j0 = col_index for j, req_lev in ipairs(expr.levels) do local match = level_does_match(t, i, expr.factor, req_lev) @@ -226,17 +226,29 @@ local function eval_lm_matrix(t, expr_list, y_expr) end col_index = col_index + expr.mult end - if not row_undef then - if y_expr then - local y_val = eval_scalar(y_expr, eval_scope) - if y_val then - Y:set(row_index, 1, y_val) - row_index = row_index + 1 - end - else - row_index = row_index + 1 - end + + if y_expr and not row_undef then + local y_val = eval_scalar(y_expr, eval_scope) + row_undef = (not y_val) + if y_val then Y:set(row_index, 1, y_val) end end + + if row_undef then + local kk = #index_map + index_map[kk+1] = index_map_start + index_map[kk+2] = index_map_len + index_map_start = i + 1 + index_map_len = 0 + else + row_index = row_index + 1 + index_map_len = index_map_len + 1 + end + end + + if index_map_len > 0 then + local kk = #index_map + index_map[kk+1] = index_map_start + index_map[kk+2] = index_map_len end local nb_rows = row_index - 1 @@ -246,7 +258,7 @@ local function eval_lm_matrix(t, expr_list, y_expr) X.size1 = nb_rows if y_expr then Y.size1 = nb_rows - return X, Y + return X, Y, index_map end return X @@ -377,19 +389,36 @@ local function fit_compute_Rsquare(fit, t) return SE, R2, R2_adj end -local function lm(t, model_formula) +local function fit_add_predicted(t, param_name, X, fit, index_map) + local cname = string.format("%s (PREDICTED)", param_name) + t:col_append(cname) + local cindex = t:col_index(cname) + + local y_pred = X * fit.c + local jy = 1 + for k = 1, #index_map - 1, 2 do + local idx, len = index_map[k], index_map[k+1] + for j = 0, len - 1 do + t:set(idx + j, cindex, y_pred[jy + j]) + end + jy = jy + len + end +end + +local function lm(t, model_formula, options) local actions = lm_actions_gen(t) local l = mini.lexer(model_formula) local schema = mini.schema(l, actions) - local function matrix_eval(t_alt) - return eval_lm_matrix(t_alt, schema.x, schema.y.scalar) - end - local names = build_lm_model(t, schema.x, schema.y.scalar) - local X, y = matrix_eval(t) + local X, y, index_map = eval_lm_matrix(t, schema.x, schema.y.scalar, true) local fit = compute_fit(X, y, names) + if options and options.predict then + local y_name = expr_print.expr(schema.y.scalar) + fit_add_predicted(t, y_name, X, fit, index_map) + end + fit.schema = schema function fit.model(t_alt) |