From c0a8e6c96a2c1da0d05794a44bd437976689c6ee Mon Sep 17 00:00:00 2001 From: Peter Mullen Date: Fri, 15 Mar 2024 22:21:13 -0700 Subject: [PATCH] AP_Scripting: lua range finder driver that returns distances to a simulated sea floor for testing ArduSub surface tracking modes --- .../examples/sub_test_synthetic_seafloor.lua | 726 ++++++++++++++++++ 1 file changed, 726 insertions(+) create mode 100644 libraries/AP_Scripting/examples/sub_test_synthetic_seafloor.lua diff --git a/libraries/AP_Scripting/examples/sub_test_synthetic_seafloor.lua b/libraries/AP_Scripting/examples/sub_test_synthetic_seafloor.lua new file mode 100644 index 0000000000..a7815dec89 --- /dev/null +++ b/libraries/AP_Scripting/examples/sub_test_synthetic_seafloor.lua @@ -0,0 +1,726 @@ + +-- sub_test_synthetic_seafloor.lua +-- A simulated range finder driver returns distances based on the +-- vehicle location over a simulated seafloor topography. Used for +-- testing bottom tracking modes in ArduSub. +-- +-- The synthetic seafloor topography is defined as a 1-D height profile (h=f(x)) that is extruded +-- in the perpendicular direction to define a 2-D height function (h=f(x,y)). +-- +-- The following Parameters can be set by the test script to control how this driver behaves +-- SCR_USER1 is an index into a table of configuration bundles. +-- SCR_USER2 is the average bottom depth in meters +-- SCR_USER3 is a bit field that controls driver logging. +-- + + +local UPDATE_PERIOD_MS = 50 + +local TEST_ID_STR = "STSS" +local SCRIPT_NAME = "sub_test_synthetic_seafloor.lua" + +local enable_logger_write = true +local enable_gcs_send_data = false +local enable_gcs_send_range = false + +-- Copied from libraries/AP_RangeFinder/AP_RangeFinder.h enum RangeFinder::Type {}. +local RNGFND_TYPE_LUA = 36.0 +-- Copied from libraries/AP_RangeFinder/AP_RangeFinder.h enum RangeFinder::Status {}. +local RNGFND_STATUS_NO_DATA = 1 +local RNGFND_STATUS_GOOD = 4 +-- Copied from libraries/AP_RangeFinder/AP_RangeFinder.h +local SIGNAL_QUALITY_MIN = 0 +local SIGNAL_QUALITY_MAX = 100 + + + +------------------------------------------------------------------------------- + +-- gcs messaging function. Pass two strings and the function will drop frequent messages + +local gcs_send_funcfactory = function(name, eat_messages_period_s, msg_severity) + local gcs_send_times = {} + local gcs_eaten_count = {} + + ---@param str1 string + ---@param str2 string|nil + return function(str1, str2) + if not str1 or #str1 == 0 then return end + if not msg_severity then msg_severity = 6 end + if not eat_messages_period_s then eat_messages_period_s = 1.0 end + + local send_str + + if not str2 or #str2 == 0 then + send_str = string.format("%s: %s", name, str1) + + else + local time_curr_s = millis():tofloat() / 1000.0 + local time_first_s = gcs_send_times[str1] + if (time_first_s) then + local dur_since_first = time_curr_s - time_first_s + if dur_since_first < eat_messages_period_s then + if not gcs_eaten_count[str1] then + gcs_eaten_count[str1] = 0 + end + gcs_eaten_count[str1] = gcs_eaten_count[str1] + 1 + return + end + end + + local eaten_count = gcs_eaten_count[str1] + if eaten_count then + gcs_eaten_count[str1] = nil + send_str = string.format("%s: %s %s (+%i)", name, str1, str2, eaten_count) + else + send_str = string.format("%s: %s %s", name, str1, str2) + end + + gcs_send_times[str1] = time_curr_s + end + + gcs:send_text(msg_severity, send_str) + end +end + +local send = gcs_send_funcfactory(TEST_ID_STR, 1.0, 3) +-- A send function used to send true range messages at a different rate. +local send_quick = gcs_send_funcfactory(TEST_ID_STR, 0.5, 3) + +local function fatal_error(error) + send(string.format("FATAL ERROR '%s': %s", SCRIPT_NAME, error)) +end + +------------------------------------------------------------------------------- + +-- Profile definition and intersection used in Sea Floor Model + +local function section_factory(x0, z0, x1, z1) + + -- Use cross product to see if self intersects with seg. + -- Lines are defined by the two points (x0, z0) and (x0 + vx, z0 + vz). + local function intersect(self, seg) + local den = - self.vx*seg.vz + self.vz*seg.vx + if math.abs(den) < 1.0e-6 then + return nil + end + local x1m0 = seg.x0 - self.x0 + local z1m0 = seg.z0 - self.z0 + local t0 = - x1m0*seg.vz + z1m0*seg.vx + local t1 = self.vx*z1m0 - self.vz*x1m0 + -- The returned values are the distance along the lines where the intersection occurs. + -- The first is along self and the second is along seg + return t0 / den, t1 / den + end + + return { + x0 = x0, + z0 = z0, + vx = x1 - x0, + vz = z1 - z0, + intersect = intersect, + } +end + +-- psi is an angle in radians rotating around the y axis with zero at the z axis +local function ray_factory(x, z, psi) + return section_factory(x, z, x + math.sin(psi), z + math.cos(psi)) +end + +local function profile_factory(vertices) + local sections = {} + + -- The first section covers to -infinity + if #vertices > 0 then + local vertex1 = vertices[1] + table.insert(sections, section_factory(vertex1[1], vertex1[2], vertex1[1] - 1, vertex1[2])) + end + + local last_vertex + for _, vertex in pairs(vertices) do + if not last_vertex then + last_vertex = vertex + else + local section = section_factory(last_vertex[1], last_vertex[2], vertex[1], vertex[2]) + if math.abs(section.vx) > 1.0e-6 or math.abs(section.vy) > 1.0e-6 then + -- Only add section if it has non-zero length, otherwise ignore this vertex. + table.insert(sections, section) + last_vertex = vertex + end + end + end + + -- The last section covers to +infinity + if last_vertex then + table.insert(sections, section_factory(last_vertex[1], last_vertex[2], last_vertex[1] + 1, last_vertex[2])) + end + + + -- Returns the distance to the closest section. + sections.intersect = function(self, ray) + local d + + for i, segment in ipairs(self) do + + -- Return the distance from the start of the ray to the intersection + -- with the segment. If there is no intersection or the intersection is + -- not on the segment s < 0 or s > 1 then return nil. Note: the r distance + -- can be negative if the intersection is on the ray before its start (r < 0) + local function find_valid_intersection() + local s, r = segment:intersect(ray) + if not s then + return nil + end + if s < 0 then + return nil + end + if i ~= 1 and i ~= #self then + if s > 1 then + return nil + end + end + return r + end + + local d1 = find_valid_intersection() + if d1 then + -- Found a valid intersection, look for the shortest distance + if not d or d > d1 then + d = d1 + end + end + end + + return d + end + + return sections +end + +do -- Some code to test functionality and illustrate usage + local section = section_factory(0, 1, 1, 1) + + local function test(x, z, psi, s_expected, r_expected) + local ray = ray_factory(x, z, psi) + + -- Test the section:intersect method. + local s_actual, r_actual = section:intersect(ray) + + if not s_actual or not s_expected then + if s_actual or s_expected then + local actual_str = not s_actual and "nil" or string.format("%.2f", s_actual) + local expected_str = not s_expected and "nil" or string.format("%.2f", s_expected) + send(string.format("intersect not nil : x %.2f, z %.2f, psi %.2f", x, z, psi) .. + string.format(" : s_actual %s, s_expected %s", actual_str, expected_str)) + end + else + if math.abs(s_actual - s_expected) > 1.0e-6 or math.abs(r_actual - r_expected) > 1.0e-6 then + send(string.format("intersect failed : x %.2f, z %.2f, psi %.2f", x, z, psi) .. + string.format(" : s_actual %.2f, s_expected %.2f, r_actual %.2f, r_expected %.2f", + s_actual, s_expected, r_actual, r_expected)) + end + end + end + + test(-1, 0, 0, -1, 1) + test(0, 0, 0, 0, 1) + test(1, 0, 0, 1, 1) + test(2, 0, 0, 2, 1) + + test(-1, 0.5, 0, -1, .5) + test(0, 0.5, 0, 0, .5) + test(1, 0.5, 0, 1, .5) + test(2, 0.5, 0, 2, .5) + + test(-1, 1, 0, -1, 0) + test(0, 1, 0, 0, 0) + test(1, 1, 0, 1, 0) + test(2, 1, 0, 2, 0) + + test(-1, 2, 0, -1, -1) + test(0, 2, 0, 0, -1) + test(1, 2, 0, 1, -1) + test(2, 2, 0, 2, -1) + + test(0, 0, math.pi/2, nil, 0) + test(0, 1, math.pi/2, nil, 0) + test(1, 1, math.pi/2, nil, 0) + test(1, 2, math.pi/2, nil, 0) + + test(0, 0, math.pi/4, 1, math.sqrt(2)) + test(1, 0, math.pi/4, 2, math.sqrt(2)) + test(0, 0, -math.pi/4, -1, math.sqrt(2)) + test(1, 0, -math.pi/4, 0, math.sqrt(2)) +end + +do -- Some code to test functionality and illustrate usage + local profile = profile_factory({{0, 1}, {1, 2}}) + + local function test(x, z, psi, d_expected) + local ray = ray_factory(x, z, psi) + + -- Test the profile:intersect method. + local d_actual = profile:intersect(ray) + + if math.abs(d_actual - d_expected) > 1.0e-6 then + send(string.format("intersect failed x %.2f, z %.2f, psi %.2f, d_actual %.2f, d_expected %.2f", + x, z, psi, d_actual, d_expected)) + end + end + + test(-20, 0, 0, 1) + test(-1, 0, 0, 1) + test(0, 0, 0, 1) + test(0.5, 0, 0, 1.5) + test(1, 0, 0, 2) + test(1.5, 0, 0, 2) + test(20, 0, 0, 2) + + test(-20, 1, 0, 0) + test(-1, 1, 0, 0) + test(0, 1, 0, 0) + test(0.5, 1, 0, 0.5) + test(1, 1, 0, 1) + test(1.5, 1, 0, 1) + test(20, 1, 0, 1) + + test(-20, 2, 0, -1) + test(-1, 2, 0, -1) + test(0, 2, 0, -1) + test(0.5, 2, 0, -0.5) + test(1, 2, 0, 0) + test(1.5, 2, 0, 0) + test(20, 2, 0, 0) + + test(-20, 0, math.pi/4, math.sqrt(2)) + test(-1, 0, math.pi/4, math.sqrt(2)) + test(-0.99, 0, math.pi/4, 2*math.sqrt(2)) + test(0, 0, math.pi/4, 2*math.sqrt(2)) + test(0.5, 0, math.pi/4, 2*math.sqrt(2)) + test(1, 0, math.pi/4, 2*math.sqrt(2)) + test(1.5, 0, math.pi/4, 2*math.sqrt(2)) + test(20, 0, math.pi/4, 2*math.sqrt(2)) + + test(-20, 0, -math.pi/4, math.sqrt(2)) + test(-1, 0, -math.pi/4, math.sqrt(2)) + test(0, 0, -math.pi/4, math.sqrt(2)) + test(0.5, 0, -math.pi/4, math.sqrt(2)) + test(1, 0, -math.pi/4, math.sqrt(2)) + test(1.5, 0, -math.pi/4, 1.25*math.sqrt(2)) + test(2, 0, -math.pi/4, 1.5*math.sqrt(2)) + test(3, 0, -math.pi/4, 2*math.sqrt(2)) + test(20, 0, -math.pi/4, 2*math.sqrt(2)) +end + + +------------------------------------------------------------------------------- + +-- NoiseModel + +---@class NoiseModelConfig +---@field mean number -- Mean of the noise distribution (mean == 0 and std_dev == 0 => no gaussian noise) +---@field std_dev number -- Standard deviation of the noise distribution +---@field outlier_rate_ops number -- Rate of outliers outliers/second (0 => no outliers) +---@field outlier_mean number -- Mean of outliers distribution +---@field outlier_std_dev number -- Stardard deviation of outliers distribution +---@field delay_s number -- Delay between measurement request and measurement return (0 => no delay) +---@field callback_interval_ms number -- Delay between calls of the add noise function + +-- This factory creates a function that will take a measurement m and add noise to it +---@param config NoiseModelConfig +local function add_noise_funcfactory(config) + + local function identity_funcfactory(pre_func) + if pre_func then + return pre_func + end + return function(m) return m end + end + + + local function noise_funcfactory(mean, std_dev, rate_ops, callback_interval_ms, pre_func) + + -- Use the Box-Muller algorithm to generate normally distributed error that is added to the sample. + local function box_muller_func(m) + return m + mean + std_dev * math.sqrt(-2 * math.log(math.random())) * math.cos(2 * math.pi * math.random()) + end + + local function gaussian_noise_funcfactory() + if std_dev == 0.0 and pre_func then + return function(m) return pre_func(m) + mean end + end + + if std_dev == 0.0 then + return function(m) return m + mean end + end + + if pre_func then + return function(m) return box_muller_func(pre_func(m)) end + end + + return box_muller_func + end + + -- Just simple normally distributed noise + if rate_ops == 0.0 then + return gaussian_noise_funcfactory() + end + + -- Use poisson distribution to generate outliers + + -- Create a function to generate outliers + local outlier_func = gaussian_noise_funcfactory() + + -- Rate of outlier events in a callback interval + local rate_opi = rate_ops * callback_interval_ms / 1000.0 + + -- Poisson probability of zero events in an interval - Poisson formula is just exp in this case + local poisson_prob_zero = math.exp(-rate_opi) + + return function(m) + -- Poisson probability of 1 or more events in this interval is 1-P(0) + if math.random() > poisson_prob_zero then + return outlier_func(m) -- NOTE: pre_func is invoked in outlier_func. + end + if pre_func then + return pre_func(m) + end + return m + end + end + + + local function delay_funcfactory(delay_s, callback_interval_ms, pre_func) + + if delay_s == 0.0 or callback_interval_ms == 0 then + return identity_funcfactory(pre_func) + end + + local delay_line = {} + local delay_count = math.ceil(delay_s / callback_interval_ms * 1000.0) + if delay_count <= 0 then + return identity_funcfactory(pre_func) + end + + local next_idx = -1 + local function delay_func(m) + if pre_func then + m = pre_func(m) + end + + if next_idx < 1 then + for i = 1, delay_count do + delay_line[i] = m + end + next_idx = 1 + end + + local m_delay = delay_line[next_idx] + delay_line[next_idx] = m + next_idx = next_idx + 1 + if next_idx > #delay_line then + next_idx = 1 + end + return m_delay + end + + return delay_func + end + + local func + + -- Check for adding gaussian noise to measrement + if config.mean ~= 0.0 or config.std_dev ~= 0.0 then + func = noise_funcfactory(config.mean, config.std_dev, 0.0, 0.0) + end + + -- Check for adding an outlier measurement + if config.outlier_rate_ops ~= 0 then + func = noise_funcfactory(config.outlier_mean, config.outlier_std_dev, + config.outlier_rate_ops, config.callback_interval_ms, func) + end + + -- Check for delaying the measurement + if config.delay_s ~= 0.0 then + func = delay_funcfactory(config.delay_s, config.callback_interval_ms, func) + end + + if func == nil then + func = identity_funcfactory() + end + + return func +end + +------------------------------------------------------------------------------- + +-- Range Model + +---@class RangeModel +---@field get_range function(RangeModel, location_ud): number +---@field sub_z_m number +---@field bottom_z_m number +---@field range_m number +---@field set_origin function(RangeModel, location_ud) +---@field is_origin_valid function(RangeModel) + +---@return RangeModel +local function range_model_factory(model_bearing_N_rad, model_depth_m, vertices) + + ---@type Location_ud + local origin_loc + + local profile = profile_factory(vertices) + + ---@param sub_loc Location_ud + local function get_range(self, sub_loc) + + -- Figure out the depth of the sub in absolute frame + sub_loc:change_alt_frame(0) + self.sub_z_m = -sub_loc:alt()/100 + + -- If the origin has not been set then do not use the profile. + -- The seafloor model needs to have an origin. The origin of the model is set to + -- the location of the sub when it is armed. If the model’s origin hasn’t been set, + -- the driver returns distances based on a constant depth sea floor. + if not origin_loc then + self.bottom_z_m = model_depth_m + self.range_m = model_depth_m - self.sub_z_m + return self.range_m + end + + -- The model origin has been set so use the profile. + -- N in name means relative to Earth frame, M in name means relative to Sea floor model frame. + local sub_bearing_N_rad = origin_loc:get_bearing(sub_loc) + local sub_distance_M_m = origin_loc:get_distance(sub_loc) + + local sub_bearing_M_rad = sub_bearing_N_rad - model_bearing_N_rad + local sub_northly_M_m = math.cos(sub_bearing_M_rad) * sub_distance_M_m + + -- profile:intersect has an origin at zero so we have to add in the model + -- depth to get the seafloor depth at this location. + -- For the sea floor depth: + -- bottom_z = model_depth_m - profile(0) + -- Positive z is down, positive profile is up + local bottom_ray = ray_factory(sub_northly_M_m, 0, 0) + self.bottom_z_m = model_depth_m - profile:intersect(bottom_ray) + + -- For the sub range measurement: + -- model_range = model_depth_m - sub_z + self.range_m = self.bottom_z_m - self.sub_z_m + return self.range_m + end + + return { + get_range = get_range, + sub_z_m = 0, + bottom_z_m = model_depth_m, + range_m = model_depth_m, + set_origin = function(origin) origin_loc = origin end, + is_origin_valid = function() return origin_loc ~= nil end, + } +end + +------------------------------------------------------------------------------- + +-- Range Finder Driver + +-- The range finder backend is initialized in the update_init function. +---@type AP_RangeFinder_Backend_ud +local rngfnd_backend + +-- The range_model and add_noise_func are initialized when vehicle is armed +---@type RangeModel +local range_model + +local measurement_noise_func +local signal_quality_noise_func + + + +local function range_finder_driver(sub_loc) + local rf_state = RangeFinder_State() + + -- The full state udata must be initialized. + rf_state:last_reading(millis():toint()) + rf_state:voltage(0) + + -- If no location, then return no data + if not sub_loc then + rf_state:status(RNGFND_STATUS_NO_DATA) + rf_state:range_valid_count(0) + rf_state:distance(0) + rf_state:signal_quality(SIGNAL_QUALITY_MIN) + rngfnd_backend:handle_script_msg(rf_state) -- state as arg + return + end + + -- Generate a simulated range measurement + local true_range_m = range_model:get_range(sub_loc) + local range_m = measurement_noise_func(true_range_m) + local signal_quality = signal_quality_noise_func(SIGNAL_QUALITY_MAX) + + -- Return this measurement to the range finder backend + rf_state:status(RNGFND_STATUS_GOOD) + rf_state:range_valid_count(10) + rf_state:distance(range_m) + rf_state:signal_quality(signal_quality) + rngfnd_backend:handle_script_msg(rf_state) -- state as arg + + -- Log this data + if enable_logger_write then + logger:write('RNFN', 'sub_z,bottom_z,true_range,range,quality', 'fffff', 'mmmmm', '-----', + range_model.sub_z_m, range_model.bottom_z_m, true_range_m, range_m, signal_quality) + -- This data can be viewed in mavexplorer with the following command: + -- graph RNFN.sub_z RNFN.bottom_z RNFN.true_range RNFN.range + end + if enable_gcs_send_data then + send("RNGFND", string.format("true range %.2f, range %.2f, sub_z %.2f, bottom_z %.2f", + true_range_m, range_m, range_model.sub_z_m, range_model.bottom_z_m)) + end + if enable_gcs_send_range then + send_quick("#TR#", string.format("%7.2f", true_range_m)) + end +end + +------------------------------------------------------------------------------- + +local function initialize_model() + + -- query SCR_USERx for parameters + -- SCR_USER1 is a code for which config bundle to use + local config_index = param:get('SCR_USER1') + if not config_index then + config_index = 1 + end + -- SCR_USER2 is the bottom depth + local bottom_depth_m = param:get('SCR_USER2') + if not bottom_depth_m or bottom_depth_m < 1 then + bottom_depth_m = 50 + end + -- SCR_USER3 contains bits for logging + local logging_bits = param:get('SCR_USER3') + if not logging_bits then + logging_bits = 1 + end + + + -- Set logging flags from logging_bits + local lb_str = tostring(math.floor(logging_bits)) + enable_logger_write = string.sub(lb_str, -1, -1) == '1' + enable_gcs_send_data = string.sub(lb_str, -2, -2) == '1' + enable_gcs_send_range = string.sub(lb_str, -3, -3) == '1' + + local config_simple_ridge = {{5, 0}, {30, 10}, {40, 10}, {50, 0}} + local config_ridge_plateau = {{5, 0}, {30, 10}, {40, 10}, {50, 0}, {70, 0}, {90, -10}} + + local config_range_model = { + model_bearing_N_rad = math.pi, + vertices = config_simple_ridge, + } + + local config_measurement_noise = { + mean = 0.0, + std_dev = 0.0, + outlier_rate_ops = 0.0, + outlier_mean = 0.0, + outlier_std_dev = 0.0, + delay_s = 0.0, + callback_interval_ms = UPDATE_PERIOD_MS, + } + + local config_signal_quality_noise = { + mean = 0.0, + std_dev = 0.0, + outlier_rate_ops = 0.0, + outlier_mean = 0.0, + outlier_std_dev = 0.0, + delay_s = 0.0, + callback_interval_ms = UPDATE_PERIOD_MS, + } + + -- config_index = 1 is default with no noise + -- config_index = 2 has a little noise + if config_index == 2 then + config_range_model.vertices = config_ridge_plateau + config_measurement_noise.std_dev = .1 + config_measurement_noise.outlier_rate_ops = .2 + config_measurement_noise.outlier_mean = 5 + config_measurement_noise.outlier_std_dev = 2 + config_measurement_noise.delay_s = 0.00 + end + + range_model = range_model_factory(config_range_model.model_bearing_N_rad, + bottom_depth_m, config_range_model.vertices) + measurement_noise_func = add_noise_funcfactory(config_measurement_noise) + + + -- Constrain signal quality values + local signal_quality_noise_pre = add_noise_funcfactory(config_signal_quality_noise) + signal_quality_noise_func = function(m) + m = signal_quality_noise_pre(m) + if m > SIGNAL_QUALITY_MAX then + return SIGNAL_QUALITY_MAX + end + if m < SIGNAL_QUALITY_MIN then + return SIGNAL_QUALITY_MIN + end + return m + end + +end + +------------------------------------------------------------------------------- + +-- update functions + +local function update_run() + + local loc_c = ahrs:get_location() + + -- Check if we have to set or clear the origin + if arming:is_armed() ~= range_model.is_origin_valid() then + if arming:is_armed() then + if loc_c then + send("Starting to use the sea floor model for range data.") + range_model.set_origin(loc_c) + end + else + send("Stopping sea floor model range data. Starting to use flat sea floor for range data.") + range_model.set_origin(nil) + end + end + + -- Update with range finder driver + range_finder_driver(loc_c) + + return update_run, UPDATE_PERIOD_MS +end + +local function update_init() + if Parameter('RNGFND1_TYPE'):get() ~= RNGFND_TYPE_LUA then + return fatal_error("LUA range finder driver not enabled") + end + if rangefinder:num_sensors() < 1 then + return fatal_error("LUA range finder driver not connected") + end + rngfnd_backend = rangefinder:get_backend(0) + if not rngfnd_backend then + return fatal_error("Range Finder 1 does not exist") + end + if (rngfnd_backend:type() ~= RNGFND_TYPE_LUA) then + return fatal_error("Range Finder 1 is not a LUA driver") + end + + initialize_model() + + if not range_model or not measurement_noise_func or not signal_quality_noise_func then + return fatal_error("Could not initialize model") + end + + return update_run, 0 +end + +send(string.format("Loaded %s", SCRIPT_NAME)) + +return update_init, 0