Switch to new gdt plot implementation based on expressions - gsl-shell.git - gsl-shell

index : gsl-shell.git
gsl-shell
summary refs log tree commit diff
path: root/gdt-plot.lua
diff options
context:
space:
mode:
authorFrancesco Abbate <francesco.bbt@gmail.com>2013年01月24日 18:07:14 +0100
committerFrancesco Abbate <francesco.bbt@gmail.com>2013年01月24日 18:14:42 +0100
commit0020f49a042221eb53ae88f5ba773e9b0db09e9a (patch)
tree7f0ae680003ecfa3694673e81eebfbe0b3901878 /gdt-plot.lua
parent35421c4a09fb4e148fd9ce4422a7ce45f836e61d (diff)
downloadgsl-shell-0020f49a042221eb53ae88f5ba773e9b0db09e9a.tar.gz
Switch to new gdt plot implementation based on expressions
Diffstat (limited to 'gdt-plot.lua')
-rw-r--r--gdt-plot.lua 275
1 files changed, 167 insertions, 108 deletions
diff --git a/gdt-plot.lua b/gdt-plot.lua
index 15888639..f774c901 100644
--- a/gdt-plot.lua
+++ b/gdt-plot.lua
@@ -1,3 +1,6 @@
+local mini = require 'expr-parser'
+local expr_print = require 'expr-print'
+
local concat = table.concat
local select, unpack = select, unpack
local sqrt = math.sqrt
@@ -8,17 +11,6 @@ local function collate(ls, sep)
return concat(ls, sep or ' ')
end
-local function treat_column_refs(t, js)
- if type(js) ~= 'table' then js = {js} end
- for i = 1, #js do
- local v = js[i]
- if type(v) == 'string' then
- js[i] = t:col_index(v)
- end
- end
- return js
-end
-
-- recursive algorithm to computer the standard deviation from
-- wikipedia: http://en.wikipedia.org/wiki/Standard_deviation.
-- Welford, BP. "Note on a Method for Calculating Corrected Sums of
@@ -55,33 +47,6 @@ local stat_lookup = {
count = {f = function(accu, x, n) return n end},
}
-local function treat_column_funcrefs(t, js)
- if type(js) ~= 'table' then js = {js} end
- for i = 1, #js do
- local v = js[i]
- local stat, name, fullname
- if type(v) == 'string' then
- fullname = v
- stat, name = string.match(v, '(%a+)%((%w+)%)')
- if not stat then
- stat, name = 'mean', v
- end
- else
- stat, name = 'mean', t:get_header(v)
- end
- local s = stat_lookup[stat]
- assert(s, "invalid parameter requested")
- js[i] = {
- f = s.f,
- f0 = s.f0,
- fini = s.fini,
- name = fullname,
- index = t:col_index(name)
- }
- end
- return js
-end
-
local function compare_list(a, b)
local n = #a
for k = 1, n do
@@ -126,28 +91,33 @@ local function vec2d_incr(r, i, j)
return v + 1
end
-local function treat_all_column_refs(t, jxs, jys, jes)
- jxs = treat_column_refs(t, jxs)
- jys = treat_column_refs(t, jys)
- jes = treat_column_refs(t, jes)
- return jxs, jys, jes
+local function eval_scalar_gen(t)
+ local i
+ local id_res = function(expr) return t:get(i, expr.index) end
+ local func_res = function(expr) return math[expr.func] end
+ local set = function(ix) i = ix end
+ return set, {ident= id_res, func= func_res}
end
local function rect_funcbin(t, jxs, jys, jes)
+ local eval_set, eval_scope = eval_scalar_gen(t)
+ local eval = expr_print.eval
+
local n = #t
local val, count = {}, {}
local enums, labels = {}, {}
local fini_table = {}
for i = 1, n do
+ eval_set(i)
local c = collate_factors(t, i, jxs)
for p = 1, #jys do
local jp = jys[p]
- local jy, fy, fini = jp.index, jp.f, jp.fini
+ local fy, fini = jp.f, jp.fini
local f0 = jp.f0 and jp.f0() or 0
local e = collate_factors(t, i, jes)
e[#e+1] = jp.name
- local v = t:get(i, jy)
+ local v = eval(jp.expr, eval_scope)
if v then
local ie = add_unique(enums, e)
local ix = add_unique(labels, c)
@@ -173,17 +143,62 @@ local function rect_funcbin(t, jxs, jys, jes)
return labels, enums, val
end
-local function gdt_table_barplot(t, jxs, jys, jes, opt)
- local show_plot = true
- if opt then show_plot = (opt.show ~= false) end
+local function infix_ast(sym, a, b)
+ return {operator= sym, a, b}
+end
- jxs = treat_column_refs(t, jxs)
- jys = treat_column_funcrefs(t, jys)
- jes = treat_column_refs(t, jes)
+local function prefix_ast(sym, a)
+ return {operator= sym, a}
+end
- local rect, webcolor = graph.rect, graph.webcolor
- local labels, enums, val = rect_funcbin(t, jxs, jys, jes)
+local function func_eval_ast(func_name, arg_expr)
+ return {func= func_name, arg= arg_expr}
+end
+
+local function itself(x) return x end
+
+local function plot_actions_gen(t)
+
+ local function ident_ast(id)
+ local i = t:col_index(id)
+ if not i then error('unknown column name: '..id) end
+ return {name= id, index= i}
+ end
+
+ return {
+ infix = infix_ast,
+ prefix = prefix_ast,
+ ident = ident_ast,
+ enum = itself,
+ func_eval = func_eval_ast,
+ number = itself,
+ exprlist = function(a, ls) if ls then ls[#ls+1] = a else ls = {a} end; return ls end,
+ schema = function(x, y, enums) return {x= x, y= y, enums= enums} end,
+ }
+end
+
+local stat_lookup = {
+ mean = {f = function(accu, x, n) return (accu * (n-1) + x) / n end},
+ stddev = {f = f_stddev, f0 = || {0, 0, 0}, fini = f_stddev_fini},
+ stddevp = {f = f_stddev, f0 = || {0, 0, 0}, fini = f_stddevp_fini},
+ var = {f = f_stddev, f0 = || {0, 0, 0}, fini = f_var_fini},
+ sum = {f = function(accu, x, n) return accu + x end},
+ count = {f = function(accu, x, n) return n end},
+}
+
+local function get_stat(expr)
+ if expr.func and stat_lookup[expr.func]then
+ return expr.func, expr.arg
+ else
+ return 'mean', expr
+ end
+end
+local rect, webcolor, path = graph.rect, graph.webcolor, graph.path
+
+local barplot = {}
+
+function barplot.create(labels, enums, val)
local plt = graph.plot()
local pad = 0.1
local dx = (1 - 2*pad) / #enums
@@ -200,20 +215,19 @@ local function gdt_table_barplot(t, jxs, jys, jes, opt)
cat[2*p-1] = p - 0.5
cat[2*p] = collate(lab)
end
+ return plt, cat
+end
- plt:set_categories('x', cat)
- plt.xlab_angle = math.pi/4
-
+function barplot.legend(plt, labels, enums)
if #enums > 1 then
for k = 1, #enums do
plt:legend(collate(enums[k], '/'), webcolor(k), 'square')
end
end
-
- if show_plot then plt:show() end
- return plt
end
+local lineplot = {}
+
local function legend_symbol(sym, dx, dy)
if sym == 'square' then
return graph.rect(5+dx, 5+dy, 15+dx, 15+dy)
@@ -234,20 +248,10 @@ local function add_legend(lg, k, symspec, color, text)
end
end
-local function gdt_table_lineplot(t, jxs, jys, jes, opt)
- local show_plot = true
- if opt then show_plot = (opt.show ~= false) end
-
- jxs = treat_column_refs(t, jxs)
- jys = treat_column_funcrefs(t, jys)
- jes = treat_column_refs(t, jes)
-
- local path, webcolor = graph.path, graph.webcolor
- local labels, enums, val = rect_funcbin(t, jxs, jys, jes)
-
- local plt, lg = graph.plot(), graph.plot()
+function lineplot.create(labels, enums, val)
+ local plt = graph.plot()
plt.pad, plt.clip = true, false
- lg.units, lg.clip = false, false
+
for q, en in ipairs(enums) do
local ln = path()
local path_method = ln.move_to
@@ -262,45 +266,109 @@ local function gdt_table_lineplot(t, jxs, jys, jes, opt)
end
plt:add(ln, webcolor(q), {{'stroke', width=line_width}})
plt:add(ln, webcolor(q), {{'marker', size=8, mark=q}})
+ end
+
+ local cat = {}
+ for p, lab in ipairs(labels) do
+ cat[2*p-1] = p - 0.5
+ cat[2*p] = collate(lab)
+ end
+
+ return plt, cat
+end
- if #enums > 1 then
+function lineplot.legend(plt, labels, enums)
+ if #enums > 1 then
+ local lg = graph.plot()
+ lg.units, lg.clip = false, false
+ for q, en in ipairs(enums) do
local label = collate(en)
add_legend(lg, q, 'line', webcolor(q), label)
add_legend(lg, q, q, webcolor(q))
end
+ plt:set_legend(lg)
end
+end
- plt:set_legend(lg)
+local function idents_get_column_indexes(t, exprs)
+ local jxs = {}
+ for i, expr in ipairs(exprs) do
+ if not expr.name then error('invalid enumeration factor') end
+ jxs[i] = t:col_index(expr.name)
+ end
+ return jxs
+end
- local cat = {}
- for p, lab in ipairs(labels) do
- cat[2*p-1] = p - 0.5
- cat[2*p] = collate(lab)
+local function stat_expr_get_functions(exprs)
+ local jys = {}
+ for i, expr in ipairs(exprs) do
+ local stat_name, yexpr = get_stat(expr)
+ local s = stat_lookup[stat_name]
+ jys[i] = {
+ f = s.f,
+ f0 = s.f0,
+ fini = s.fini,
+ name = expr_print.expr(expr),
+ expr = yexpr,
+ }
end
+ return jys
+end
+
+local function expr_get_functions(exprs)
+ local jys = {}
+ for i, expr in ipairs(exprs) do
+ jys[i] = {
+ name = expr_print.expr(expr),
+ expr = expr,
+ }
+ end
+ return jys
+end
+
+local function schema_from_plot_descr(plot_descr, t)
+ local l = mini.lexer(plot_descr)
+ local actions = plot_actions_gen(t)
+ return mini.gschema(l, actions)
+end
+
+local function gdt_table_category_plot(plotter, t, plot_descr, opt)
+ local show_plot = true
+ if opt then show_plot = (opt.show ~= false) end
+
+ local schema = schema_from_plot_descr(plot_descr, t)
+ local jxs = idents_get_column_indexes(t, schema.x)
+ local jys = stat_expr_get_functions(schema.y)
+ local jes = idents_get_column_indexes(t, schema.enums)
+
+ local labels, enums, val = rect_funcbin(t, jxs, jys, jes)
+
+ local plt, cat = plotter.create(labels, enums, val)
plt:set_categories('x', cat)
plt.xlab_angle = math.pi/4
- if show_plot then plt:show() end
+ plotter.legend(plt, labels, enums)
+ if show_plot then plt:show() end
return plt
end
-local function gdt_table_xyplot(t, jx, jys, jes, opt)
+local function gdt_table_xyplot(t, plot_descr, opt)
local show_plot = true
if opt then show_plot = (opt.show ~= false) end
- local path, webcolor = graph.path, graph.webcolor
-
local use_lines = opt and opt.lines
- local use_markers = true
- if opt then
- if opt.markers == false then use_markers = false end
- end
+ local use_markers = opt and (opt.markers ~= false) or true
- jx = type(jx) == 'number' and jx or t:col_index(jx)
- jes = treat_column_refs(t, jes)
- jys = treat_column_refs(t, jys)
+ local schema = schema_from_plot_descr(plot_descr, t)
+ local jxs = expr_get_functions(schema.x)
+ local jys = expr_get_functions(schema.y)
+ local jes = idents_get_column_indexes(t, schema.enums)
+ local jx = jxs[1]
+
+ local eval_set, eval_scope = eval_scalar_gen(t)
+ local eval = expr_print.eval
local enums = {}
local n = #t
@@ -314,14 +382,15 @@ local function gdt_table_xyplot(t, jx, jys, jes, opt)
lg.units, lg.clip = false, false
local mult = #enums * #jys
for p = 1, #jys do
- local name = t:get_header(jys[p])
+ local name = jys[p].name
for q, enum in ipairs(enums) do
local ln = path()
local path_method = ln.move_to
for i = 1, n do
+ eval_set(i)
local e = collate_factors(t, i, jes)
if compare_list(enum, e) then
- local x, y = t:get(i, jx), t:get(i, jys[p])
+ local x, y = eval(jx.expr, eval_scope), eval(jys[p].expr, eval_scope)
if x and y then
path_method(ln, x, y)
path_method = ln.line_to
@@ -355,10 +424,11 @@ local function gdt_table_xyplot(t, jx, jys, jes, opt)
return plt
end
-local function gdt_table_reduce(t_src, jxs, jys, jes)
- jxs = treat_column_refs(t_src, jxs)
- jys = treat_column_funcrefs(t_src, jys)
- jes = treat_column_refs(t_src, jes)
+local function gdt_table_reduce(t_src, schema_descr)
+ local schema = schema_from_plot_descr(schema_descr, t_src)
+ local jxs = idents_get_column_indexes(t_src, schema.x)
+ local jys = stat_expr_get_functions(schema.y)
+ local jes = idents_get_column_indexes(t_src, schema.enums)
local labels, enums, val = rect_funcbin(t_src, jxs, jys, jes)
@@ -385,18 +455,7 @@ local function gdt_table_reduce(t_src, jxs, jys, jes)
return t
end
-local function count_args(...)
- return select('#', ...)
-end
-
-local function set_elements(X, P, i, ...)
- for k = 1, P do
- local v = select(k, ...)
- X:set(i, k, v)
- end
-end
-
-gdt.barplot = gdt_table_barplot
-gdt.plot = gdt_table_lineplot
+gdt.barplot = function(t, spec, opt) return gdt_table_category_plot(barplot, t, spec, opt) end
+gdt.plot = function(t, spec, opt) return gdt_table_category_plot(lineplot, t, spec, opt) end
gdt.xyplot = gdt_table_xyplot
gdt.reduce = gdt_table_reduce
generated by cgit v1.2.3 (git 2.39.1) at 2025年09月13日 12:39:01 +0000

AltStyle によって変換されたページ (->オリジナル) /