lua-users home
lua-l archive

PATCH Lua 5.2 reentrant Lua metamethods

[Date Prev][Date Next][Thread Prev][Thread Next] [Date Index] [Thread Index]


"Reentrant" in this case refers to executing a function call by restarting the current loop. (To be honest, I don't like using the term this way, but it was already written in the code, see CIST_REENTRY in lstate.h.) This patch makes metamethod calls written in Lua reenter the execution loop. This avoids nesting calls to luaV_execute which will trigger a "C stack overflow" error on recursive metamethods. I encountered this because of a recent thread. (http://lua-users.org/lists/lua-l/2010-09/msg00644.html) I didn't have to change very much because the most critical part, finishing an operation after the call returns, has already been implemented in luaV_finishOp for yieldable metamethods.
Very lightly tested at this point.
debug = require "debug"
local rawlen = table.getn
local MT = {
 -- Deep compare tables
 __eq = function(a,b)
 if rawlen(a)==rawlen(b) then
 for i=1,rawlen(a) do
 if a[i]~=b[i] then return false end
 end
 return true
 end
 return false
 end,
 -- Completely frivolous
 __concat = function(a,b)
 local s = {}
 if type(a)=='table' then
 for i=1,rawlen(a) do
 s[#s+1] = a[i] .. b
 end
 s[#s+1] = b
 else
 s[#s+1] = a
 for i=1,rawlen(b) do
 s[#s+1] = a .. b[i]
 end
 end
 return table.concat(s)
 end}
local T = '*'
local U = '*'
for i=1,200 do
 T = setmetatable({T},MT)
 U = setmetatable({U},MT)
end
print(T==U)
print(T..'-'..T) -- Fun fact: this always produces symmetrical output
-- Sum(1..n)
debug.setmetatable(1, {
 __len = function(n)
 if n == 1 then return 1 end
 return n + #(n-1)
 end
 })
-- Find more palindromes
for i=5,500 do
 local s=tostring(#i)
 local a,b=1,#s
 while b>a do
 if string.sub(s,a,a)~=string.sub(s,b,b) then
 break
 end
 a=a+1
 b=b-1
 end
 if b<=a then
 print(i,#i)
 end
end
This patch and sample code is public domain.
--
- tom
telliamed@whoopdedo.org
diff -urN lua-5.2.0-work4-orig/src/lvm.c lua-5.2.0-work4/src/lvm.c
--- lua-5.2.0-work4-orig/src/lvm.c	2010年06月30日 10:11:17 -0400
+++ lua-5.2.0-work4/src/lvm.c	2010年09月22日 02:23:13 -0400
@@ -81,8 +81,8 @@
 }
 
 
-static void callTM (lua_State *L, const TValue *f, const TValue *p1,
- const TValue *p2, TValue *p3, int hasres) {
+static int callTM (lua_State *L, const TValue *f, const TValue *p1,
+ const TValue *p2, TValue *p3, int hasres, int reent) {
 ptrdiff_t result = savestack(L, p3);
 setobj2s(L, L->top++, f); /* push function */
 setobj2s(L, L->top++, p1); /* 1st argument */
@@ -91,15 +91,22 @@
 setobj2s(L, L->top++, p3); /* 3rd argument */
 luaD_checkstack(L, 0);
 /* metamethod may yield only when called from Lua code */
- luaD_call(L, L->top - (4 - hasres), hasres, isLua(L->ci));
+ if (reent) {
+ if (!luaD_precall(L, L->top - (4 - hasres), hasres))
+ return 0;
+ luaC_checkGC(L);
+ }
+ else
+ luaD_call(L, L->top - (4 - hasres), hasres, isLua(L->ci));
 if (hasres) { /* if has result, move it to its place */
 p3 = restorestack(L, result);
 setobjs2s(L, p3, --L->top);
 }
+ return 1;
 }
 
 
-void luaV_gettable (lua_State *L, const TValue *t, TValue *key, StkId val) {
+int luaV_gettable_ (lua_State *L, const TValue *t, TValue *key, StkId val, int reent) {
 int loop;
 for (loop = 0; loop < MAXTAGLOOP; loop++) {
 const TValue *tm;
@@ -109,23 +116,23 @@
 if (!ttisnil(res) || /* result is not nil? */
 (tm = fasttm(L, h->metatable, TM_INDEX)) == NULL) { /* or no TM? */
 setobj2s(L, val, res);
- return;
+ return 1;
 }
 /* else will try the tag method */
 }
 else if (ttisnil(tm = luaT_gettmbyobj(L, t, TM_INDEX)))
 luaG_typeerror(L, t, "index");
 if (ttisfunction(tm)) {
- callTM(L, tm, t, key, val, 1);
- return;
+ return callTM(L, tm, t, key, val, 1, reent);
 }
 t = tm; /* else repeat with 'tm' */
 }
 luaG_runerror(L, "loop in gettable");
+ return 1;
 }
 
 
-void luaV_settable (lua_State *L, const TValue *t, TValue *key, StkId val) {
+int luaV_settable_ (lua_State *L, const TValue *t, TValue *key, StkId val, int reent) {
 int loop;
 TValue temp;
 for (loop = 0; loop < MAXTAGLOOP; loop++) {
@@ -137,33 +144,31 @@
 (tm = fasttm(L, h->metatable, TM_NEWINDEX)) == NULL) { /* or no TM? */
 setobj2t(L, oldval, val);
 luaC_barrierback(L, obj2gco(h), val);
- return;
+ return 1;
 }
 /* else will try the tag method */
 }
 else if (ttisnil(tm = luaT_gettmbyobj(L, t, TM_NEWINDEX)))
 luaG_typeerror(L, t, "index");
 if (ttisfunction(tm)) {
- callTM(L, tm, t, key, val, 0);
- return;
+ return callTM(L, tm, t, key, val, 0, reent);
 }
 /* else repeat with 'tm' */
 setobj(L, &temp, tm); /* avoid pointing inside table (may rehash) */
 t = &temp;
 }
 luaG_runerror(L, "loop in settable");
+ return 1;
 }
 
-
-static int call_binTM (lua_State *L, const TValue *p1, const TValue *p2,
- StkId res, TMS event) {
+static const TValue *get_binTM (lua_State *L, const TValue *p1, const TValue *p2,
+ TMS event) {
 const TValue *tm = luaT_gettmbyobj(L, p1, event); /* try first operand */
 if (ttisnil(tm))
 tm = luaT_gettmbyobj(L, p2, event); /* try second operand */
- if (ttisnil(tm)) return 0;
+ if (ttisnil(tm)) return NULL;
 if (event == TM_UNM) p2 = luaO_nilobject;
- callTM(L, tm, p1, p2, res, 1);
- return 1;
+ return tm;
 }
 
 
@@ -183,10 +188,11 @@
 
 static int call_orderTM (lua_State *L, const TValue *p1, const TValue *p2,
 TMS event) {
- if (!call_binTM(L, p1, p2, L->top, event))
+ const TValue *tm = get_binTM(L, p1, p2, event);
+ if (!tm)
 return -1; /* no metamethod */
- else
- return !l_isfalse(L->top);
+ callTM(L, tm, p1, p2, L->top, 1, 0);
+ return !l_isfalse(L->top);
 }
 
 
@@ -212,33 +218,49 @@
 }
 
 
-int luaV_lessthan (lua_State *L, const TValue *l, const TValue *r) {
- int res;
+int luaV_lessthan_ (lua_State *L, const TValue *l, const TValue *r, int reent) {
 if (ttisnumber(l) && ttisnumber(r))
 return luai_numlt(L, nvalue(l), nvalue(r));
 else if (ttisstring(l) && ttisstring(r))
 return l_strcmp(rawtsvalue(l), rawtsvalue(r)) < 0;
- else if ((res = call_orderTM(L, l, r, TM_LT)) != -1)
- return res;
+ else {
+ const TValue *tm = get_binTM(L, l, r, TM_LT);
+ if (tm != NULL) {
+ if (callTM(L, tm, l, r, L->top, 1, reent))
+ return !l_isfalse(L->top);
+ else
+ return -1;
+ }
+ }
 return luaG_ordererror(L, l, r);
 }
 
 
-int luaV_lessequal (lua_State *L, const TValue *l, const TValue *r) {
- int res;
+int luaV_lessequal_ (lua_State *L, const TValue *l, const TValue *r, int reent) {
 if (ttisnumber(l) && ttisnumber(r))
 return luai_numle(L, nvalue(l), nvalue(r));
 else if (ttisstring(l) && ttisstring(r))
 return l_strcmp(rawtsvalue(l), rawtsvalue(r)) <= 0;
- else if ((res = call_orderTM(L, l, r, TM_LE)) != -1) /* first try `le' */
- return res;
- else if ((res = call_orderTM(L, r, l, TM_LT)) != -1) /* else try `lt' */
- return !res;
+ else {
+ const TValue *tm = get_binTM(L, l, r, TM_LE); /* first try `le' */
+ if (tm != NULL) {
+ if (callTM(L, tm, l, r, L->top, 1, reent))
+ return !l_isfalse(L->top);
+ else
+ return -1;
+ }
+ else if ((tm = get_binTM(L, r, l, TM_LT)) != NULL) { /* else try `lt' */
+ if (callTM(L, tm, r, l, L->top, 1, reent))
+ return l_isfalse(L->top);
+ else
+ return -1;
+ }
+ }
 return luaG_ordererror(L, l, r);
 }
 
 
-int luaV_equalval_ (lua_State *L, const TValue *t1, const TValue *t2) {
+int luaV_equalval_ (lua_State *L, const TValue *t1, const TValue *t2, int reent) {
 const TValue *tm;
 lua_assert(ttype(t1) == ttype(t2));
 switch (ttype(t1)) {
@@ -261,19 +283,25 @@
 default: return gcvalue(t1) == gcvalue(t2);
 }
 if (tm == NULL) return 0; /* no TM? */
- callTM(L, tm, t1, t2, L->top, 1); /* call TM */
- return !l_isfalse(L->top);
+ if (callTM(L, tm, t1, t2, L->top, 1, reent)) /* call TM */
+ return !l_isfalse(L->top);
+ else
+ return -1;
 }
 
 
-void luaV_concat (lua_State *L, int total) {
+int luaV_concat_ (lua_State *L, int total, int reent) {
 lua_assert(total >= 2);
 do {
 StkId top = L->top;
 int n = 2; /* number of elements handled in this pass (at least 2) */
 if (!(ttisstring(top-2) || ttisnumber(top-2)) || !tostring(L, top-1)) {
- if (!call_binTM(L, top-2, top-1, top-2, TM_CONCAT))
+ const TValue *tm = get_binTM(L, top-2, top-1, TM_CONCAT);
+ if (!tm)
 luaG_concaterror(L, top-2, top-1);
+ else
+ if (!callTM(L, tm, top-2, top-1, top-2, 1, reent))
+ return 0;
 }
 else if (tsvalue(top-1)->len == 0) /* second operand is empty? */
 (void)tostring(L, top - 2); /* result is first operand */
@@ -303,10 +331,11 @@
 total -= n-1; /* got 'n' strings to create 1 new */
 L->top -= n-1; /* poped 'n' strings and pushed one */
 } while (total > 1); /* repeat until only 1 result left */
+ return 1;
 }
 
 
-void luaV_objlen (lua_State *L, StkId ra, const TValue *rb) {
+int luaV_objlen_ (lua_State *L, StkId ra, const TValue *rb, int reent) {
 const TValue *tm;
 switch (ttype(rb)) {
 case LUA_TTABLE: {
@@ -314,11 +343,11 @@
 tm = fasttm(L, h->metatable, TM_LEN);
 if (tm) break; /* metamethod? break switch to call it */
 setnvalue(ra, cast_num(luaH_getn(h))); /* else primitive len */
- return;
+ return 1;
 }
 case LUA_TSTRING: {
 setnvalue(ra, cast_num(tsvalue(rb)->len));
- return;
+ return 1;
 }
 default: { /* try metamethod */
 tm = luaT_gettmbyobj(L, rb, TM_LEN);
@@ -327,12 +356,12 @@
 break;
 }
 }
- callTM(L, tm, rb, luaO_nilobject, ra, 1);
+ return callTM(L, tm, rb, luaO_nilobject, ra, 1, reent);
 }
 
 
-void luaV_arith (lua_State *L, StkId ra, const TValue *rb,
- const TValue *rc, TMS op) {
+int luaV_arith_ (lua_State *L, StkId ra, const TValue *rb,
+ const TValue *rc, TMS op, int reent) {
 TValue tempb, tempc;
 const TValue *b, *c;
 if ((b = luaV_tonumber(rb, &tempb)) != NULL &&
@@ -340,8 +369,14 @@
 lua_Number res = luaO_arith(op - TM_ADD + LUA_OPADD, nvalue(b), nvalue(c));
 setnvalue(ra, res);
 }
- else if (!call_binTM(L, rb, rc, ra, op))
- luaG_aritherror(L, rb, rc);
+ else {
+ const TValue *tm = get_binTM(L, rb, rc, op);
+ if (!tm)
+ luaG_aritherror(L, rb, rc);
+ else
+ return callTM(L, tm, rb, rc, ra, 1, reent);
+ }
+ return 1;
 }
 
 
@@ -393,7 +428,7 @@
 /*
 ** finish execution of an opcode interrupted by an yield
 */
-void luaV_finishOp (lua_State *L) {
+int luaV_finishOp_ (lua_State *L, int reent) {
 CallInfo *ci = L->ci;
 StkId base = ci->u.l.base;
 Instruction inst = *(ci->u.l.savedpc - 1); /* interrupted instruction */
@@ -425,7 +460,8 @@
 setobj2s(L, top - 2, top); /* put TM result in proper position */
 if (total > 1) { /* are there elements to concat? */
 L->top = top - 1; /* top is one after last element (at top-2) */
- luaV_concat(L, total); /* concat them (may yield again) */
+ if (!luaV_concat_(L, total, reent)) /* concat them (may yield again) */
+ return 0;
 }
 /* move final result to final position */
 setobj2s(L, ci->u.l.base + GETARG_A(inst), L->top - 1);
@@ -446,6 +482,7 @@
 break;
 default: lua_assert(0);
 }
+ return 1;
 }
 
 
@@ -470,6 +507,11 @@
 
 
 #define Protect(x)	{ {x;}; base = ci->u.l.base; }
+/* restart luaV_execute over new Lua function */
+#define Reentrant(x)	if(x){ \
+ ci = L->ci; \
+ ci->callstatus |= CIST_REENTRY; \
+ goto newframe; }
 
 #define checkGC(L)	Protect(luaC_checkGC(L); luai_threadyield(L);)
 
@@ -481,7 +523,7 @@
 lua_Number nb = nvalue(rb), nc = nvalue(rc); \
 setnvalue(ra, op(L, nb, nc)); \
 } \
- else { Protect(luaV_arith(L, ra, rb, rc, tm)); } }
+ else { Protect(Reentrant(!luaV_arith_(L, ra, rb, rc, tm, 1))); } }
 
 
 #define vmdispatch(o)	switch(o)
@@ -534,14 +576,14 @@
 )
 vmcase(OP_GETTABUP,
 int b = GETARG_B(i);
- Protect(luaV_gettable(L, cl->upvals[b]->v, RKC(i), ra));
+ Protect(Reentrant(!luaV_gettable_(L, cl->upvals[b]->v, RKC(i), ra, 1)));
 )
 vmcase(OP_GETTABLE,
- Protect(luaV_gettable(L, RB(i), RKC(i), ra));
+ Protect(Reentrant(!luaV_gettable_(L, RB(i), RKC(i), ra, 1)));
 )
 vmcase(OP_SETTABUP,
 int a = GETARG_A(i);
- Protect(luaV_settable(L, cl->upvals[a]->v, RKB(i), RKC(i)));
+ Protect(Reentrant(!luaV_settable_(L, cl->upvals[a]->v, RKB(i), RKC(i), 1)));
 )
 vmcase(OP_SETUPVAL,
 UpVal *uv = cl->upvals[GETARG_B(i)];
@@ -549,7 +591,7 @@
 luaC_barrier(L, uv, ra);
 )
 vmcase(OP_SETTABLE,
- Protect(luaV_settable(L, ra, RKB(i), RKC(i)));
+ Protect(Reentrant(!luaV_settable_(L, ra, RKB(i), RKC(i), 1)));
 )
 vmcase(OP_NEWTABLE,
 int b = GETARG_B(i);
@@ -563,7 +605,7 @@
 vmcase(OP_SELF,
 StkId rb = RB(i);
 setobjs2s(L, ra+1, rb);
- Protect(luaV_gettable(L, rb, RKC(i), ra));
+ Protect(Reentrant(!luaV_gettable_(L, rb, RKC(i), ra, 1)));
 )
 vmcase(OP_ADD,
 arith_op(luai_numadd, TM_ADD);
@@ -590,7 +632,7 @@
 setnvalue(ra, luai_numunm(L, nb));
 }
 else {
- Protect(luaV_arith(L, ra, rb, rb, TM_UNM));
+ Protect(Reentrant(!luaV_arith_(L, ra, rb, rb, TM_UNM, 1)));
 }
 )
 vmcase(OP_NOT,
@@ -598,13 +640,13 @@
 setbvalue(ra, res);
 )
 vmcase(OP_LEN,
- Protect(luaV_objlen(L, ra, RB(i)));
+ Protect(Reentrant(!luaV_objlen_(L, ra, RB(i), 1)));
 )
 vmcase(OP_CONCAT,
 int b = GETARG_B(i);
 int c = GETARG_C(i);
 L->top = base + c + 1; /* mark the end of concat operands */
- Protect(luaV_concat(L, c-b+1); checkGC(L);)
+ Protect(Reentrant(!luaV_concat_(L, c-b+1, 1)); checkGC(L);)
 L->top = ci->top; /* restore top */
 setobjs2s(L, RA(i), base+b);
 )
@@ -615,21 +657,27 @@
 TValue *rb = RKB(i);
 TValue *rc = RKC(i);
 Protect(
- if (equalobj(L, rb, rc) == GETARG_A(i))
+ int res = (ttype(rb) == ttype(rc)) ? luaV_equalval_(L, rb, rc, 1) : 0;
+ Reentrant(res==-1);
+ if (res == GETARG_A(i))
 dojump(GETARG_sBx(*ci->u.l.savedpc));
 )
 ci->u.l.savedpc++;
 )
 vmcase(OP_LT,
 Protect(
- if (luaV_lessthan(L, RKB(i), RKC(i)) == GETARG_A(i))
+ int res = luaV_lessthan_(L, RKB(i), RKC(i), 1);
+ Reentrant(res==-1);
+ if (res == GETARG_A(i))
 dojump(GETARG_sBx(*ci->u.l.savedpc));
 )
 ci->u.l.savedpc++;
 )
 vmcase(OP_LE,
 Protect(
- if (luaV_lessequal(L, RKB(i), RKC(i)) == GETARG_A(i))
+ int res = luaV_lessequal_(L, RKB(i), RKC(i), 1);
+ Reentrant(res==-1);
+ if (res == GETARG_A(i))
 dojump(GETARG_sBx(*ci->u.l.savedpc));
 )
 ci->u.l.savedpc++;
@@ -699,9 +747,8 @@
 return; /* external invocation: return */
 else { /* invocation via reentry: continue execution */
 ci = L->ci;
- if (b) L->top = ci->top;
 lua_assert(isLua(ci));
- lua_assert(GET_OPCODE(*((ci)->u.l.savedpc - 1)) == OP_CALL);
+		 Reentrant(!luaV_finishOp_(L, 1));
 goto newframe; /* restart luaV_execute over new Lua function */
 }
 )
diff -urN lua-5.2.0-work4-orig/src/lvm.h lua-5.2.0-work4/src/lvm.h
--- lua-5.2.0-work4-orig/src/lvm.h	2009年12月17日 11:20:01 -0500
+++ lua-5.2.0-work4/src/lvm.h	2010年09月22日 01:57:07 -0400
@@ -18,25 +18,35 @@
 #define tonumber(o,n)	(ttisnumber(o) || (((o) = luaV_tonumber(o,n)) != NULL))
 
 #define equalobj(L,o1,o2) \
-	(ttype(o1) == ttype(o2) && luaV_equalval_(L, o1, o2))
+	(ttype(o1) == ttype(o2) && luaV_equalval_(L, o1, o2, 0))
 
+#define luaV_lessthan(L,o1,o2)	luaV_lessthan_(L,o1,o2,0)
+#define luaV_lessequal(L,o1,o2)	luaV_lessequal_(L,o1,o2,0)
+
+#define luaV_gettable(L,t,key,val)	luaV_gettable_(L,t,key,val,0)
+#define luaV_settable(L,t,key,val)	luaV_settable_(L,t,key,val,0)
+#define luaV_arith(L,o1,o2,res,op)	luaV_arith_(L,o1,o2,res,op,0)
+#define luaV_objlen(L,o,res)	luaV_objlen_(L,o,res,0)
+
+#define luaV_finishOp(L)	luaV_finishOp_(L,0)
+#define luaV_concat(L,n)	luaV_concat_(L,n,0)
 
 /* not to called directly */
-LUAI_FUNC int luaV_equalval_ (lua_State *L, const TValue *t1, const TValue *t2);
+LUAI_FUNC int luaV_equalval_ (lua_State *L, const TValue *t1, const TValue *t2, int reent);
 
-LUAI_FUNC int luaV_lessthan (lua_State *L, const TValue *l, const TValue *r);
-LUAI_FUNC int luaV_lessequal (lua_State *L, const TValue *l, const TValue *r);
+LUAI_FUNC int luaV_lessthan_ (lua_State *L, const TValue *l, const TValue *r, int reent);
+LUAI_FUNC int luaV_lessequal_ (lua_State *L, const TValue *l, const TValue *r, int reent);
 LUAI_FUNC const TValue *luaV_tonumber (const TValue *obj, TValue *n);
 LUAI_FUNC int luaV_tostring (lua_State *L, StkId obj);
-LUAI_FUNC void luaV_gettable (lua_State *L, const TValue *t, TValue *key,
- StkId val);
-LUAI_FUNC void luaV_settable (lua_State *L, const TValue *t, TValue *key,
- StkId val);
-LUAI_FUNC void luaV_finishOp (lua_State *L);
+LUAI_FUNC int luaV_gettable_ (lua_State *L, const TValue *t, TValue *key,
+ StkId val, int reent);
+LUAI_FUNC int luaV_settable_ (lua_State *L, const TValue *t, TValue *key,
+ StkId val, int reent);
+LUAI_FUNC int luaV_finishOp_ (lua_State *L, int reent);
 LUAI_FUNC void luaV_execute (lua_State *L);
-LUAI_FUNC void luaV_concat (lua_State *L, int total);
-LUAI_FUNC void luaV_arith (lua_State *L, StkId ra, const TValue *rb,
- const TValue *rc, TMS op);
-LUAI_FUNC void luaV_objlen (lua_State *L, StkId ra, const TValue *rb);
+LUAI_FUNC int luaV_concat_ (lua_State *L, int total, int reent);
+LUAI_FUNC int luaV_arith_ (lua_State *L, StkId ra, const TValue *rb,
+ const TValue *rc, TMS op, int reent);
+LUAI_FUNC int luaV_objlen_ (lua_State *L, StkId ra, const TValue *rb, int reent);
 
 #endif

AltStyle によって変換されたページ (->オリジナル) /