From a8b4999d965a3c12cee333391a633e71f4af817f Mon Sep 17 00:00:00 2001 From: Iampete1 Date: Thu, 2 May 2024 20:52:19 +0100 Subject: [PATCH] AP_Scripting: tests: enforce correct types in docs --- libraries/AP_Scripting/tests/docs_check.py | 100 +++++++++++++++++++-- 1 file changed, 95 insertions(+), 5 deletions(-) diff --git a/libraries/AP_Scripting/tests/docs_check.py b/libraries/AP_Scripting/tests/docs_check.py index a8bc1320dd..f2c6b120b3 100644 --- a/libraries/AP_Scripting/tests/docs_check.py +++ b/libraries/AP_Scripting/tests/docs_check.py @@ -6,14 +6,16 @@ python ./libraries/AP_Scripting/tests/docs_check.py "./libraries/AP_Scripting/do AP_FLAKE8_CLEAN ''' -import optparse, sys +import optparse, sys, re class method(object): - def __init__(self, global_name, local_name, num_args, full_line): + def __init__(self, global_name, local_name, num_args, full_line, returns, params): self.global_name = global_name self.local_name = local_name self.num_args = num_args self.full_line = full_line + self.returns = returns + self.params = params def __str__(self): ret_str = "%s\n" % (self.full_line) @@ -22,20 +24,91 @@ class method(object): ret_str += "\tFunction: %s\n" % (self.local_name) else: ret_str += "\tGlobal: %s\n" % (self.global_name) - ret_str += "\tNum Args: %s\n\n" % (self.num_args) + ret_str += "\tNum Args: %s\n" % (self.num_args) + + ret_str += "\tParams:\n" + for param_type in self.params: + ret_str += "\t\t%s\n" % param_type + + ret_str += "\tReturns:\n" + for return_type in self.returns: + ret_str += "\t\t%s\n" % return_type + + ret_str +="\n" return ret_str + def type_compare(self, A, B): + if (((len(A) == 1) and (A[0] == 'UNKNOWN')) or + ((len(B) == 1) and (B[0] == 'UNKNOWN'))): + # UNKNOWN is a special case used for manual bindings + return True + + if len(A) != len(B): + return False + + for i in range(len(A)): + if A[i] != B[i]: + return False + + return True + + def types_compare(self, A, B): + if len(A) != len(B): + return False + + for i in range(len(A)): + if not self.type_compare(A[i], B[i]): + return False + + return True + + def check_types(self, other): + if not self.types_compare(self.returns, other.returns): + return False + + if not self.types_compare(self.params, other.params): + return False + + return True + def __eq__(self, other): return (self.global_name == other.global_name) and (self.local_name == other.local_name) and (self.num_args == other.num_args) +def get_return_type(line): + try: + match = re.findall("^---@return (\w+(\|(\w+))*)", line) + all_types = match[0][0] + return all_types.split("|") + + except: + raise Exception("Could not get return type in: %s" % line) + +def get_param_type(line): + try: + match = re.findall("^---@param \w+\?? (\w+(\|(\w+))*)", line) + all_types = match[0][0] + return all_types.split("|") + + except: + raise Exception("Could not get param type in: %s" % line) + def parse_file(file_name): methods = [] + returns = [] + params = [] with open(file_name) as fp: while True: line = fp.readline() if not line: break + # Acuminate return and params to associate with next function + if line.startswith("---@return"): + returns.append(get_return_type(line)) + + if line.startswith("---@param"): + params.append(get_param_type(line)) + # only consider functions if not line.startswith("function"): continue @@ -57,12 +130,18 @@ def parse_file(file_name): # get arguments function_name, args = function_line.split("(",1) - args = args[0:args.find(")")-1] + args = args[0:args.find(")")] if len(args) == 0: num_args = 0 else: num_args = args.count(",") + 1 + # ... shows up in arg list but not @param, add a unknown param + if args.endswith("..."): + params.append(["UNKNOWN"]) + + if num_args != len(params): + raise Exception("Missing \"---@param\" for function: %s", line) # get global/class name and function name local_name = "" @@ -71,7 +150,9 @@ def parse_file(file_name): else: global_name = function_name - methods.append(method(global_name, local_name, num_args, line)) + methods.append(method(global_name, local_name, num_args, line, returns, params)) + returns = [] + params = [] return methods @@ -92,6 +173,15 @@ def compare(expected_file_name, got_file_name): print("Multiple definitions of:") print(meth) pass_check = False + + elif not meth.check_types(got): + print("Type error:") + print("Want:") + print(meth) + print("Got:") + print(got) + pass_check = False + found = True if not found: