Skip to content

Commit fdf1761

Browse files
Fix issue tracking recursive references to custom annotations (#197)
* Handle recursive custom annotations dependencies * Fix tests and lint Co-authored-by: Brad Girardeau <[email protected]>
1 parent ff84331 commit fdf1761

7 files changed

+158
-91
lines changed

stone/backends/python_client.py

-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@
108108

109109

110110
class PythonClientBackend(CodeBackend):
111-
# pylint: disable=attribute-defined-outside-init
112111

113112
cmdline_parser = _cmdline_parser
114113
supported_auth_types = None

stone/backends/python_types.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,9 @@
1414

1515
from stone.ir import AnnotationType, ApiNamespace
1616
from stone.ir import (
17-
get_custom_annotations_for_alias,
18-
get_custom_annotations_recursive,
1917
is_alias,
2018
is_boolean_type,
19+
is_composite_type,
2120
is_bytes_type,
2221
is_list_type,
2322
is_map_type,
@@ -642,7 +641,7 @@ def _generate_custom_annotation_processors(self, ns, data_type, extra_annotation
642641
dt, _, _ = unwrap(data_type)
643642
if is_struct_type(dt) or is_union_type(dt):
644643
annotation_types_seen = set()
645-
for annotation in get_custom_annotations_recursive(dt):
644+
for _, annotation in dt.recursive_custom_annotations:
646645
if annotation.annotation_type not in annotation_types_seen:
647646
yield (annotation.annotation_type,
648647
generate_func_call(
@@ -672,7 +671,12 @@ def _generate_custom_annotation_processors(self, ns, data_type, extra_annotation
672671

673672
# annotations applied directly to this type (through aliases or
674673
# passed in from the caller)
675-
for annotation in itertools.chain(get_custom_annotations_for_alias(data_type),
674+
indirect_annotations = dt.recursive_custom_annotations if is_composite_type(dt) else set()
675+
all_annotations = (data_type.recursive_custom_annotations
676+
if is_composite_type(data_type) else set())
677+
remaining_annotations = [annotation for _, annotation in
678+
all_annotations.difference(indirect_annotations)]
679+
for annotation in itertools.chain(remaining_annotations,
676680
extra_annotations):
677681
yield (annotation.annotation_type,
678682
generate_func_call(

stone/frontend/ir_generator.py

+71
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
Int32,
3838
Int64,
3939
is_alias,
40+
is_composite_type,
4041
is_field_type,
4142
is_list_type,
4243
is_map_type,
@@ -297,6 +298,7 @@ def generate_IR(self):
297298
self._populate_field_defaults()
298299
self._populate_enumerated_subtypes()
299300
self._populate_route_attributes()
301+
self._populate_recursive_custom_annotations()
300302
self._populate_examples()
301303
self._validate_doc_refs()
302304
self._validate_annotations()
@@ -802,6 +804,75 @@ def _populate_union_type_attributes(self, env, data_type):
802804
data_type.set_attributes(
803805
data_type._ast_node.doc, api_type_fields, parent_type, catch_all_field)
804806

807+
def _populate_recursive_custom_annotations(self):
808+
"""
809+
Populates custom annotations applied to fields recursively. This is done in
810+
a separate pass because it requires all fields and routes to be defined so that
811+
recursive chains can be followed accurately.
812+
"""
813+
data_types_seen = set()
814+
815+
def recurse(data_type):
816+
# primitive types do not have annotations
817+
if not is_composite_type(data_type):
818+
return set()
819+
820+
# if we have already analyzed data type, just return result
821+
if data_type.recursive_custom_annotations is not None:
822+
return data_type.recursive_custom_annotations
823+
824+
# handle cycles safely (annotations will be found first time at top level)
825+
if data_type in data_types_seen:
826+
return set()
827+
data_types_seen.add(data_type)
828+
829+
annotations = set()
830+
831+
# collect data types from subtypes recursively
832+
if is_struct_type(data_type) or is_union_type(data_type):
833+
for field in data_type.fields:
834+
annotations.update(recurse(field.data_type))
835+
# annotations can be defined directly on fields
836+
annotations.update([(field, annotation)
837+
for annotation in field.custom_annotations])
838+
elif is_alias(data_type):
839+
annotations.update(recurse(data_type.data_type))
840+
# annotations can be defined directly on aliases
841+
annotations.update([(data_type, annotation)
842+
for annotation in data_type.custom_annotations])
843+
elif is_list_type(data_type):
844+
annotations.update(recurse(data_type.data_type))
845+
elif is_map_type(data_type):
846+
# only map values support annotations for now
847+
annotations.update(recurse(data_type.value_data_type))
848+
elif is_nullable_type(data_type):
849+
annotations.update(recurse(data_type.data_type))
850+
851+
data_type.recursive_custom_annotations = annotations
852+
return annotations
853+
854+
for namespace in self.api.namespaces.values():
855+
namespace_annotations = set()
856+
for data_type in namespace.data_types:
857+
namespace_annotations.update(recurse(data_type))
858+
859+
for alias in namespace.aliases:
860+
namespace_annotations.update(recurse(alias))
861+
862+
for route in namespace.routes:
863+
namespace_annotations.update(recurse(route.arg_data_type))
864+
namespace_annotations.update(recurse(route.result_data_type))
865+
namespace_annotations.update(recurse(route.error_data_type))
866+
867+
# record annotation types as dependencies of the namespace. this allows for
868+
# an optimization when processing custom annotations to ignore annotation
869+
# types that are not applied to the data type, rather than recursing into it
870+
for _, annotation in namespace_annotations:
871+
if annotation.annotation_type.namespace.name != namespace.name:
872+
namespace.add_imported_namespace(
873+
annotation.annotation_type.namespace,
874+
imported_annotation_type=True)
875+
805876
def _populate_field_defaults(self):
806877
"""
807878
Populate the defaults of each field. This is done in a separate pass

stone/ir/data_types.py

+32-84
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,30 @@ def generic_type_name(v):
5555
return type(v).__name__
5656

5757

58+
def record_custom_annotation_imports(annotation, namespace):
59+
"""
60+
Records imports for custom annotations in the given namespace.
61+
62+
"""
63+
# first, check the annotation *type*
64+
if annotation.annotation_type.namespace.name != namespace.name:
65+
namespace.add_imported_namespace(
66+
annotation.annotation_type.namespace,
67+
imported_annotation_type=True)
68+
69+
# second, check if we need to import the annotation itself
70+
71+
# the annotation namespace is currently not actually used in the
72+
# backends, which reconstruct the annotation from the annotation
73+
# type directly. This could be changed in the future, and at
74+
# the IR level it makes sense to include the dependency
75+
76+
if annotation.namespace.name != namespace.name:
77+
namespace.add_imported_namespace(
78+
annotation.namespace,
79+
imported_annotation=True)
80+
81+
5882
class DataType(object):
5983
"""
6084
Abstract class representing a data type.
@@ -118,6 +142,12 @@ class Composite(DataType): # pylint: disable=abstract-method
118142
Composite types are any data type which can be constructed using primitive
119143
data types and other composite types.
120144
"""
145+
def __init__(self):
146+
super(Composite, self).__init__()
147+
# contains custom annotations that apply to any containing data types (recursively)
148+
# format is (location, CustomAnnotation) to indicate a custom annotation is applied
149+
# to a location (Field or Alias)
150+
self.recursive_custom_annotations = None
121151

122152

123153
class Nullable(Composite):
@@ -781,22 +811,7 @@ def set_attributes(self, doc, fields, parent_type=None):
781811
# they are treated as globals at the IR level
782812
for field in self.fields:
783813
for annotation in field.custom_annotations:
784-
# first, check the annotation *type*
785-
if annotation.annotation_type.namespace.name != self.namespace.name:
786-
self.namespace.add_imported_namespace(
787-
annotation.annotation_type.namespace,
788-
imported_annotation_type=True)
789-
790-
# second, check if we need to import the annotation itself
791-
792-
# the annotation namespace is currently not actually used in the
793-
# backends, which reconstruct the annotation from the annotation
794-
# type directly. This could be changed in the future, and at
795-
# the IR level it makes sense to include the dependency
796-
if annotation.namespace.name != self.namespace.name:
797-
self.namespace.add_imported_namespace(
798-
annotation.namespace,
799-
imported_annotation=True)
814+
record_custom_annotation_imports(annotation, self.namespace)
800815

801816
# Indicate that the attributes of the type have been populated.
802817
self._is_forward_ref = False
@@ -901,7 +916,6 @@ class Struct(UserDefined):
901916
"""
902917
Defines a product type: Composed of other primitive and/or struct types.
903918
"""
904-
# pylint: disable=attribute-defined-outside-init
905919

906920
composite_type = 'struct'
907921

@@ -1359,7 +1373,6 @@ def __repr__(self):
13591373

13601374
class Union(UserDefined):
13611375
"""Defines a tagged union. Fields are variants."""
1362-
# pylint: disable=attribute-defined-outside-init
13631376

13641377
composite_type = 'union'
13651378

@@ -1830,25 +1843,7 @@ def set_annotations(self, annotations):
18301843
elif isinstance(annotation, CustomAnnotation):
18311844
# Note: we don't need to do this for builtin annotations because
18321845
# they are treated as globals at the IR level
1833-
1834-
# first, check the annotation *type*
1835-
if annotation.annotation_type.namespace.name != self.namespace.name:
1836-
self.namespace.add_imported_namespace(
1837-
annotation.annotation_type.namespace,
1838-
imported_annotation_type=True)
1839-
1840-
# second, check if we need to import the annotation itself
1841-
1842-
# the annotation namespace is currently not actually used in the
1843-
# backends, which reconstruct the annotation from the annotation
1844-
# type directly. This could be changed in the future, and at
1845-
# the IR level it makes sense to include the dependency
1846-
1847-
if annotation.namespace.name != self.namespace.name:
1848-
self.namespace.add_imported_namespace(
1849-
annotation.namespace,
1850-
imported_annotation=True)
1851-
1846+
record_custom_annotation_imports(annotation, self.namespace)
18521847
self.custom_annotations.append(annotation)
18531848
else:
18541849
raise InvalidSpec("Aliases only support 'Redacted' and custom annotations, not %r" %
@@ -2002,53 +1997,6 @@ def unwrap(data_type):
20021997
data_type = data_type.data_type
20031998
return data_type, unwrapped_nullable, unwrapped_alias
20041999

2005-
def get_custom_annotations_for_alias(data_type):
2006-
"""
2007-
Given a Stone data type, returns all custom annotations applied to it.
2008-
"""
2009-
# annotations can only be applied to Aliases, but they can be wrapped in
2010-
# Nullable. also, Aliases pointing to other Aliases don't automatically
2011-
# inherit their custom annotations, so we might have to traverse.
2012-
result = []
2013-
data_type, _ = unwrap_nullable(data_type)
2014-
while is_alias(data_type):
2015-
result.extend(data_type.custom_annotations)
2016-
data_type, _ = unwrap_nullable(data_type.data_type)
2017-
return result
2018-
2019-
def get_custom_annotations_recursive(data_type):
2020-
"""
2021-
Given a Stone data type, returns all custom annotations applied to any of
2022-
its memebers, as well as submembers, ..., to an arbitrary depth.
2023-
"""
2024-
# because Stone structs can contain references to themselves (or otherwise
2025-
# be cyclical), we need ot keep track of the data types we've already seen
2026-
data_types_seen = set()
2027-
2028-
def recurse(data_type):
2029-
if data_type in data_types_seen:
2030-
return
2031-
data_types_seen.add(data_type)
2032-
2033-
dt, _, _ = unwrap(data_type)
2034-
if is_struct_type(dt) or is_union_type(dt):
2035-
for field in dt.fields:
2036-
for annotation in recurse(field.data_type):
2037-
yield annotation
2038-
for annotation in field.custom_annotations:
2039-
yield annotation
2040-
elif is_list_type(dt):
2041-
for annotation in recurse(dt.data_type):
2042-
yield annotation
2043-
elif is_map_type(dt):
2044-
for annotation in recurse(dt.value_data_type):
2045-
yield annotation
2046-
2047-
for annotation in get_custom_annotations_for_alias(data_type):
2048-
yield annotation
2049-
2050-
return recurse(data_type)
2051-
20522000

20532001
def is_alias(data_type):
20542002
return isinstance(data_type, Alias)

test/test_python_gen.py

-2
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,6 @@ def test_json_encoder(self):
229229
self.assertEqual(json_encode(bv.Nullable(bv.String()), u'abc'), json.dumps('abc'))
230230

231231
def test_json_encoder_union(self):
232-
# pylint: disable=attribute-defined-outside-init
233232
class S(object):
234233
_all_field_names_ = {'f'}
235234
_all_fields_ = [('f', bv.String())]
@@ -331,7 +330,6 @@ def _get_val_data_type(cls, tag, cp):
331330
self.assertEqual(json_encode(bv.Union(U), u, old_style=True), json.dumps({'g': m}))
332331

333332
def test_json_encoder_error_messages(self):
334-
# pylint: disable=attribute-defined-outside-init
335333
class S3(object):
336334
_all_field_names_ = {'j'}
337335
_all_fields_ = [('j', bv.UInt64(max_value=10))]

test/test_python_types.py

+1
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def test_struct_with_custom_annotations(self):
171171
StructField('unannotated_field', Int32(), None, None),
172172
])
173173
struct.fields[0].set_annotations([annotation])
174+
struct.recursive_custom_annotations = set([annotation])
174175

175176
result = self._evaluate_struct(ns, struct)
176177

test/test_stone.py

+46
Original file line numberDiff line numberDiff line change
@@ -4886,6 +4886,52 @@ def test_custom_annotations(self):
48864886

48874887
struct = api.namespaces['test'].data_type_by_name['TestStruct']
48884888
self.assertEqual(struct.fields[0].custom_annotations[0], annotation)
4889+
self.assertEqual(struct.recursive_custom_annotations, set([
4890+
(alias, api.namespaces['test'].annotation_by_name['VeryImportant']),
4891+
(struct.fields[0], api.namespaces['test'].annotation_by_name['SortaImportant']),
4892+
]))
4893+
4894+
# Test recursive references are captured
4895+
ns2 = textwrap.dedent("""\
4896+
namespace testchain
4897+
4898+
import test
4899+
4900+
alias TestAliasChain = String
4901+
@test.SortaImportant
4902+
4903+
struct TestStructChain
4904+
f test.TestStruct
4905+
g List(TestAliasChain)
4906+
""")
4907+
ns3 = textwrap.dedent("""\
4908+
namespace teststruct
4909+
4910+
import testchain
4911+
4912+
struct TestStructToStruct
4913+
f testchain.TestStructChain
4914+
""")
4915+
ns4 = textwrap.dedent("""\
4916+
namespace testalias
4917+
4918+
import testchain
4919+
4920+
struct TestStructToAlias
4921+
f testchain.TestAliasChain
4922+
""")
4923+
4924+
api = specs_to_ir([('test.stone', text), ('testchain.stone', ns2),
4925+
('teststruct.stone', ns3), ('testalias.stone', ns4)])
4926+
4927+
struct_namespaces = [ns.name for ns in
4928+
api.namespaces['teststruct'].get_imported_namespaces(
4929+
consider_annotation_types=True)]
4930+
self.assertTrue('test' in struct_namespaces)
4931+
alias_namespaces = [ns.name for ns in
4932+
api.namespaces['testalias'].get_imported_namespaces(
4933+
consider_annotation_types=True)]
4934+
self.assertTrue('test' in alias_namespaces)
48894935

48904936

48914937
if __name__ == '__main__':

0 commit comments

Comments
 (0)