Skip to content

Commit

Permalink
added ability to map custom service field names
Browse files Browse the repository at this point in the history
Signed-off-by: Allison Thackston <[email protected]>
  • Loading branch information
athackst committed Feb 14, 2020
1 parent f88d12a commit 309ece6
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 76 deletions.
8 changes: 6 additions & 2 deletions doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ How are ROS 1 and 2 services associated with each other?
--------------------------------------------------------

Automatic mapping between ROS 1 and 2 services is performed similar to messages.
Except that currently different field names are not supported.

How can I specify custom mapping rule for services?
---------------------------------------------------
Expand All @@ -73,7 +72,12 @@ In case of services, each mapping rule can have one of two types:
- a ``ros1_service_name``
- a ``ros2_service_name``

A custom field mapping is currently not supported for services.
3. A field mapping rule is defined by the attributes of a message mapping rule and:

- a dictionary ``request_fields_1_to_2`` or ``response_fields_1_to_2`` mapping ROS 1 field selections to ROS 2 field selections.
A field selection is a sequence of field names separated by ``.``, that specifies the path to a field starting from a message similar to the message fields.
All fields must be listed explicitly - not listed fields are not mapped implicitly when their names match.


How can I install mapping rule files?
-------------------------------------
Expand Down
18 changes: 9 additions & 9 deletions resource/interface_factories.cpp.em
Original file line number Diff line number Diff line change
Expand Up @@ -278,20 +278,20 @@ void ServiceFactory<
@[ for field in service["fields"][type.lower()]]@
@[ if field["array"]]@
req@(to).@(field["ros" + to]["name"]).resize(req@(frm).@(field["ros" + frm]["name"]).size());
auto @(field["ros1"]["name"])1_it = req1.@(field["ros1"]["name"]).begin();
auto @(field["ros2"]["name"])2_it = req2.@(field["ros2"]["name"]).begin();
auto @(field["ros" + frm]["name"])@(frm)_it = req@(frm).@(field["ros" + frm]["name"]).begin();
auto @(field["ros" + to]["name"])@(to)_it = req@(to).@(field["ros" + to]["name"]).begin();
while (
@(field["ros1"]["name"])1_it != req1.@(field["ros1"]["name"]).end() &&
@(field["ros2"]["name"])2_it != req2.@(field["ros2"]["name"]).end()
@(field["ros" + frm]["name"])@(frm)_it != req@(frm).@(field["ros" + frm]["name"]).end() &&
@(field["ros" + to]["name"])@(to)_it != req@(to).@(field["ros" + to]["name"]).end()
) {
auto & @(field["ros1"]["name"])1 = *(@(field["ros1"]["name"])1_it++);
auto & @(field["ros2"]["name"])2 = *(@(field["ros2"]["name"])2_it++);
auto & @(field["ros" + frm]["name"])@(frm) = *(@(field["ros" + frm]["name"])@(frm)_it++);
auto & @(field["ros" + to]["name"])@(to) = *(@(field["ros" + to]["name"])@(to)_it++);
@[ else]@
auto & @(field["ros1"]["name"])1 = req1.@(field["ros1"]["name"]);
auto & @(field["ros2"]["name"])2 = req2.@(field["ros2"]["name"]);
auto & @(field["ros" + frm]["name"])@(frm) = req@(frm).@(field["ros" + frm]["name"]);
auto & @(field["ros" + to]["name"])@(to) = req@(to).@(field["ros" + to]["name"]);
@[ end if]@
@[ if field["basic"]]@
@(field["ros2"]["name"])@(to) = @(field["ros1"]["name"])@(frm);
@(field["ros" + to]["name"])@(to) = @(field["ros" + frm]["name"])@(frm);
@[ else]@
Factory<@(field["ros1"]["cpptype"]),@(field["ros2"]["cpptype"])>::convert_@(frm)_to_@(to)(@
@(field["ros2"]["name"])@(frm), @(field["ros1"]["name"])@(to));
Expand Down
145 changes: 80 additions & 65 deletions ros1_bridge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,21 +67,16 @@
def generate_cpp(output_path, template_dir):
rospack = rospkg.RosPack()
data = generate_messages(rospack)
message_string_pairs = {
(
'%s/%s' % (m.ros1_msg.package_name, m.ros1_msg.message_name),
'%s/%s' % (m.ros2_msg.package_name, m.ros2_msg.message_name))
for m in data['mappings']}
data.update(
generate_services(rospack, message_string_pairs=message_string_pairs))
data.update(generate_services(rospack))

template_file = os.path.join(template_dir, 'get_mappings.cpp.em')
output_file = os.path.join(output_path, 'get_mappings.cpp')
data_for_template = {
'mappings': data['mappings'], 'services': data['services']}
expand_template(template_file, data_for_template, output_file)

unique_package_names = set(data['ros2_package_names_msg'] + data['ros2_package_names_srv'])
unique_package_names = set(
data['ros2_package_names_msg'] + data['ros2_package_names_srv'])
# skip builtin_interfaces since there is a custom implementation
unique_package_names -= {'builtin_interfaces'}
data['ros2_package_names'] = list(unique_package_names)
Expand Down Expand Up @@ -147,7 +142,8 @@ def generate_cpp(output_path, template_dir):
if s['ros2_package'] == ros2_package_name and
s['ros2_name'] == interface.message_name],
}
template_file = os.path.join(template_dir, 'interface_factories.cpp.em')
template_file = os.path.join(
template_dir, 'interface_factories.cpp.em')
output_file = os.path.join(
output_path, '%s__%s__%s__factories.cpp' %
(ros2_package_name, interface_type, interface.message_name))
Expand All @@ -158,8 +154,10 @@ def generate_messages(rospack=None):
ros1_msgs = get_ros1_messages(rospack=rospack)
ros2_package_names, ros2_msgs, mapping_rules = get_ros2_messages()

package_pairs = determine_package_pairs(ros1_msgs, ros2_msgs, mapping_rules)
message_pairs = determine_message_pairs(ros1_msgs, ros2_msgs, package_pairs, mapping_rules)
package_pairs = determine_package_pairs(
ros1_msgs, ros2_msgs, mapping_rules)
message_pairs = determine_message_pairs(
ros1_msgs, ros2_msgs, package_pairs, mapping_rules)

mappings = []
# add custom mapping for builtin_interfaces
Expand All @@ -179,7 +177,8 @@ def generate_messages(rospack=None):
msg_idx.ros2_put(ros2_msg)

for ros1_msg, ros2_msg in message_pairs:
mapping = determine_field_mapping(ros1_msg, ros2_msg, mapping_rules, msg_idx)
mapping = determine_field_mapping(
ros1_msg, ros2_msg, mapping_rules, msg_idx)
if mapping:
mappings.append(mapping)

Expand Down Expand Up @@ -209,7 +208,8 @@ def generate_messages(rospack=None):
('%s/%s' % (m.ros1_msg.package_name, m.ros1_msg.message_name),
'%s/%s' % (m.ros2_msg.package_name, m.ros2_msg.message_name)), file=sys.stderr)
for d in m.depends_on_ros2_messages:
print(' -', '%s/%s' % (d.package_name, d.message_name), file=sys.stderr)
print(' -', '%s/%s' %
(d.package_name, d.message_name), file=sys.stderr)
print(file=sys.stderr)

return {
Expand Down Expand Up @@ -254,7 +254,8 @@ def get_ros2_messages():
resources = ament_index_python.get_resources(resource_type)
for package_name, prefix_path in resources.items():
pkgs.append(package_name)
resource, _ = ament_index_python.get_resource(resource_type, package_name)
resource, _ = ament_index_python.get_resource(
resource_type, package_name)
interfaces = resource.splitlines()
message_names = {
i[4:-4]
Expand All @@ -271,7 +272,8 @@ def get_ros2_messages():
continue
if 'mapping_rules' not in export.attributes:
continue
rule_file = os.path.join(package_path, export.attributes['mapping_rules'])
rule_file = os.path.join(
package_path, export.attributes['mapping_rules'])
with open(rule_file, 'r') as h:
content = yaml.safe_load(h)
if not isinstance(content, list):
Expand Down Expand Up @@ -307,7 +309,8 @@ def get_ros2_services():
resources = ament_index_python.get_resources(resource_type)
for package_name, prefix_path in resources.items():
pkgs.append(package_name)
resource, _ = ament_index_python.get_resource(resource_type, package_name)
resource, _ = ament_index_python.get_resource(
resource_type, package_name)
interfaces = resource.splitlines()
service_names = {
i[4:-4]
Expand All @@ -324,7 +327,8 @@ def get_ros2_services():
continue
if 'mapping_rules' not in export.attributes:
continue
rule_file = os.path.join(package_path, export.attributes['mapping_rules'])
rule_file = os.path.join(
package_path, export.attributes['mapping_rules'])
with open(rule_file, 'r') as h:
content = yaml.safe_load(h)
if not isinstance(content, list):
Expand Down Expand Up @@ -385,7 +389,8 @@ def __init__(self, data, expected_package_name):
self.ros2_package_name = data['ros2_package_name']
self.package_mapping = (len(data) == 2)
else:
raise Exception('Ignoring a rule without a ros1_package_name and/or ros2_package_name')
raise Exception(
'Ignoring a rule without a ros1_package_name and/or ros2_package_name')

def is_package_mapping(self):
return self.package_mapping
Expand Down Expand Up @@ -427,7 +432,7 @@ def is_field_mapping(self):
return self.fields_1_to_2 is not None

def __str__(self):
return 'MessageMappingRule(%s <-> %s)' % (self.ros1_package_name, self.ros2_package_name)
return 'MessageMappingRule(%s::%s <-> %s::%s)' % (self.ros1_package_name, self.ros1_message_name, self.ros2_package_name, self.ros2_message_name)


class ServiceMappingRule(MappingRule):
Expand Down Expand Up @@ -466,7 +471,7 @@ def __init__(self, data, expected_package_name):
'Mapping for package %s contains unknown field(s)' % self.ros2_package_name)

def __str__(self):
return 'ServiceMappingRule(%s <-> %s)' % (self.ros1_package_name, self.ros2_package_name)
return 'ServiceMappingRule(%s::%s <-> %s::%s)' % (self.ros1_package_name, self.ros1_service_name, self.ros2_package_name, self.ros2_service_name)


def determine_package_pairs(ros1_msgs, ros2_msgs, mapping_rules):
Expand Down Expand Up @@ -544,35 +549,33 @@ def determine_message_pairs(ros1_msgs, ros2_msgs, package_pairs, mapping_rules):

return pairs


def determine_common_services(
ros1_srvs, ros2_srvs, mapping_rules, message_string_pairs=None
):
if message_string_pairs is None:
message_string_pairs = set()

pairs = []
def determine_common_services(ros1_srvs, ros2_srvs, mapping_rules):
pairs = set()
services = []
# determine service names considered equal between ROS 1 and ROS 2
for ros1_srv in ros1_srvs:
for ros2_srv in ros2_srvs:
if ros1_srv.package_name == ros2_srv.package_name:
if ros1_srv.message_name == ros2_srv.message_name:
pairs.append((ros1_srv, ros2_srv))
try:
ros2_srv = ros2_srvs[ros2_srvs.index(ros1_srv)]
except:
print ("No matching pair for %s" % str(ros1_srv), file=sys.stdout)
continue
pairs.add((ros1_srv, ros2_srv))

# add manual service mapping rules
for rule in mapping_rules:
for ros1_srv in ros1_srvs:
for ros2_srv in ros2_srvs:
if rule.ros1_package_name == ros1_srv.package_name and \
rule.ros2_package_name == ros2_srv.package_name:
if rule.ros1_service_name is None and rule.ros2_service_name is None:
if ros1_srv.message_name == ros2_srv.message_name:
pairs.append((ros1_srv, ros2_srv))
else:
if (
rule.ros1_service_name == ros1_srv.message_name and
rule.ros2_service_name == ros2_srv.message_name
):
pairs.append((ros1_srv, ros2_srv))
ros1_rule = Message(rule.ros1_package_name, rule.ros1_service_name)
ros2_rule = Message(rule.ros2_package_name, rule.ros2_service_name)
try:
ros1_srv = ros1_srvs[ros1_srvs.index(ros1_rule)]
except:
print ("No matching srv for rule %s" % str(ros1_rule), file=sys.stderr)
continue
try:
ros2_srv = ros2_srvs[ros2_srvs.index(ros2_rule)]
except:
print ("No matching srv for rule %s" % str(ros2_rule), file=sys.stderr)
continue
pairs.add((ros1_srv, ros2_srv))

for pair in pairs:
ros1_spec = load_ros1_service(pair[0])
Expand All @@ -599,12 +602,10 @@ def determine_common_services(
ros2_type = str(ros2_fields[direction][i].type)
ros1_name = ros1_field[1]
ros2_name = ros2_fields[direction][i].name
if ros1_type != ros2_type or ros1_name != ros2_name:
# if the message types have a custom mapping their names
# might not be equal, therefore check the message pairs
if (ros1_type, ros2_type) not in message_string_pairs:
match = False
break

if ros1_type != ros2_type:
match = False
break
output[direction].append({
'basic': False if '/' in ros1_type else True,
'array': True if '[]' in ros1_type else False,
Expand Down Expand Up @@ -664,11 +665,14 @@ def consume_field(field):
selected_fields.append(field)

fields = ros1_field_selection.split('.')
current_field = [f for f in parent_ros1_spec.parsed_fields() if f.name == fields[0]][0]
current_field = [f for f in parent_ros1_spec.parsed_fields()
if f.name == fields[0]][0]
consume_field(current_field)
for field in fields[1:]:
parent_ros1_spec = load_ros1_message(msg_idx.ros1_get_from_field(current_field))
current_field = [f for f in parent_ros1_spec.parsed_fields() if f.name == field][0]
parent_ros1_spec = load_ros1_message(
msg_idx.ros1_get_from_field(current_field))
current_field = [
f for f in parent_ros1_spec.parsed_fields() if f.name == field][0]
consume_field(current_field)

return tuple(selected_fields)
Expand All @@ -677,11 +681,14 @@ def consume_field(field):
def get_ros2_selected_fields(ros2_field_selection, parent_ros2_spec, msg_idx):
selected_fields = []
fields = ros2_field_selection.split('.')
current_field = [f for f in parent_ros2_spec.fields if f.name == fields[0]][0]
current_field = [
f for f in parent_ros2_spec.fields if f.name == fields[0]][0]
selected_fields.append(current_field)
for field in fields[1:]:
parent_ros2_spec = load_ros2_message(msg_idx.ros2_get_from_field(current_field))
current_field = [f for f in parent_ros2_spec.fields if f.name == field][0]
parent_ros2_spec = load_ros2_message(
msg_idx.ros2_get_from_field(current_field))
current_field = [
f for f in parent_ros2_spec.fields if f.name == field][0]
selected_fields.append(current_field)
return tuple(selected_fields)

Expand Down Expand Up @@ -720,7 +727,8 @@ def determine_field_mapping(ros1_msg, ros2_msg, mapping_rules, msg_idx):
for ros1_field_selection, ros2_field_selection in rule.fields_1_to_2.items():
try:
ros1_selected_fields = \
get_ros1_selected_fields(ros1_field_selection, ros1_spec, msg_idx)
get_ros1_selected_fields(
ros1_field_selection, ros1_spec, msg_idx)
except IndexError:
print(
"A manual mapping refers to an invalid field '%s' " % ros1_field_selection +
Expand All @@ -730,7 +738,8 @@ def determine_field_mapping(ros1_msg, ros2_msg, mapping_rules, msg_idx):
continue
try:
ros2_selected_fields = \
get_ros2_selected_fields(ros2_field_selection, ros2_spec, msg_idx)
get_ros2_selected_fields(
ros2_field_selection, ros2_spec, msg_idx)
except IndexError:
print(
"A manual mapping refers to an invalid field '%s' " % ros2_field_selection +
Expand All @@ -749,7 +758,8 @@ def determine_field_mapping(ros1_msg, ros2_msg, mapping_rules, msg_idx):
if ros1_field.name.lower() == ros2_field.name:
# get package name and message name from ROS 1 field type
if ros2_field.type.pkg_name:
update_ros1_field_information(ros1_field, ros1_msg.package_name)
update_ros1_field_information(
ros1_field, ros1_msg.package_name)
mapping.add_field_pair(ros1_field, ros2_field)
break
else:
Expand All @@ -773,7 +783,8 @@ def determine_field_mapping(ros1_msg, ros2_msg, mapping_rules, msg_idx):

def load_ros1_message(ros1_msg):
msg_context = genmsg.MsgContext.create_default()
message_path = os.path.join(ros1_msg.prefix_path, ros1_msg.message_name + '.msg')
message_path = os.path.join(
ros1_msg.prefix_path, ros1_msg.message_name + '.msg')
try:
spec = genmsg.msg_loader.load_msg_from_file(
msg_context, message_path, '%s/%s' % (ros1_msg.package_name, ros1_msg.message_name))
Expand All @@ -784,10 +795,12 @@ def load_ros1_message(ros1_msg):

def load_ros1_service(ros1_srv):
srv_context = genmsg.MsgContext.create_default()
srv_path = os.path.join(ros1_srv.prefix_path, ros1_srv.message_name + '.srv')
srv_path = os.path.join(ros1_srv.prefix_path,
ros1_srv.message_name + '.srv')
srv_name = '%s/%s' % (ros1_srv.package_name, ros1_srv.message_name)
try:
spec = genmsg.msg_loader.load_srv_from_file(srv_context, srv_path, srv_name)
spec = genmsg.msg_loader.load_srv_from_file(
srv_context, srv_path, srv_name)
except genmsg.InvalidMsgSpec:
return None
return spec
Expand All @@ -798,7 +811,8 @@ def load_ros2_message(ros2_msg):
ros2_msg.prefix_path, 'share', ros2_msg.package_name, 'msg',
ros2_msg.message_name + '.msg')
try:
spec = rosidl_adapter.parser.parse_message_file(ros2_msg.package_name, message_path)
spec = rosidl_adapter.parser.parse_message_file(
ros2_msg.package_name, message_path)
except rosidl_adapter.parser.InvalidSpecification:
return None
return spec
Expand All @@ -809,7 +823,8 @@ def load_ros2_service(ros2_srv):
ros2_srv.prefix_path, 'share', ros2_srv.package_name, 'srv',
ros2_srv.message_name + '.srv')
try:
spec = rosidl_adapter.parser.parse_service_file(ros2_srv.package_name, srv_path)
spec = rosidl_adapter.parser.parse_service_file(
ros2_srv.package_name, srv_path)
except rosidl_adapter.parser.InvalidSpecification:
return None
return spec
Expand Down

0 comments on commit 309ece6

Please sign in to comment.