author | Francesco Abbate <francesco.bbt@gmail.com> | 2013年01月24日 18:07:14 +0100 |
---|---|---|
committer | Francesco Abbate <francesco.bbt@gmail.com> | 2013年01月24日 18:14:42 +0100 |
commit | 0020f49a042221eb53ae88f5ba773e9b0db09e9a (patch) | |
tree | 7f0ae680003ecfa3694673e81eebfbe0b3901878 /gdt-plot.lua | |
parent | 35421c4a09fb4e148fd9ce4422a7ce45f836e61d (diff) | |
download | gsl-shell-0020f49a042221eb53ae88f5ba773e9b0db09e9a.tar.gz |
-rw-r--r-- | gdt-plot.lua | 275 |
diff --git a/gdt-plot.lua b/gdt-plot.lua index 15888639..f774c901 100644 --- a/gdt-plot.lua +++ b/gdt-plot.lua @@ -1,3 +1,6 @@ +local mini = require 'expr-parser' +local expr_print = require 'expr-print' + local concat = table.concat local select, unpack = select, unpack local sqrt = math.sqrt @@ -8,17 +11,6 @@ local function collate(ls, sep) return concat(ls, sep or ' ') end -local function treat_column_refs(t, js) - if type(js) ~= 'table' then js = {js} end - for i = 1, #js do - local v = js[i] - if type(v) == 'string' then - js[i] = t:col_index(v) - end - end - return js -end - -- recursive algorithm to computer the standard deviation from -- wikipedia: http://en.wikipedia.org/wiki/Standard_deviation. -- Welford, BP. "Note on a Method for Calculating Corrected Sums of @@ -55,33 +47,6 @@ local stat_lookup = { count = {f = function(accu, x, n) return n end}, } -local function treat_column_funcrefs(t, js) - if type(js) ~= 'table' then js = {js} end - for i = 1, #js do - local v = js[i] - local stat, name, fullname - if type(v) == 'string' then - fullname = v - stat, name = string.match(v, '(%a+)%((%w+)%)') - if not stat then - stat, name = 'mean', v - end - else - stat, name = 'mean', t:get_header(v) - end - local s = stat_lookup[stat] - assert(s, "invalid parameter requested") - js[i] = { - f = s.f, - f0 = s.f0, - fini = s.fini, - name = fullname, - index = t:col_index(name) - } - end - return js -end - local function compare_list(a, b) local n = #a for k = 1, n do @@ -126,28 +91,33 @@ local function vec2d_incr(r, i, j) return v + 1 end -local function treat_all_column_refs(t, jxs, jys, jes) - jxs = treat_column_refs(t, jxs) - jys = treat_column_refs(t, jys) - jes = treat_column_refs(t, jes) - return jxs, jys, jes +local function eval_scalar_gen(t) + local i + local id_res = function(expr) return t:get(i, expr.index) end + local func_res = function(expr) return math[expr.func] end + local set = function(ix) i = ix end + return set, {ident= id_res, func= func_res} end local function rect_funcbin(t, jxs, jys, jes) + local eval_set, eval_scope = eval_scalar_gen(t) + local eval = expr_print.eval + local n = #t local val, count = {}, {} local enums, labels = {}, {} local fini_table = {} for i = 1, n do + eval_set(i) local c = collate_factors(t, i, jxs) for p = 1, #jys do local jp = jys[p] - local jy, fy, fini = jp.index, jp.f, jp.fini + local fy, fini = jp.f, jp.fini local f0 = jp.f0 and jp.f0() or 0 local e = collate_factors(t, i, jes) e[#e+1] = jp.name - local v = t:get(i, jy) + local v = eval(jp.expr, eval_scope) if v then local ie = add_unique(enums, e) local ix = add_unique(labels, c) @@ -173,17 +143,62 @@ local function rect_funcbin(t, jxs, jys, jes) return labels, enums, val end -local function gdt_table_barplot(t, jxs, jys, jes, opt) - local show_plot = true - if opt then show_plot = (opt.show ~= false) end +local function infix_ast(sym, a, b) + return {operator= sym, a, b} +end - jxs = treat_column_refs(t, jxs) - jys = treat_column_funcrefs(t, jys) - jes = treat_column_refs(t, jes) +local function prefix_ast(sym, a) + return {operator= sym, a} +end - local rect, webcolor = graph.rect, graph.webcolor - local labels, enums, val = rect_funcbin(t, jxs, jys, jes) +local function func_eval_ast(func_name, arg_expr) + return {func= func_name, arg= arg_expr} +end + +local function itself(x) return x end + +local function plot_actions_gen(t) + + local function ident_ast(id) + local i = t:col_index(id) + if not i then error('unknown column name: '..id) end + return {name= id, index= i} + end + + return { + infix = infix_ast, + prefix = prefix_ast, + ident = ident_ast, + enum = itself, + func_eval = func_eval_ast, + number = itself, + exprlist = function(a, ls) if ls then ls[#ls+1] = a else ls = {a} end; return ls end, + schema = function(x, y, enums) return {x= x, y= y, enums= enums} end, + } +end + +local stat_lookup = { + mean = {f = function(accu, x, n) return (accu * (n-1) + x) / n end}, + stddev = {f = f_stddev, f0 = || {0, 0, 0}, fini = f_stddev_fini}, + stddevp = {f = f_stddev, f0 = || {0, 0, 0}, fini = f_stddevp_fini}, + var = {f = f_stddev, f0 = || {0, 0, 0}, fini = f_var_fini}, + sum = {f = function(accu, x, n) return accu + x end}, + count = {f = function(accu, x, n) return n end}, +} + +local function get_stat(expr) + if expr.func and stat_lookup[expr.func]then + return expr.func, expr.arg + else + return 'mean', expr + end +end +local rect, webcolor, path = graph.rect, graph.webcolor, graph.path + +local barplot = {} + +function barplot.create(labels, enums, val) local plt = graph.plot() local pad = 0.1 local dx = (1 - 2*pad) / #enums @@ -200,20 +215,19 @@ local function gdt_table_barplot(t, jxs, jys, jes, opt) cat[2*p-1] = p - 0.5 cat[2*p] = collate(lab) end + return plt, cat +end - plt:set_categories('x', cat) - plt.xlab_angle = math.pi/4 - +function barplot.legend(plt, labels, enums) if #enums > 1 then for k = 1, #enums do plt:legend(collate(enums[k], '/'), webcolor(k), 'square') end end - - if show_plot then plt:show() end - return plt end +local lineplot = {} + local function legend_symbol(sym, dx, dy) if sym == 'square' then return graph.rect(5+dx, 5+dy, 15+dx, 15+dy) @@ -234,20 +248,10 @@ local function add_legend(lg, k, symspec, color, text) end end -local function gdt_table_lineplot(t, jxs, jys, jes, opt) - local show_plot = true - if opt then show_plot = (opt.show ~= false) end - - jxs = treat_column_refs(t, jxs) - jys = treat_column_funcrefs(t, jys) - jes = treat_column_refs(t, jes) - - local path, webcolor = graph.path, graph.webcolor - local labels, enums, val = rect_funcbin(t, jxs, jys, jes) - - local plt, lg = graph.plot(), graph.plot() +function lineplot.create(labels, enums, val) + local plt = graph.plot() plt.pad, plt.clip = true, false - lg.units, lg.clip = false, false + for q, en in ipairs(enums) do local ln = path() local path_method = ln.move_to @@ -262,45 +266,109 @@ local function gdt_table_lineplot(t, jxs, jys, jes, opt) end plt:add(ln, webcolor(q), {{'stroke', width=line_width}}) plt:add(ln, webcolor(q), {{'marker', size=8, mark=q}}) + end + + local cat = {} + for p, lab in ipairs(labels) do + cat[2*p-1] = p - 0.5 + cat[2*p] = collate(lab) + end + + return plt, cat +end - if #enums > 1 then +function lineplot.legend(plt, labels, enums) + if #enums > 1 then + local lg = graph.plot() + lg.units, lg.clip = false, false + for q, en in ipairs(enums) do local label = collate(en) add_legend(lg, q, 'line', webcolor(q), label) add_legend(lg, q, q, webcolor(q)) end + plt:set_legend(lg) end +end - plt:set_legend(lg) +local function idents_get_column_indexes(t, exprs) + local jxs = {} + for i, expr in ipairs(exprs) do + if not expr.name then error('invalid enumeration factor') end + jxs[i] = t:col_index(expr.name) + end + return jxs +end - local cat = {} - for p, lab in ipairs(labels) do - cat[2*p-1] = p - 0.5 - cat[2*p] = collate(lab) +local function stat_expr_get_functions(exprs) + local jys = {} + for i, expr in ipairs(exprs) do + local stat_name, yexpr = get_stat(expr) + local s = stat_lookup[stat_name] + jys[i] = { + f = s.f, + f0 = s.f0, + fini = s.fini, + name = expr_print.expr(expr), + expr = yexpr, + } end + return jys +end + +local function expr_get_functions(exprs) + local jys = {} + for i, expr in ipairs(exprs) do + jys[i] = { + name = expr_print.expr(expr), + expr = expr, + } + end + return jys +end + +local function schema_from_plot_descr(plot_descr, t) + local l = mini.lexer(plot_descr) + local actions = plot_actions_gen(t) + return mini.gschema(l, actions) +end + +local function gdt_table_category_plot(plotter, t, plot_descr, opt) + local show_plot = true + if opt then show_plot = (opt.show ~= false) end + + local schema = schema_from_plot_descr(plot_descr, t) + local jxs = idents_get_column_indexes(t, schema.x) + local jys = stat_expr_get_functions(schema.y) + local jes = idents_get_column_indexes(t, schema.enums) + + local labels, enums, val = rect_funcbin(t, jxs, jys, jes) + + local plt, cat = plotter.create(labels, enums, val) plt:set_categories('x', cat) plt.xlab_angle = math.pi/4 - if show_plot then plt:show() end + plotter.legend(plt, labels, enums) + if show_plot then plt:show() end return plt end -local function gdt_table_xyplot(t, jx, jys, jes, opt) +local function gdt_table_xyplot(t, plot_descr, opt) local show_plot = true if opt then show_plot = (opt.show ~= false) end - local path, webcolor = graph.path, graph.webcolor - local use_lines = opt and opt.lines - local use_markers = true - if opt then - if opt.markers == false then use_markers = false end - end + local use_markers = opt and (opt.markers ~= false) or true - jx = type(jx) == 'number' and jx or t:col_index(jx) - jes = treat_column_refs(t, jes) - jys = treat_column_refs(t, jys) + local schema = schema_from_plot_descr(plot_descr, t) + local jxs = expr_get_functions(schema.x) + local jys = expr_get_functions(schema.y) + local jes = idents_get_column_indexes(t, schema.enums) + local jx = jxs[1] + + local eval_set, eval_scope = eval_scalar_gen(t) + local eval = expr_print.eval local enums = {} local n = #t @@ -314,14 +382,15 @@ local function gdt_table_xyplot(t, jx, jys, jes, opt) lg.units, lg.clip = false, false local mult = #enums * #jys for p = 1, #jys do - local name = t:get_header(jys[p]) + local name = jys[p].name for q, enum in ipairs(enums) do local ln = path() local path_method = ln.move_to for i = 1, n do + eval_set(i) local e = collate_factors(t, i, jes) if compare_list(enum, e) then - local x, y = t:get(i, jx), t:get(i, jys[p]) + local x, y = eval(jx.expr, eval_scope), eval(jys[p].expr, eval_scope) if x and y then path_method(ln, x, y) path_method = ln.line_to @@ -355,10 +424,11 @@ local function gdt_table_xyplot(t, jx, jys, jes, opt) return plt end -local function gdt_table_reduce(t_src, jxs, jys, jes) - jxs = treat_column_refs(t_src, jxs) - jys = treat_column_funcrefs(t_src, jys) - jes = treat_column_refs(t_src, jes) +local function gdt_table_reduce(t_src, schema_descr) + local schema = schema_from_plot_descr(schema_descr, t_src) + local jxs = idents_get_column_indexes(t_src, schema.x) + local jys = stat_expr_get_functions(schema.y) + local jes = idents_get_column_indexes(t_src, schema.enums) local labels, enums, val = rect_funcbin(t_src, jxs, jys, jes) @@ -385,18 +455,7 @@ local function gdt_table_reduce(t_src, jxs, jys, jes) return t end -local function count_args(...) - return select('#', ...) -end - -local function set_elements(X, P, i, ...) - for k = 1, P do - local v = select(k, ...) - X:set(i, k, v) - end -end - -gdt.barplot = gdt_table_barplot -gdt.plot = gdt_table_lineplot +gdt.barplot = function(t, spec, opt) return gdt_table_category_plot(barplot, t, spec, opt) end +gdt.plot = function(t, spec, opt) return gdt_table_category_plot(lineplot, t, spec, opt) end gdt.xyplot = gdt_table_xyplot gdt.reduce = gdt_table_reduce |