1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
|
local expr_lexer = require 'expr-lexer'
local AST = require 'expr-actions'
local format, concat = string.format, table.concat
local oper_table = expr_lexer.operators
local ex_print
local function is_ident_simple(s)
return s:match('^[%l%u_][%w_]*$')
end
local function op_print(e, prio)
if #e == 1 then
local c, c_prio = ex_print(e[1])
if c_prio < prio then c = format('(%s)', c) end
return format("%s%s", e.operator, c)
else
local a, a_prio = ex_print(e[1])
local b, b_prio = ex_print(e[2])
if a_prio < prio then a = format('(%s)', a) end
if b_prio < prio then b = format('(%s)', b) end
local temp = (prio < 2 and "%s %s %s" or "%s%s%s")
return format(temp, a, e.operator, b)
end
end
local function exlist_print(e)
local t = {}
for k = 1, #e do t[k] = ex_print(e[k]) end
return concat(t, ', ')
end
local high_prio = expr_lexer.max_oper_prio + 1
ex_print = function(e)
if type(e) == 'number' then
return e, high_prio
elseif type(e) == 'string' then
local s = e
if not is_ident_simple(s) then s = format('[%s]', s) end
return s, high_prio
elseif e.literal then
return format('%q', e.literal), high_prio
elseif e.func then
local arg_str = ex_print(e.arg)
return format('%s(%s)', e.func, arg_str), high_prio
else
local prio = oper_table[e.operator]
local s = op_print(e, prio)
return s, prio
end
end
local function schema_print(e)
local ys = exlist_print(e.y)
local xs = exlist_print(e.x)
local cs = exlist_print(e.conds)
return format("%s ~ %s : %s", ys, xs, cs)
end
local function eval_operator(op, a, b)
if op == '+' then return a + b
elseif op == '-' then return a - b
elseif op == '*' then return a * b
elseif op == '/' then return a / b
elseif op == '^' then return a ^ b
elseif op == '=' then return (a == b and 1 or 0)
elseif op == '>' then return (a > b and 1 or 0)
elseif op == '<' then return (a < b and 1 or 0)
elseif op == '!=' then return (a ~= b and 1 or 0)
elseif op == '>=' then return (a >= b and 1 or 0)
elseif op == '<=' then return (a <= b and 1 or 0)
elseif op == 'and' then return ((a ~= 0 and b ~= 0) and 1 or 0)
elseif op == 'or' then return ((a ~= 0 or b ~= 0) and 1 or 0)
else error('unknown operation: ' .. op) end
end
local function eval(expr, scope, ...)
if type(expr) == 'number' then
return expr
elseif type(expr) == 'string' then
return scope.ident(expr, ...)
elseif expr.literal then
return expr.literal
elseif expr.func then
local arg_value = eval(expr.arg, scope, ...)
if arg_value then
local f = scope.func(expr)
if not f then error('unknown function: ' .. expr.func) end
return f(arg_value)
end
else
if #expr == 1 then
local v = eval(expr[1], scope, ...)
if v then return -v end
else
local a = eval(expr[1], scope, ...)
local b = eval(expr[2], scope, ...)
if a and b then
return eval_operator(expr.operator, a, b)
end
end
end
end
-- return a set with all the variables referenced in a given expression
local function ref_list_rec(expr, list)
if type(expr) == 'number' then
return
elseif AST.is_variable(expr) then
local _, var_name = AST.is_variable(expr)
list[var_name] = true
elseif expr.literal then
return
elseif expr.func then
ref_list_rec(expr.arg, list)
else
if #expr == 1 then
ref_list_rec(expr[1], list)
else
ref_list_rec(expr[1], list)
ref_list_rec(expr[2], list)
end
end
end
return {schema = schema_print, expr = ex_print, expr_list = exlist_print, eval = eval, references = ref_list_rec}
|