diff --git a/ros1_bridge/__init__.py b/ros1_bridge/__init__.py index e57517d7..83c75855 100644 --- a/ros1_bridge/__init__.py +++ b/ros1_bridge/__init__.py @@ -75,8 +75,7 @@ def generate_cpp(output_path, template_dir): 'mappings': data['mappings'], 'services': data['services']} expand_template(template_file, data_for_template, output_file) - unique_package_names = set( - data['ros2_package_names_msg'] + data['ros2_package_names_srv']) + unique_package_names = set(data['ros2_package_names_msg'] + data['ros2_package_names_srv']) # skip builtin_interfaces since there is a custom implementation unique_package_names -= {'builtin_interfaces'} data['ros2_package_names'] = list(unique_package_names) @@ -142,8 +141,7 @@ def generate_cpp(output_path, template_dir): if s['ros2_package'] == ros2_package_name and s['ros2_name'] == interface.message_name], } - template_file = os.path.join( - template_dir, 'interface_factories.cpp.em') + template_file = os.path.join(template_dir, 'interface_factories.cpp.em') output_file = os.path.join( output_path, '%s__%s__%s__factories.cpp' % (ros2_package_name, interface_type, interface.message_name)) @@ -154,10 +152,8 @@ def generate_messages(rospack=None): ros1_msgs = get_ros1_messages(rospack=rospack) ros2_package_names, ros2_msgs, mapping_rules = get_ros2_messages() - package_pairs = determine_package_pairs( - ros1_msgs, ros2_msgs, mapping_rules) - message_pairs = determine_message_pairs( - ros1_msgs, ros2_msgs, package_pairs, mapping_rules) + package_pairs = determine_package_pairs(ros1_msgs, ros2_msgs, mapping_rules) + message_pairs = determine_message_pairs(ros1_msgs, ros2_msgs, package_pairs, mapping_rules) mappings = [] # add custom mapping for builtin_interfaces @@ -177,8 +173,7 @@ def generate_messages(rospack=None): msg_idx.ros2_put(ros2_msg) for ros1_msg, ros2_msg in message_pairs: - mapping = determine_field_mapping( - ros1_msg, ros2_msg, mapping_rules, msg_idx) + mapping = determine_field_mapping(ros1_msg, ros2_msg, mapping_rules, msg_idx) if mapping: mappings.append(mapping) @@ -208,8 +203,7 @@ def generate_messages(rospack=None): ('%s/%s' % (m.ros1_msg.package_name, m.ros1_msg.message_name), '%s/%s' % (m.ros2_msg.package_name, m.ros2_msg.message_name)), file=sys.stderr) for d in m.depends_on_ros2_messages: - print(' -', '%s/%s' % - (d.package_name, d.message_name), file=sys.stderr) + print(' -', '%s/%s' % (d.package_name, d.message_name), file=sys.stderr) print(file=sys.stderr) return { @@ -221,12 +215,10 @@ def generate_messages(rospack=None): } -def generate_services(rospack=None, message_string_pairs=None): +def generate_services(rospack=None): ros1_srvs = get_ros1_services(rospack=rospack) ros2_pkgs, ros2_srvs, mapping_rules = get_ros2_services() - services = determine_common_services( - ros1_srvs, ros2_srvs, mapping_rules, - message_string_pairs=message_string_pairs) + services = determine_common_services(ros1_srvs, ros2_srvs, mapping_rules) return { 'services': services, 'ros2_package_names_srv': ros2_pkgs, @@ -254,8 +246,7 @@ def get_ros2_messages(): resources = ament_index_python.get_resources(resource_type) for package_name, prefix_path in resources.items(): pkgs.append(package_name) - resource, _ = ament_index_python.get_resource( - resource_type, package_name) + resource, _ = ament_index_python.get_resource(resource_type, package_name) interfaces = resource.splitlines() message_names = { i[4:-4] @@ -272,8 +263,7 @@ def get_ros2_messages(): continue if 'mapping_rules' not in export.attributes: continue - rule_file = os.path.join( - package_path, export.attributes['mapping_rules']) + rule_file = os.path.join(package_path, export.attributes['mapping_rules']) with open(rule_file, 'r') as h: content = yaml.safe_load(h) if not isinstance(content, list): @@ -309,8 +299,7 @@ def get_ros2_services(): resources = ament_index_python.get_resources(resource_type) for package_name, prefix_path in resources.items(): pkgs.append(package_name) - resource, _ = ament_index_python.get_resource( - resource_type, package_name) + resource, _ = ament_index_python.get_resource(resource_type, package_name) interfaces = resource.splitlines() service_names = { i[4:-4] @@ -327,8 +316,7 @@ def get_ros2_services(): continue if 'mapping_rules' not in export.attributes: continue - rule_file = os.path.join( - package_path, export.attributes['mapping_rules']) + rule_file = os.path.join(package_path, export.attributes['mapping_rules']) with open(rule_file, 'r') as h: content = yaml.safe_load(h) if not isinstance(content, list): @@ -389,8 +377,7 @@ def __init__(self, data, expected_package_name): self.ros2_package_name = data['ros2_package_name'] self.package_mapping = (len(data) == 2) else: - raise Exception( - 'Ignoring a rule without a ros1_package_name and/or ros2_package_name') + raise Exception('Ignoring a rule without a ros1_package_name and/or ros2_package_name') def is_package_mapping(self): return self.package_mapping @@ -432,7 +419,10 @@ def is_field_mapping(self): return self.fields_1_to_2 is not None def __str__(self): - return 'MessageMappingRule(%s::%s <-> %s::%s)' % (self.ros1_package_name, self.ros1_message_name, self.ros2_package_name, self.ros2_message_name) + return 'MessageMappingRule(%s::%s <-> %s::%s)' % (self.ros1_package_name, + self.ros1_message_name, + self.ros2_package_name, + self.ros2_message_name) class ServiceMappingRule(MappingRule): @@ -471,7 +461,10 @@ def __init__(self, data, expected_package_name): 'Mapping for package %s contains unknown field(s)' % self.ros2_package_name) def __str__(self): - return 'ServiceMappingRule(%s::%s <-> %s::%s)' % (self.ros1_package_name, self.ros1_service_name, self.ros2_package_name, self.ros2_service_name) + return 'ServiceMappingRule(%s::%s <-> %s::%s)' % (self.ros1_package_name, + self.ros1_service_name, + self.ros2_package_name, + self.ros2_service_name) def determine_package_pairs(ros1_msgs, ros2_msgs, mapping_rules): @@ -549,83 +542,104 @@ def determine_message_pairs(ros1_msgs, ros2_msgs, package_pairs, mapping_rules): return pairs + +def srv_field_mapping(ros_type, ros1_name, ros2_name): + return { + 'basic': False if '/' in ros_type else True, + 'array': True if '[]' in ros_type else False, + 'ros1': { + 'name': ros1_name, + 'type': ros_type.rstrip('[]'), + 'cpptype': ros_type.rstrip('[]').replace('/', '::') + }, + 'ros2': { + 'name': ros2_name, + 'type': ros_type.rstrip('[]'), + 'cpptype': ros_type.rstrip('[]').replace('/', '::msg::') + } + } + + def determine_common_services(ros1_srvs, ros2_srvs, mapping_rules): - pairs = set() services = [] - # determine service names considered equal between ROS 1 and ROS 2 + mapping_pairs = [(rule.ros1_package_name, rule.ros1_service_name) for rule in mapping_rules] + + # fill in rules for matching names for ros1_srv in ros1_srvs: - try: - ros2_srv = ros2_srvs[ros2_srvs.index(ros1_srv)] - except: - print ("No matching pair for %s" % str(ros1_srv), file=sys.stdout) - continue - pairs.add((ros1_srv, ros2_srv)) + # add ros1/ros2 srv pairs if a mapping doesn't already exist + # and names are the same + if(not (ros1_srv.package_name, ros1_srv.message_name) in mapping_pairs): + if(ros1_srv in ros2_srvs): + data = { + 'ros1_package_name': ros1_srv.package_name, + 'ros2_package_name': ros1_srv.package_name, + 'ros1_service_name': ros1_srv.message_name, + 'ros2_service_name': ros1_srv.message_name, + 'request_fields_1_to_2': {}, + 'response_fields_1_to_2': {} + } + mapping_rules.append(ServiceMappingRule(data, ros1_srv.package_name)) + else: + print("No matching ros2 srv for %s" % str(ros1_srv), file=sys.stdout) + continue - # add manual service mapping rules + # fill in missing fields if they match for rule in mapping_rules: - ros1_rule = Message(rule.ros1_package_name, rule.ros1_service_name) - ros2_rule = Message(rule.ros2_package_name, rule.ros2_service_name) - try: - ros1_srv = ros1_srvs[ros1_srvs.index(ros1_rule)] - except: - print ("No matching srv for rule %s" % str(ros1_rule), file=sys.stderr) - continue - try: - ros2_srv = ros2_srvs[ros2_srvs.index(ros2_rule)] - except: - print ("No matching srv for rule %s" % str(ros2_rule), file=sys.stderr) - continue - pairs.add((ros1_srv, ros2_srv)) - - for pair in pairs: - ros1_spec = load_ros1_service(pair[0]) - ros2_spec = load_ros2_service(pair[1]) + ros1_srv = Message(rule.ros1_package_name, rule.ros1_service_name) + ros2_srv = Message(rule.ros2_package_name, rule.ros2_service_name) + ros1_spec = load_ros1_service(ros1_srvs[ros1_srvs.index(ros1_srv)]) + ros2_spec = load_ros2_service(ros2_srvs[ros2_srvs.index(ros2_srv)]) + match = True ros1_fields = { 'request': ros1_spec.request.fields(), 'response': ros1_spec.response.fields() } ros2_fields = { - 'request': ros2_spec.request.fields, - 'response': ros2_spec.response.fields + 'request': [(str(field.type), field.name) for field in ros2_spec.request.fields], + 'response': [(str(field.type), field.name) for field in ros2_spec.response.fields] + } + rule_fields = { + 'request': rule.request_fields_1_to_2, + 'response': rule.response_fields_1_to_2 } output = { 'request': [], 'response': [] } - match = True for direction in ['request', 'response']: - if len(ros1_fields[direction]) != len(ros2_fields[direction]): + if(len(ros1_fields[direction]) != len(ros2_fields[direction])): match = False - break - for i, ros1_field in enumerate(ros1_fields[direction]): - ros1_type = ros1_field[0] - ros2_type = str(ros2_fields[direction][i].type) - ros1_name = ros1_field[1] - ros2_name = ros2_fields[direction][i].name - - if ros1_type != ros2_type: - match = False - break - output[direction].append({ - 'basic': False if '/' in ros1_type else True, - 'array': True if '[]' in ros1_type else False, - 'ros1': { - 'name': ros1_name, - 'type': ros1_type.rstrip('[]'), - 'cpptype': ros1_type.rstrip('[]').replace('/', '::') - }, - 'ros2': { - 'name': ros2_name, - 'type': ros2_type.rstrip('[]'), - 'cpptype': ros2_type.rstrip('[]').replace('/', '::msg::') - } - }) + print("ros1 and ros2 %s fields must have matching length %s" % + (direction, rule), file=sys.stderr) + continue + for (ros_type, ros1_name) in ros1_fields[direction]: + # if a rule exists for this item, add in the mapping + if (rule_fields[direction] and ros1_name in rule_fields[direction]): + ros2_name = rule_fields[direction][ros1_name] + + if ((ros_type, ros2_name) in ros2_fields[direction]): + output[direction].append(srv_field_mapping(ros_type, ros1_name, ros2_name)) + else: + print("Invalid rule field pair for %s<-->%s in %s" % + (ros1_name, ros2_name, rule), file=sys.stderr) + match = False + continue + # A rule doesn't exist for this item, try to match by name + else: + ros2_name = ros1_name + if ((ros_type, ros2_name) in ros2_fields[direction]): + output[direction].append(srv_field_mapping(ros_type, ros1_name, ros2_name)) + else: + print("Invalid matching field pair for %s<-->%s in %s" % + (ros1_name, ros2_name, rule), file=sys.stderr) + match = False + continue if match: services.append({ - 'ros1_name': pair[0].message_name, - 'ros2_name': pair[1].message_name, - 'ros1_package': pair[0].package_name, - 'ros2_package': pair[1].package_name, + 'ros1_name': rule.ros1_service_name, + 'ros2_name': rule.ros2_service_name, + 'ros1_package': rule.ros1_package_name, + 'ros2_package': rule.ros2_package_name, 'fields': output }) return services @@ -651,7 +665,7 @@ def get_ros1_selected_fields(ros1_field_selection, parent_ros1_spec, msg_idx): in ros1_field_selection :type msg_idx: MessageIndex - :return: a tuple of genmsg.msgs.Field objets with additional attributes `pkg_name` + :return: a tuple of genmsg.msgs.Field objects with additional attributes `pkg_name` and `msg_name` as defined by `update_ros1_field_information`, corresponding to traversing `parent_ros1_spec` recursively following `ros1_field_selection` @@ -665,14 +679,11 @@ def consume_field(field): selected_fields.append(field) fields = ros1_field_selection.split('.') - current_field = [f for f in parent_ros1_spec.parsed_fields() - if f.name == fields[0]][0] + current_field = [f for f in parent_ros1_spec.parsed_fields() if f.name == fields[0]][0] consume_field(current_field) for field in fields[1:]: - parent_ros1_spec = load_ros1_message( - msg_idx.ros1_get_from_field(current_field)) - current_field = [ - f for f in parent_ros1_spec.parsed_fields() if f.name == field][0] + parent_ros1_spec = load_ros1_message(msg_idx.ros1_get_from_field(current_field)) + current_field = [f for f in parent_ros1_spec.parsed_fields() if f.name == field][0] consume_field(current_field) return tuple(selected_fields) @@ -681,14 +692,11 @@ def consume_field(field): def get_ros2_selected_fields(ros2_field_selection, parent_ros2_spec, msg_idx): selected_fields = [] fields = ros2_field_selection.split('.') - current_field = [ - f for f in parent_ros2_spec.fields if f.name == fields[0]][0] + current_field = [f for f in parent_ros2_spec.fields if f.name == fields[0]][0] selected_fields.append(current_field) for field in fields[1:]: - parent_ros2_spec = load_ros2_message( - msg_idx.ros2_get_from_field(current_field)) - current_field = [ - f for f in parent_ros2_spec.fields if f.name == field][0] + parent_ros2_spec = load_ros2_message(msg_idx.ros2_get_from_field(current_field)) + current_field = [f for f in parent_ros2_spec.fields if f.name == field][0] selected_fields.append(current_field) return tuple(selected_fields) @@ -727,8 +735,7 @@ def determine_field_mapping(ros1_msg, ros2_msg, mapping_rules, msg_idx): for ros1_field_selection, ros2_field_selection in rule.fields_1_to_2.items(): try: ros1_selected_fields = \ - get_ros1_selected_fields( - ros1_field_selection, ros1_spec, msg_idx) + get_ros1_selected_fields(ros1_field_selection, ros1_spec, msg_idx) except IndexError: print( "A manual mapping refers to an invalid field '%s' " % ros1_field_selection + @@ -738,8 +745,7 @@ def determine_field_mapping(ros1_msg, ros2_msg, mapping_rules, msg_idx): continue try: ros2_selected_fields = \ - get_ros2_selected_fields( - ros2_field_selection, ros2_spec, msg_idx) + get_ros2_selected_fields(ros2_field_selection, ros2_spec, msg_idx) except IndexError: print( "A manual mapping refers to an invalid field '%s' " % ros2_field_selection + @@ -758,8 +764,7 @@ def determine_field_mapping(ros1_msg, ros2_msg, mapping_rules, msg_idx): if ros1_field.name.lower() == ros2_field.name: # get package name and message name from ROS 1 field type if ros2_field.type.pkg_name: - update_ros1_field_information( - ros1_field, ros1_msg.package_name) + update_ros1_field_information(ros1_field, ros1_msg.package_name) mapping.add_field_pair(ros1_field, ros2_field) break else: @@ -783,8 +788,7 @@ def determine_field_mapping(ros1_msg, ros2_msg, mapping_rules, msg_idx): def load_ros1_message(ros1_msg): msg_context = genmsg.MsgContext.create_default() - message_path = os.path.join( - ros1_msg.prefix_path, ros1_msg.message_name + '.msg') + message_path = os.path.join(ros1_msg.prefix_path, ros1_msg.message_name + '.msg') try: spec = genmsg.msg_loader.load_msg_from_file( msg_context, message_path, '%s/%s' % (ros1_msg.package_name, ros1_msg.message_name)) @@ -795,12 +799,10 @@ def load_ros1_message(ros1_msg): def load_ros1_service(ros1_srv): srv_context = genmsg.MsgContext.create_default() - srv_path = os.path.join(ros1_srv.prefix_path, - ros1_srv.message_name + '.srv') + srv_path = os.path.join(ros1_srv.prefix_path, ros1_srv.message_name + '.srv') srv_name = '%s/%s' % (ros1_srv.package_name, ros1_srv.message_name) try: - spec = genmsg.msg_loader.load_srv_from_file( - srv_context, srv_path, srv_name) + spec = genmsg.msg_loader.load_srv_from_file(srv_context, srv_path, srv_name) except genmsg.InvalidMsgSpec: return None return spec @@ -811,8 +813,7 @@ def load_ros2_message(ros2_msg): ros2_msg.prefix_path, 'share', ros2_msg.package_name, 'msg', ros2_msg.message_name + '.msg') try: - spec = rosidl_adapter.parser.parse_message_file( - ros2_msg.package_name, message_path) + spec = rosidl_adapter.parser.parse_message_file(ros2_msg.package_name, message_path) except rosidl_adapter.parser.InvalidSpecification: return None return spec @@ -823,8 +824,7 @@ def load_ros2_service(ros2_srv): ros2_srv.prefix_path, 'share', ros2_srv.package_name, 'srv', ros2_srv.message_name + '.srv') try: - spec = rosidl_adapter.parser.parse_service_file( - ros2_srv.package_name, srv_path) + spec = rosidl_adapter.parser.parse_service_file(ros2_srv.package_name, srv_path) except rosidl_adapter.parser.InvalidSpecification: return None return spec