author | Francesco Abbate <francesco.bbt@gmail.com> | 2012年12月18日 17:13:45 +0100 |
---|---|---|
committer | Francesco Abbate <francesco.bbt@gmail.com> | 2012年12月18日 17:13:45 +0100 |
commit | a98ae9e87bfba7d40c33218f7a5523fa236875de (patch) | |
tree | 41478e7b40e57ab445b0bd9b6319a86c933579fa /gdt-plot.lua | |
parent | 41522ebdfb5bd64a439a45ddc77889554da93215 (diff) | |
download | gsl-shell-a98ae9e87bfba7d40c33218f7a5523fa236875de.tar.gz |
-rw-r--r-- | gdt-plot.lua | 58 |
diff --git a/gdt-plot.lua b/gdt-plot.lua index a2fb9815..4800c0ad 100644 --- a/gdt-plot.lua +++ b/gdt-plot.lua @@ -16,6 +16,32 @@ local function treat_column_refs(t, js) return js end +local stat_lookup = { + mean = {f = function(accu, x, n) return (accu * (n-1) + x) / n end}, + sum = {f = function(accu, x, n) return accu + x end}, + count = {f = function(accu, x, n) return n end}, +} + +local function treat_column_funcrefs(t, js) + if type(js) ~= 'table' then js = {js} end + for i = 1, #js do + local v = js[i] + local stat, name + if type(v) == 'string' then + stat, name = string.match(v, '(%a+)%((%w+)%)') + if not stat then + stat, name = 'mean', v + end + else + stat, name = 'mean', t:get_header(v) + end + local s = stat_lookup[stat] + assert(s, "invalid parameter requested") + js[i] = {f= s.f, f0 = s.f0 or 0, index = t:col_index(name)} + end + return js +end + local function compare_list(a, b) local n = #a for k = 1, n do @@ -87,6 +113,30 @@ local function rect_bin(t, jxs, jys, jes) return labels, enums, val end +local function rect_funcbin(t, jxs, jys, jes) + local n = #t + local val, count = {}, {} + local enums, labels = {}, {} + for i = 1, n do + local c = collate_factors(t, i, jxs) + for p = 1, #jys do + local jy, fy, fy0 = jys[p].index, jys[p].f, jys[p].f0 + local e = collate_factors(t, i, jes) + if #jys > 1 then + e[#e+1] = t:get_header(jys[p]) + end + local ie = add_unique(enums, e) + local ix = add_unique(labels, c) + local cc = vec2d_incr(count, ix, ie) + local v_accu = vec2d_get(val, ix, ie) or fy0 + local v = t:get(i, jy) + vec2d_set(val, ix, ie, fy(v_accu, v, cc)) + end + end + + return labels, enums, val +end + local function gdt_table_barplot(t, jxs, jys, jes) jxs, jys, jes = treat_all_column_refs(t, jxs, jys, jes) @@ -256,11 +306,11 @@ local function gdt_table_xyplot(t, jx, jys, jes, opt) end local function gdt_table_reduce(t_src, jxs, jys, jes) - jxs = treat_column_refs(t_src, jxs) - jys = treat_column_refs(t_src, jys) - jes = treat_column_refs(t_src, jes) + jxs = treat_column_refs(t, jxs) + jys = treat_column_funcrefs(t, jys) + jes = treat_column_refs(t, jes) - local labels, enums, val = rect_bin(t_src, jxs, jys, jes) + local labels, enums, val = rect_funcbin(t_src, jxs, jys, jes) local n, p, q = #labels, #enums, #labels[1] local t = gdt.new(n, q + p) |