diff --git a/libraries/AP_Motors/examples/AP_Motors_test/run_heli_comparison.py b/libraries/AP_Motors/examples/AP_Motors_test/run_heli_comparison.py index 2c8f8b9fc5..6ff651bbc6 100644 --- a/libraries/AP_Motors/examples/AP_Motors_test/run_heli_comparison.py +++ b/libraries/AP_Motors/examples/AP_Motors_test/run_heli_comparison.py @@ -16,17 +16,13 @@ import os import subprocess -import shutil import csv from matplotlib import pyplot as plt from argparse import ArgumentParser -import time -# ============================================================================== + class DataPoints: - - # -------------------------------------------------------------------------- # Instantiate the object and parse the data from the provided file # file: path to the csv file to be parsed def __init__(self, file): @@ -71,9 +67,7 @@ class DataPoints: # Make data immutable for field in self.data.keys(): self.data[field] = tuple(self.data[field]) - # -------------------------------------------------------------------------- - # -------------------------------------------------------------------------- # get the data from a given field # field: dict index, name of field data to be returned # lim_tf: limit bool, return limit cases or not @@ -86,23 +80,20 @@ class DataPoints: if (flag == lim_tf): ret.append(data) return ret - # -------------------------------------------------------------------------- - # -------------------------------------------------------------------------- def get_fields(self): return self.data.keys() - # -------------------------------------------------------------------------- -# ============================================================================== -frame_class_lookup = {6:'Single_Heli', 11:'Dual_Heli'} +frame_class_lookup = {6: 'Single_Heli', 11: 'Dual_Heli'} + +swash_type_lookup = {0: 'H3', + 1: 'H1', + 2: 'H3_140', + 3: 'H3_120', + 4: 'H4_90', + 5: 'H4_45'} -swash_type_lookup = {0:'H3', - 1:'H1', - 2:'H3_140', - 3:'H3_120', - 4:'H4_90', - 5:'H4_45',} # Run sweep over range of types def run_sweep(frame_class, swash_type, dir_name): @@ -117,26 +108,24 @@ def run_sweep(frame_class, swash_type, dir_name): print('Running motors test for frame class = %s (%i), swash = %s (%i)' % (frame_class_lookup[fc], fc, swash_type_lookup[swash], swash)) filename = '%s_%s_motor_test.csv' % (frame_class_lookup[fc], swash_type_lookup[swash]) - os.system('./build/linux/examples/AP_Motors_test s frame_class=%d swash=%d > %s/%s' % (fc,swash,dir_name,filename)) + os.system('./build/linux/examples/AP_Motors_test s frame_class=%d swash=%d > %s/%s' % (fc, swash, dir_name, filename)) print('Frame class = %s, swash = %s complete\n' % (frame_class_lookup[fc], swash_type_lookup[swash])) -# ============================================================================== if __name__ == '__main__': - BLUE = [0,0,1] - RED = [1,0,0] - BLACK = [0,0,0] - + BLUE = [0, 0, 1] + RED = [1, 0, 0] + BLACK = [0, 0, 0] # Build input parser parser = ArgumentParser(description='Find logs in which the input string is found in messages') - parser.add_argument("-H","--head", type=int, help='number of commits to roll back the head for comparing the work done') - parser.add_argument("-f","--frame-class", type=int, dest='frame_class', nargs="+", default=(6,11), help="list of frame classes to run comparison on. Defaults to test all helis.") - parser.add_argument("-s","--swash-type", type=int, dest='swash_type', nargs="+", default=(0,1,2,3,4,5), help="list of swashplate types to run comparison on. Defaults to test all types.") - parser.add_argument("-c","--compare", action='store_true', help='Compare only, do not re-run tests') - parser.add_argument("-p","--plot", action='store_true', help='Plot comparison results') + parser.add_argument("-H", "--head", type=int, help='number of commits to roll back the head for comparing the work done') + parser.add_argument("-f", "--frame-class", type=int, dest='frame_class', nargs="+", default=(6, 11), help="list of frame classes to run comparison on. Defaults to test all helis.") + parser.add_argument("-s", "--swash-type", type=int, dest='swash_type', nargs="+", default=(0, 1, 2, 3, 4, 5), help="list of swashplate types to run comparison on. Defaults to test all types.") + parser.add_argument("-c", "--compare", action='store_true', help='Compare only, do not re-run tests') + parser.add_argument("-p", "--plot", action='store_true', help='Plot comparison results') args = parser.parse_args() dir_name = 'motors_comparison' @@ -188,14 +177,14 @@ if __name__ == '__main__': print('ERROR: Could not rewind HEAD. Exited with error:\n%s\n%s' % (result.stdout, result.stderr)) quit() - # Rebuild + # Rebuild os.system('./waf clean') run_sweep(args.frame_class, args.swash_type, original_name) # Move back to active branch print('Returning to original branch, commit = %s\n' % latest_commit) - cmd = 'git switch -' + cmd = 'git switch -' result = subprocess.run([cmd], shell=True, capture_output=True, text=True) print('\n%s\n' % cmd) @@ -218,11 +207,11 @@ if __name__ == '__main__': print('\t failed!\n') print('\tInputs max change:') - INPUTS = ['Roll','Pitch','Yaw','Thr'] + INPUTS = ['Roll', 'Pitch', 'Yaw', 'Thr'] input_diff = {} for field in INPUTS: - input_diff[field] = [i-j for i,j in zip(old_points.data[field], new_points.data[field])] - print('\t\t%s: %f' % (field, max(map(abs,input_diff[field])))) + input_diff[field] = [i-j for i, j in zip(old_points.data[field], new_points.data[field])] + print('\t\t%s: %f' % (field, max(map(abs, input_diff[field])))) # Find number of motors num_motors = 0 @@ -235,15 +224,15 @@ if __name__ == '__main__': output_diff = {} for i in range(num_motors): field = 'Mot%i' % (i+1) - output_diff[field] = [i-j for i,j in zip(old_points.data[field], new_points.data[field])] - print('\t\t%s: %f' % (field, max(map(abs,output_diff[field])))) + output_diff[field] = [i-j for i, j in zip(old_points.data[field], new_points.data[field])] + print('\t\t%s: %f' % (field, max(map(abs, output_diff[field])))) print('\tLimits max change:') - LIMITS = ['LimR','LimP','LimY','LimThD','LimThU'] + LIMITS = ['LimR', 'LimP', 'LimY', 'LimThD', 'LimThU'] limit_diff = {} for field in LIMITS: - limit_diff[field] = [i-j for i,j in zip(old_points.data[field], new_points.data[field])] - print('\t\t%s: %f' % (field, max(map(abs,limit_diff[field])))) + limit_diff[field] = [i-j for i, j in zip(old_points.data[field], new_points.data[field])] + print('\t\t%s: %f' % (field, max(map(abs, limit_diff[field])))) print('\n') if not args.plot: @@ -264,25 +253,25 @@ if __name__ == '__main__': fig.suptitle('%s Outputs' % name, fontsize=16) for i in range(num_motors): field = 'Mot%i' % (i+1) - ax[0,i].plot(old_points.data[field], color=RED) - ax[0,i].plot(new_points.data[field], color=BLUE) - ax[0,i].set_ylabel(field) - ax[0,i].set_xlabel('Test No') - ax[1,i].plot(output_diff[field], color=BLACK) - ax[1,i].set_ylabel('Change in %s' % field) - ax[1,i].set_xlabel('Test No') + ax[0, i].plot(old_points.data[field], color=RED) + ax[0, i].plot(new_points.data[field], color=BLUE) + ax[0, i].set_ylabel(field) + ax[0, i].set_xlabel('Test No') + ax[1, i].plot(output_diff[field], color=BLACK) + ax[1, i].set_ylabel('Change in %s' % field) + ax[1, i].set_xlabel('Test No') plt.tight_layout(rect=[0, 0.0, 1, 0.95]) fig, ax = plt.subplots(2, 5, figsize=fig_size) fig.suptitle(name + ' Limits', fontsize=16) for i, field in enumerate(LIMITS): - ax[0,i].plot(old_points.data[field], color=RED) - ax[0,i].plot(new_points.data[field], color=BLUE) - ax[0,i].set_ylabel(field) - ax[0,i].set_xlabel('Test No') - ax[1,i].plot(limit_diff[field], color=BLACK) - ax[1,i].set_ylabel('Change in %s' % field) - ax[1,i].set_xlabel('Test No') + ax[0, i].plot(old_points.data[field], color=RED) + ax[0, i].plot(new_points.data[field], color=BLUE) + ax[0, i].set_ylabel(field) + ax[0, i].set_xlabel('Test No') + ax[1, i].plot(limit_diff[field], color=BLACK) + ax[1, i].set_ylabel('Change in %s' % field) + ax[1, i].set_xlabel('Test No') plt.tight_layout(rect=[0, 0.0, 1, 0.95]) if args.plot: