AP_Scripting: allow auto generation of uint32 operators

This commit is contained in:
Iampete1 2024-06-14 19:37:28 +01:00 committed by Peter Barker
parent 003f931d9c
commit 574b9939a5
4 changed files with 211 additions and 114 deletions

View File

@ -854,21 +854,23 @@ global manual micros lua_micros 0 1
global manual mission_receive lua_mission_receive 0 5 depends AP_MISSION_ENABLED
userdata uint32_t creation lua_new_uint32_t 1
userdata uint32_t manual_operator __add uint32_t___add
userdata uint32_t manual_operator __sub uint32_t___sub
userdata uint32_t manual_operator __mul uint32_t___mul
userdata uint32_t manual_operator __div uint32_t___div
userdata uint32_t manual_operator __mod uint32_t___mod
userdata uint32_t operator_getter coerce_to_uint32_t
userdata uint32_t operator +
userdata uint32_t operator -
userdata uint32_t operator *
userdata uint32_t operator /
-- We know name of the generated function so we can point at it again with a manual operator so idiv is the same as div
userdata uint32_t manual_operator __idiv uint32_t___div
userdata uint32_t manual_operator __band uint32_t___band
userdata uint32_t manual_operator __bor uint32_t___bor
userdata uint32_t manual_operator __bxor uint32_t___bxor
userdata uint32_t manual_operator __shl uint32_t___shl
userdata uint32_t manual_operator __shr uint32_t___shr
userdata uint32_t manual_operator __eq uint32_t___eq
userdata uint32_t manual_operator __lt uint32_t___lt
userdata uint32_t manual_operator __le uint32_t___le
userdata uint32_t manual_operator __bnot uint32_t___bnot
userdata uint32_t operator %
userdata uint32_t operator &
userdata uint32_t operator |
userdata uint32_t operator ^
userdata uint32_t operator <<
userdata uint32_t operator >>
userdata uint32_t operator ==
userdata uint32_t operator <
userdata uint32_t operator <=
userdata uint32_t operator ~
userdata uint32_t manual_operator __tostring uint32_t___tostring
userdata uint32_t manual toint uint32_t_toint 0 1
userdata uint32_t manual tofloat uint32_t_tofloat 0 1

View File

@ -31,6 +31,7 @@ char keyword_manual[] = "manual";
char keyword_global[] = "global";
char keyword_creation[] = "creation";
char keyword_manual_operator[] = "manual_operator";
char keyword_operator_getter[] = "operator_getter";
// attributes (should include the leading ' )
char keyword_attr_enum[] = "'enum";
@ -137,7 +138,17 @@ enum operator_type {
OP_SUB = (1U << 1),
OP_MUL = (1U << 2),
OP_DIV = (1U << 3),
OP_MANUAL = (1U << 4),
OP_MOD = (1U << 4),
OP_BAND = (1U << 5),
OP_BOR = (1U << 6),
OP_BXOR = (1U << 7),
OP_SHL = (1U << 8),
OP_SHR = (1U << 9),
OP_EQ = (1U << 10),
OP_LT = (1U << 11),
OP_LE = (1U << 12),
OP_BNOT = (1U << 13),
OP_MANUAL = (1U << 14),
OP_LAST
};
@ -410,6 +421,7 @@ struct userdata {
char *dependency;
char *creation; // name of a manual creation function if set, note that this will not be used internally
int creation_args; // number of args for custom creation function
char *operator_getter; // Custom function to get values for use in operators
};
static struct userdata *parsed_userdata;
@ -934,6 +946,26 @@ void handle_operator(struct userdata *data) {
operation = OP_MUL;
} else if (strcmp(operator, "/") == 0) {
operation = OP_DIV;
} else if (strcmp(operator, "%") == 0) {
operation = OP_MOD;
} else if (strcmp(operator, "&") == 0) {
operation = OP_BAND;
} else if (strcmp(operator, "|") == 0) {
operation = OP_BOR;
} else if (strcmp(operator, "^") == 0) {
operation = OP_BXOR;
} else if (strcmp(operator, "<<") == 0) {
operation = OP_SHL;
} else if (strcmp(operator, ">>") == 0) {
operation = OP_SHR;
} else if (strcmp(operator, "==") == 0) {
operation = OP_EQ;
} else if (strcmp(operator, "<") == 0) {
operation = OP_LT;
} else if (strcmp(operator, "<=") == 0) {
operation = OP_LE;
} else if (strcmp(operator, "~") == 0) {
operation = OP_BNOT;
} else {
error(ERROR_USERDATA, "Unknown operation type: %s", operator);
}
@ -1032,6 +1064,16 @@ void handle_userdata(void) {
handle_manual(node, ALIAS_TYPE_MANUAL_OPERATOR);
node->operations |= OP_MANUAL;
} else if (strcmp(type, keyword_operator_getter) == 0) {
if (node->operator_getter != NULL) {
error(ERROR_USERDATA, "Userdata only support a single getter string");
}
char *name = next_token();
if (name == NULL) {
error(ERROR_USERDATA, "Expected a getter string for %s",node->name);
}
string_copy(&(node->operator_getter), name);
} else {
error(ERROR_USERDATA, "Unknown or unsupported type for userdata: %s", type);
}
@ -1701,7 +1743,7 @@ void emit_field(const struct userdata_field *field, const char* object_name, con
break;
case TYPE_UINT32_T:
fprintf(source, "%snew_uint32_t(L);\n", indent);
fprintf(source, "%s*static_cast<uint32_t *>(luaL_checkudata(L, -1, \"uint32_t\")) = %s%s%s%s;\n", indent, object_name, object_access, field->name, index_string);
fprintf(source, "%s*check_uint32_t(L, -1) = %s%s%s%s;\n", indent, object_name, object_access, field->name, index_string);
break;
case TYPE_NONE:
error(ERROR_INTERNAL, "Can't access a NONE field");
@ -1822,7 +1864,7 @@ int emit_references(const struct argument *arg, const char * tab) {
break;
case TYPE_UINT32_T:
fprintf(source, "%snew_uint32_t(L);\n", tab);
fprintf(source, "%s*static_cast<uint32_t *>(luaL_checkudata(L, -1, \"uint32_t\")) = data_%d;\n", tab, arg_index);
fprintf(source, "%s*check_uint32_t(L, -1) = data_%d;\n", tab, arg_index);
break;
case TYPE_STRING:
fprintf(source, "%slua_pushstring(L, data_%d);\n", tab, arg_index);
@ -2142,10 +2184,28 @@ const char * get_name_for_operation(enum operator_type op) {
return "__sub";
case OP_MUL:
return "__mul";
break;
case OP_DIV:
return "__div";
break;
case OP_MOD:
return "__mod";
case OP_BAND:
return "__band";
case OP_BOR:
return "__bor";
case OP_BXOR:
return "__bxor";
case OP_SHL:
return "__shl";
case OP_SHR:
return "__shr";
case OP_EQ:
return "__eq";
case OP_LT:
return "__lt";
case OP_LE:
return "__le";
case OP_BNOT:
return "__bnot";
case OP_MANUAL:
case OP_LAST:
return NULL;
@ -2153,6 +2213,93 @@ const char * get_name_for_operation(enum operator_type op) {
return NULL;
}
const char * get_sym_for_operation(enum operator_type op) {
switch (op) {
case OP_ADD:
return "+";
case OP_SUB:
return "-";
case OP_MUL:
return "*";
case OP_DIV:
return "/";
case OP_MOD:
return "%";
case OP_BAND:
return "&";
case OP_BOR:
return "|";
case OP_BXOR:
return "^";
case OP_SHL:
return "<<";
case OP_SHR:
return ">>";
case OP_EQ:
return "==";
case OP_LT:
return "<";
case OP_LE:
return "<=";
case OP_BNOT:
return "~";
case OP_MANUAL:
case OP_LAST:
return NULL;
}
return NULL;
}
int operation_is_bool(enum operator_type op) {
switch (op) {
case OP_ADD:
case OP_SUB:
case OP_MUL:
case OP_DIV:
case OP_MOD:
case OP_BAND:
case OP_BOR:
case OP_BXOR:
case OP_SHL:
case OP_SHR:
case OP_BNOT:
case OP_MANUAL:
case OP_LAST:
return FALSE;
case OP_EQ:
case OP_LT:
case OP_LE:
return TRUE;
}
return FALSE;
}
int operation_is_unary(enum operator_type op) {
switch (op) {
case OP_ADD:
case OP_SUB:
case OP_MUL:
case OP_DIV:
case OP_MOD:
case OP_BAND:
case OP_BOR:
case OP_BXOR:
case OP_SHL:
case OP_SHR:
case OP_EQ:
case OP_LT:
case OP_LE:
case OP_MANUAL:
case OP_LAST:
return FALSE;
case OP_BNOT:
return TRUE;
}
return FALSE;
}
void emit_operators(struct userdata *data) {
trace(TRACE_USERDATA, "Emitting operators for %s", data->name);
@ -2161,39 +2308,52 @@ void emit_operators(struct userdata *data) {
start_dependency(source, data->dependency);
for (uint32_t i = 1; i < OP_LAST; i = (i << 1)) {
const char * op_name = get_name_for_operation((data->operations) & i);
const enum operator_type type = (data->operations) & i;
const char * op_name = get_name_for_operation(type);
if (op_name == NULL) {
continue;
}
char op_sym;
switch ((data->operations) & i) {
case OP_ADD:
op_sym = '+';
break;
case OP_SUB:
op_sym = '-';
break;
case OP_MUL:
op_sym = '*';
break;
case OP_DIV:
op_sym = '/';
break;
case OP_MANUAL:
case OP_LAST:
return;
const char * op_sym = get_sym_for_operation(type);
if (op_sym == NULL) {
error(ERROR_USERDATA, "No symbol for %s operation %u", data->name, type);
}
// The generated check functions return pointers, the manual getters return a value directly
const int have_getter = data->operator_getter != NULL;
const char * access = have_getter ? "" : "*";
const char * getter_prefix = have_getter ? "" : "check_";
const char * getter = have_getter ? data->operator_getter : data->sanatized_name;
fprintf(source, "static int %s_%s(lua_State *L) {\n", data->sanatized_name, op_name);
// check number of arguments
fprintf(source, " binding_argcheck(L, 2);\n");
// check the pointers
fprintf(source, " %s *ud = check_%s(L, 1);\n", data->name, data->sanatized_name);
fprintf(source, " %s *ud2 = check_%s(L, 2);\n", data->name, data->sanatized_name);
fprintf(source, " %s %sud = %s%s(L, 1);\n", data->name, access, getter_prefix, getter);
if (!operation_is_unary(type)) {
// Need two values
fprintf(source, " %s %sud2 = %s%s(L, 2);\n", data->name, access, getter_prefix, getter);
if (operation_is_bool(type)) {
// Return bool
fprintf(source, " lua_pushboolean(L, (%sud) %s (%sud2));\n", access, op_sym, access);
} else {
// Return same type
// create a container for the result
fprintf(source, " new_%s(L);\n", data->sanatized_name);
fprintf(source, " *check_%s(L, -1) = *ud %c *ud2;\n", data->sanatized_name, op_sym);
fprintf(source, " *check_%s(L, -1) = (%sud) %s (%sud2);\n", data->sanatized_name, access, op_sym, access);
}
} else {
// Only a single value, lua pushes the same value onto the stack twice, so we still check for 2 arguments
fprintf(source, " new_%s(L);\n", data->sanatized_name);
fprintf(source, " *check_%s(L, -1) = %s (%sud);\n", data->sanatized_name, op_sym, access);
}
// return the first pointer
fprintf(source, " return 1;\n");
fprintf(source, "}\n\n");

View File

@ -46,68 +46,16 @@ int lua_new_uint32_t(lua_State *L) {
return luaL_argerror(L, args, "too many arguments");
}
*static_cast<uint32_t *>(lua_newuserdata(L, sizeof(uint32_t))) = (args == 1) ? coerce_to_uint32_t(L, 1) : 0;
luaL_getmetatable(L, "uint32_t");
lua_setmetatable(L, -2);
new_uint32_t(L);
*check_uint32_t(L, -1) = (args == 1) ? coerce_to_uint32_t(L, 1) : 0;
return 1;
}
#define UINT32_T_BOX_OP(name, sym) \
int uint32_t___##name(lua_State *L) { \
binding_argcheck(L, 2); \
\
uint32_t v1 = coerce_to_uint32_t(L, 1); \
uint32_t v2 = coerce_to_uint32_t(L, 2); \
\
new_uint32_t(L); \
*static_cast<uint32_t *>(luaL_checkudata(L, -1, "uint32_t")) = v1 sym v2; \
return 1; \
}
UINT32_T_BOX_OP(add, +)
UINT32_T_BOX_OP(sub, -)
UINT32_T_BOX_OP(mul, *)
UINT32_T_BOX_OP(div, /)
UINT32_T_BOX_OP(mod, %)
UINT32_T_BOX_OP(band, &)
UINT32_T_BOX_OP(bor, |)
UINT32_T_BOX_OP(bxor, ^)
UINT32_T_BOX_OP(shl, <<)
UINT32_T_BOX_OP(shr, >>)
#define UINT32_T_BOX_OP_BOOL(name, sym) \
int uint32_t___##name(lua_State *L) { \
binding_argcheck(L, 2); \
\
uint32_t v1 = coerce_to_uint32_t(L, 1); \
uint32_t v2 = coerce_to_uint32_t(L, 2); \
\
lua_pushboolean(L, v1 sym v2); \
return 1; \
}
UINT32_T_BOX_OP_BOOL(eq, ==)
UINT32_T_BOX_OP_BOOL(lt, <)
UINT32_T_BOX_OP_BOOL(le, <=)
#define UINT32_T_BOX_OP_UNARY(name, sym) \
int uint32_t___##name(lua_State *L) { \
binding_argcheck(L, 2); \
\
uint32_t v1 = coerce_to_uint32_t(L, 1); \
\
new_uint32_t(L); \
*static_cast<uint32_t *>(luaL_checkudata(L, -1, "uint32_t")) = sym v1; \
return 1; \
}
// DO NOT SUPPORT UNARY NEGATION
UINT32_T_BOX_OP_UNARY(bnot, ~)
int uint32_t_toint(lua_State *L) {
binding_argcheck(L, 1);
uint32_t v = *static_cast<uint32_t *>(luaL_checkudata(L, 1, "uint32_t"));
const uint32_t v = *check_uint32_t(L, 1);
lua_pushinteger(L, static_cast<lua_Integer>(v));
@ -117,7 +65,8 @@ int uint32_t_toint(lua_State *L) {
int uint32_t_tofloat(lua_State *L) {
binding_argcheck(L, 1);
uint32_t v = *static_cast<uint32_t *>(luaL_checkudata(L, 1, "uint32_t"));
const uint32_t v = *check_uint32_t(L, 1);
lua_pushnumber(L, static_cast<lua_Number>(v));
@ -127,7 +76,7 @@ int uint32_t_tofloat(lua_State *L) {
int uint32_t___tostring(lua_State *L) {
binding_argcheck(L, 1);
uint32_t v = *static_cast<uint32_t *>(luaL_checkudata(L, 1, "uint32_t"));
const uint32_t v = *check_uint32_t(L, 1);
char buf[32];
hal.util->snprintf(buf, ARRAY_SIZE(buf), "%u", (unsigned)v);

View File

@ -5,20 +5,6 @@
uint32_t coerce_to_uint32_t(lua_State *L, int arg);
int lua_new_uint32_t(lua_State *L);
int uint32_t___add(lua_State *L);
int uint32_t___sub(lua_State *L);
int uint32_t___mul(lua_State *L);
int uint32_t___div(lua_State *L);
int uint32_t___mod(lua_State *L);
int uint32_t___band(lua_State *L);
int uint32_t___bor(lua_State *L);
int uint32_t___bxor(lua_State *L);
int uint32_t___shl(lua_State *L);
int uint32_t___shr(lua_State *L);
int uint32_t___eq(lua_State *L);
int uint32_t___lt(lua_State *L);
int uint32_t___le(lua_State *L);
int uint32_t___bnot(lua_State *L);
int uint32_t___tostring(lua_State *L);
int uint32_t_toint(lua_State *L);
int uint32_t_tofloat(lua_State *L);