-rw-r--r-- | gdt-plot.lua | 35 |
diff --git a/gdt-plot.lua b/gdt-plot.lua index a1308c4d..8d624474 100644 --- a/gdt-plot.lua +++ b/gdt-plot.lua @@ -1,4 +1,5 @@ local concat = table.concat +local select = select local function collate(ls) return concat(ls, ' ') @@ -257,7 +258,41 @@ local function gdt_table_reduce(t_src, jxs, jys, jes) return t end +local function count_args(...) + return select('#', ...) +end + +local function set_elements(X, P, i, ...) + for k = 1, P do + local v = select(k, ...) + X:set(i, k, v) + end +end + +local function gdt_table_linfit(t, f, jy) + local N, M = t:dim() + local name = {} + for k = 1, M do + name[k] = t:get_header(k) + end + local row = {} + for k = 1, M do + row[name[k]] = t:get(1, k) + end + local P = count_args(f(row)) + + local X, Y = matrix.alloc(N, P), matrix.alloc(N, 1) + for i = 1, N do + for k = 1, M do row[name[k]] = t:get(i, k) end + set_elements(X, P, i, f(row)) + Y:set(i, 1, t:get(i, jy)) + end + + return num.linfit(X, Y) +end + gdt.barplot = gdt_table_barplot gdt.plot = gdt_table_lineplot gdt.xyplot = gdt_table_xyplot gdt.reduce = gdt_table_reduce +gdt.lm = gdt_table_linfit |