Skip to content

Commit

Permalink
add support define metamethods for record
Browse files Browse the repository at this point in the history
Signed-off-by: Jianhui Zhao <[email protected]>
  • Loading branch information
zhaojh329 committed Jul 7, 2024
1 parent 7607db9 commit 5fa3f17
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 1 deletion.
79 changes: 78 additions & 1 deletion ffi.c
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ enum {
CTYPE_FUNC,
};

enum {
METATYPE_FLAG_INDEX = 1 << 0,
METATYPE_FLAG_TOSTRING = 1 << 1
};

struct crecord;
struct carray;
struct cfunc;
Expand Down Expand Up @@ -112,6 +117,8 @@ struct crecord_field {

struct crecord {
ffi_type ft;
int mt_ref;
uint8_t mflags;
uint8_t nfield:5;
uint8_t is_union:1;
uint8_t anonymous:1;
Expand Down Expand Up @@ -622,6 +629,16 @@ static int __cdata_tostring(lua_State *L, struct cdata *cd)
static int cdata_tostring(lua_State *L)
{
struct cdata *cd = luaL_checkudata(L, 1, CDATA_MT);
struct ctype *ct = cd->ct;

if (ct->type == CTYPE_RECORD && ct->rc->mflags & METATYPE_FLAG_TOSTRING) {
lua_rawgeti(L, LUA_REGISTRYINDEX, ct->rc->mt_ref);
lua_getfield(L, -1, "__tostring");
lua_pushvalue(L, 1);
lua_call(L, 1, 1);
return 1;
}

return __cdata_tostring(L, cd);
}

Expand Down Expand Up @@ -1016,6 +1033,7 @@ static struct crecord_field *cdata_crecord_find_field(
static int cdata_index_crecord(lua_State *L, struct cdata *cd, struct ctype *ct, bool to)
{
void *ptr = cdata_type(cd) == CTYPE_PTR ? cdata_ptr_ptr(cd) : cdata_ptr(cd);
struct crecord *rc = ct->rc;
struct crecord_field *field;
size_t offset = 0;
const char *name;
Expand All @@ -1035,8 +1053,24 @@ static int cdata_index_crecord(lua_State *L, struct cdata *cd, struct ctype *ct,
lua_pop(L, 2);
}

field = cdata_crecord_find_field(ct->rc->fields, ct->rc->nfield, name, &offset);
field = cdata_crecord_find_field(rc->fields, rc->nfield, name, &offset);
if (!field) {
if (to) {
if (rc->mflags & METATYPE_FLAG_INDEX) {
lua_rawgeti(L, LUA_REGISTRYINDEX, rc->mt_ref);
lua_getfield(L, -1, "__index");

if (lua_isfunction(L, -1))
return 1;

if (lua_istable(L, -1)) {
lua_getfield(L, -1, name);
if (!lua_isnil(L, -1))
return 1;
}
}
}

__ctype_tostring(L, ct);
return luaL_error(L, "ctype '%s' has no member named '%s'", lua_tostring(L, -1), name);
}
Expand Down Expand Up @@ -1307,6 +1341,10 @@ static int ctype_gc(lua_State *L)
int i;
for (i = 0; i < ct->rc->nfield; i++)
free(ct->rc->fields[i]);

if (ct->rc->mt_ref != LUA_REFNIL)
luaL_unref(L, LUA_REGISTRYINDEX, ct->rc->mt_ref);

free(ct->rc);
}

Expand Down Expand Up @@ -1615,6 +1653,8 @@ static int cparse_record(lua_State *L, struct ctype *ct, bool is_union)
if (!ct->rc)
return luaL_error(L, "no mem");

ct->rc->mt_ref = LUA_REFNIL;

memcpy(ct->rc->fields, fields, sizeof(struct crecord_field *) * nfield);

if (named) {
Expand Down Expand Up @@ -2157,6 +2197,42 @@ static int lua_ffi_cast(lua_State *L)
return 1;
}

static int lua_ffi_metatype(lua_State *L)
{
struct ctype *ct = luaL_checkudata(L, 1, CTYPE_MT);
struct crecord *rc = ct->rc;
uint8_t mflags = 0;

luaL_argcheck(L, ct->type == CTYPE_RECORD, 1, "invalid C type");

luaL_checktype(L, 2, LUA_TTABLE);

#define FIELD_CHECK(name, flag) { \
lua_getfield(L, 2, "__" name); \
if (!lua_isnil(L, -1)) { \
mflags |= METATYPE_FLAG_##flag; \
} \
lua_pop(L, 1); \
}

FIELD_CHECK("index", INDEX);
FIELD_CHECK("tostring", TOSTRING);

#undef FIELD_CHECK

rc->mflags = mflags;

if (rc->mt_ref != LUA_REFNIL)
luaL_unref(L, LUA_REGISTRYINDEX, rc->mt_ref);

lua_pushvalue(L, 2);
rc->mt_ref = luaL_ref(L, LUA_REGISTRYINDEX);

lua_settop(L, 1);

return 1;
}

static int lua_ffi_typeof(lua_State *L)
{
lua_check_ct(L, NULL, true);
Expand Down Expand Up @@ -2356,6 +2432,7 @@ static const luaL_Reg methods[] = {

{"new", lua_ffi_new},
{"cast", lua_ffi_cast},
{"metatype", lua_ffi_metatype},
{"typeof", lua_ffi_typeof},
{"addressof", lua_ffi_addressof},
{"gc", lua_ffi_gc},
Expand Down
19 changes: 19 additions & 0 deletions tests/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,25 @@ local tests = {

p[1] = 567
assert(a[1] == 567)
end,
function()
local tp = ffi.metatype(ffi.typeof('struct Point'), {
__tostring = function(self)
return string.format('x:%d,y:%d', self.x, self.y)
end,
__index = {
add = function(self, x, y)
self.x = self.x + x
self.y = self.y + y
end
}
})

local p = ffi.new(tp, {45, 67})
assert(tostring(p) == 'x:45,y:67')

p:add(1, 1)
assert(p.x == 46 and p.y == 68)
end
}

Expand Down

0 comments on commit 5fa3f17

Please sign in to comment.