gsl-shell.git - gsl-shell

index : gsl-shell.git
gsl-shell
summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat
-rw-r--r--gdt-lm.lua 73
1 files changed, 51 insertions, 22 deletions
diff --git a/gdt-lm.lua b/gdt-lm.lua
index 1094a44c..f0af2074 100644
--- a/gdt-lm.lua
+++ b/gdt-lm.lua
@@ -202,6 +202,7 @@ local function eval_lm_matrix(t, expr_list, y_expr)
local X = matrix.alloc(N, XM)
local Y = y_expr and matrix.alloc(N, 1)
+ local index_map, index_map_start, index_map_len = {}, 1, 0
local row_index = 1
for i = 1, N do
eval_set(i)
@@ -209,14 +210,13 @@ local function eval_lm_matrix(t, expr_list, y_expr)
local col_index = 1
for k = 1, NE do
local expr = expr_list[k]
- local scalar_expr = expr.scalar
- local xs = eval_scalar(scalar_expr, eval_scope)
- row_undef = row_undef or (not xs)
- if xs then
+ local xs = eval_scalar(expr.scalar, eval_scope)
+ local is_undef = (not xs) or (expr.factor and not factors_defined(t, i, expr.factor))
+ row_undef = row_undef or is_undef
+ if not is_undef then
if not expr.factor then
X:set(row_index, col_index, xs)
else
- row_undef = row_undef or (not factors_defined(t, i, expr.factor))
local j0 = col_index
for j, req_lev in ipairs(expr.levels) do
local match = level_does_match(t, i, expr.factor, req_lev)
@@ -226,17 +226,29 @@ local function eval_lm_matrix(t, expr_list, y_expr)
end
col_index = col_index + expr.mult
end
- if not row_undef then
- if y_expr then
- local y_val = eval_scalar(y_expr, eval_scope)
- if y_val then
- Y:set(row_index, 1, y_val)
- row_index = row_index + 1
- end
- else
- row_index = row_index + 1
- end
+
+ if y_expr and not row_undef then
+ local y_val = eval_scalar(y_expr, eval_scope)
+ row_undef = (not y_val)
+ if y_val then Y:set(row_index, 1, y_val) end
end
+
+ if row_undef then
+ local kk = #index_map
+ index_map[kk+1] = index_map_start
+ index_map[kk+2] = index_map_len
+ index_map_start = i + 1
+ index_map_len = 0
+ else
+ row_index = row_index + 1
+ index_map_len = index_map_len + 1
+ end
+ end
+
+ if index_map_len > 0 then
+ local kk = #index_map
+ index_map[kk+1] = index_map_start
+ index_map[kk+2] = index_map_len
end
local nb_rows = row_index - 1
@@ -246,7 +258,7 @@ local function eval_lm_matrix(t, expr_list, y_expr)
X.size1 = nb_rows
if y_expr then
Y.size1 = nb_rows
- return X, Y
+ return X, Y, index_map
end
return X
@@ -377,19 +389,36 @@ local function fit_compute_Rsquare(fit, t)
return SE, R2, R2_adj
end
-local function lm(t, model_formula)
+local function fit_add_predicted(t, param_name, X, fit, index_map)
+ local cname = string.format("%s (PREDICTED)", param_name)
+ t:col_append(cname)
+ local cindex = t:col_index(cname)
+
+ local y_pred = X * fit.c
+ local jy = 1
+ for k = 1, #index_map - 1, 2 do
+ local idx, len = index_map[k], index_map[k+1]
+ for j = 0, len - 1 do
+ t:set(idx + j, cindex, y_pred[jy + j])
+ end
+ jy = jy + len
+ end
+end
+
+local function lm(t, model_formula, options)
local actions = lm_actions_gen(t)
local l = mini.lexer(model_formula)
local schema = mini.schema(l, actions)
- local function matrix_eval(t_alt)
- return eval_lm_matrix(t_alt, schema.x, schema.y.scalar)
- end
-
local names = build_lm_model(t, schema.x, schema.y.scalar)
- local X, y = matrix_eval(t)
+ local X, y, index_map = eval_lm_matrix(t, schema.x, schema.y.scalar, true)
local fit = compute_fit(X, y, names)
+ if options and options.predict then
+ local y_name = expr_print.expr(schema.y.scalar)
+ fit_add_predicted(t, y_name, X, fit, index_map)
+ end
+
fit.schema = schema
function fit.model(t_alt)
generated by cgit v1.2.3 (git 2.25.1) at 2025年09月17日 03:37:44 +0000

AltStyle によって変換されたページ (->オリジナル) /