gdt-lm.lua - gsl-shell.git - gsl-shell

index : gsl-shell.git
gsl-shell
summary refs log tree commit diff
path: root/gdt-lm.lua
blob: dabbf71c08a9f88614840cb30a72253ea065c4a1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
local expr_parse = require 'expr-parse'
local expr_print = require 'expr-print'
local gdt_expr = require 'gdt-expr'
local gdt_factors = require 'gdt-factors'
local check = require 'check'
local mon = require 'monomial'
local AST = require 'expr-actions'
local linfit_rank = require 'linfit_rank'
local sqrt, abs = math.sqrt, math.abs
local ipairs = ipairs
local function monomial_to_expr(m, context)
 local coeff = m[1]
 local prod
 for k, sym, pow in mon.terms(m) do
 local base = context[sym]
 local t = (pow == 1 and base or {operator='^', base, pow})
 prod = (prod and {operator='*', t, prod} or t)
 end
 return coeff == 1 and (prod or 1) or (prod and {operator='*', coeff, prod} or coeff)
end
local function expr_to_monomial(expr, context)
 if AST.is_number(expr) then
 return {expr}
 elseif expr.operator == '*' then
 local a = expr_to_monomial(expr[1], context)
 local b = expr_to_monomial(expr[2], context)
 mon.mult(a, b)
 return a
 elseif expr.operator == '^' and check.is_integer(expr[2]) then
 local base = expr_to_monomial(expr[1], context)
 mon.power(base, expr[2])
 return base
 else
 local s = expr_print.expr(expr)
 context[s] = expr
 return mon.symbol(s)
 end
end
local function t_test(xm, s, n, df)
 local t = xm / s
 local at = abs(t)
 local p_value = 2 * (1 - randist.tdist_P(at, df))
 return t, (p_value >= 2e-16 and p_value or '< 2e-16')
end
local function list_exists(ls, x)
 local n = #ls
 for i = 1, n do
 if ls[i] == x then return true end
 end
 return false
end
local function compute_fit(X, y, names)
 local n, p = matrix.dim(X)
 local c, chisq, cov, remov = linfit_rank(X, y)
 local rank = p - #remov
 local coeff = gdt.alloc(p, {"term", "estimate", "std error", "t value" ,"Pr(>|t|)"})
 for i = 1, p do
 coeff:set(i, 1, names[i])
 local xm, s, t, p_value
 if not list_exists(remov, i) then
 xm, s = c[i], sqrt(cov:get(i,i))
 t, p_value = t_test(xm, s, n, n - rank)
 end
 coeff:set(i, 2, xm)
 coeff:set(i, 3, s)
 coeff:set(i, 4, t)
 coeff:set(i, 5, p_value)
 end
 return {coeff = coeff, c = c, chisq = chisq, cov = cov, n = n, p = p, rank= rank}
end
local function fit_compute_Rsquare(fit, X, y)
 local n, p = fit.n, fit.rank
 local y_pred = X * fit.c
 local y_mean = 0
 for k = 1, #y do y_mean = y_mean + y:get(k, 1) end
 y_mean = y_mean / #y
 local SS_tot, SS_reg = 0, 0
 for k = 1, #y do
 SS_reg = SS_reg + (y:get(k, 1) - y_pred:get(k, 1))^2
 SS_tot = SS_tot + (y:get(k, 1) - y_mean)^2
 end
 local R2 = 1 - SS_reg/SS_tot
 local R2_adj = R2 - (1 - R2) * p / (n - p - 1)
 local SE = sqrt(SS_reg / (n - p))
 return SE, R2, R2_adj
end
local function fit_add_predicted(t, param_name, X, fit, index_map)
 local cname = string.format("%s (PREDICTED)", param_name)
 if not t:col_index(cname) then t:col_append(cname) end
 local cindex = t:col_index(cname)
 local y_pred = X * fit.c
 local jy = 1
 for k = 1, #index_map - 1, 2 do
 local idx, len = index_map[k], index_map[k+1]
 for j = 0, len - 1 do
 t:set(idx + j, cindex, y_pred[jy + j])
 end
 jy = jy + len
 end
end
local function monomial_exists(ls, e)
 local n = #ls
 for k = 1, n do
 if mon.equal(ls[k], e) then return true end
 end
 return false
end
local function expand_exprs(expr_list)
 -- mls is a table indexed with: 0 for scalar and an id >= 1 for
 -- each factor combination.
 -- Each value in mls is a list of the monomials already
 -- included in the expansion.
 -- mls[0] is the list of monomials for the purely scalar terms.
 -- msl[<factor id>] is the list of monomials for the named
 -- factor term.
 local mls, els = {}, {}
 for k, e in ipairs(expr_list) do
 local context = {}
 local m = expr_to_monomial(e, context)
 for _, mexp in ipairs(mon.combine(m)) do
 local eexp = monomial_to_expr(mexp, context)
 if not monomial_exists(mls, mexp) then
 mls[#mls+1] = mexp
 els[#els+1] = eexp
 end
 end
 end
 return els
end
local FIT = {}
function FIT.model(fit, t_alt)
 return gdt_expr.eval_matrix(t_alt, fit.info, fit.x_exprs)
end
function FIT.predict(fit, t_alt)
 local X = gdt_expr.eval_matrix(t_alt, fit.info, fit.x_exprs)
 return X * fit.c
end
function FIT.summary(fit)
 print(fit.coeff)
 if fit.rank < fit.p then
 print()
 print('WARNING: model has linearly dependent terms.')
 print(string.format(' %i of the %i coefficients excluded from model.', fit.p - fit.rank, fit.p))
 end
 print()
 print(string.format("Standard Error: %g, R2: %g, Adjusted R2: %g", fit.SE, fit.R2, fit.R2_adj))
end
function FIT.show(fit)
 return string.format("<fit %p: model: %s>", fit, fit.model_formula)
end
local FIT_MT = {__index = FIT}
-- used to eval a model for a single entry
function FIT.eval(fit, tn)
 local eval_table = fit.eval_table
 for k, name in ipairs(fit.headers) do
 eval_table:set(1, k, tn[name])
 end
 local coeff = fit.c
 local sX = gdt_expr.eval_matrix(eval_table, fit.info, fit.x_exprs)
 local sy = 0
 for k = 0, #coeff - 1 do
 sy = sy + sX.data[k] * coeff.data[k]
 end
 return sy
end
local function check_var_references(t, schema)
 local refs = {}
 for _, expr in ipairs(schema.x) do
 expr_print.references(expr, refs)
 end
 expr_print.references(schema.y, refs)
 for _, expr in ipairs(schema.conds) do
 expr_print.references(expr, refs)
 end
 for var_name in pairs(refs) do
 if not t:col_index(var_name) then
 error('invalid reference to column name \"'..var_name.."\"", 3)
 end
 end
end
local function lm(t, model_formula, options)
 local schema = expr_parse.schema(model_formula, AST, false)
 check_var_references(t, schema)
 local expand = not options or (options.expand == nil or options.expand)
 local xs = expand and expand_exprs(schema.x) or schema.x
 local x_exprs = gdt_factors.compute(t, xs)
 local y_expr = schema.y
 local info, index_map = gdt_expr.prepare_model(t, x_exprs, y_expr, schema.conds)
 local X, y = gdt_expr.eval_matrix(t, info, x_exprs, y_expr, index_map)
 local fit = compute_fit(X, y, info.names)
 if options and options.predict then
 local y_name = expr_print.expr(y_expr)
 fit_add_predicted(t, y_name, X, fit, index_map)
 end
 fit.info = info
 fit.model_formula = model_formula
 fit.schema = schema
 fit.x_exprs = x_exprs
 fit.SE, fit.R2, fit.R2_adj = fit_compute_Rsquare(fit, X, y)
 fit.eval_table = gdt.alloc(1, t:headers())
 fit.headers = t:headers()
 return setmetatable(fit, FIT_MT)
end
gdt.lm = lm
generated by cgit v1.2.3 (git 2.25.1) at 2025年09月10日 20:56:21 +0000

AltStyle によって変換されたページ (->オリジナル) /