1
+ import ast
2
+ import keyword
3
+ from pathlib import Path
4
+ from typing import List , Dict
5
+
6
+ import astor
7
+ from lionwebpython .language import Language , Concept , Interface , Containment , Property
8
+ from lionwebpython .language .classifier import Classifier
9
+ from lionwebpython .language .enumeration import Enumeration
10
+ from lionwebpython .language .primitive_type import PrimitiveType
11
+ from lionwebpython .language .reference import Reference
12
+ from lionwebpython .lionweb_version import LionWebVersion
13
+
14
+ from pylasu .lionweb .starlasu import StarLasuBaseLanguage
15
+ from pylasu .lionweb .utils import to_snake_case
16
+
17
+
18
+ def make_cond (enumeration_name : str , member_name : str ):
19
+ return ast .Compare (
20
+ left = ast .Name (id = "serialized" , ctx = ast .Load ()),
21
+ ops = [ast .Eq ()],
22
+ comparators = [
23
+ ast .Attribute (
24
+ value = ast .Attribute (
25
+ value = ast .Name (id = enumeration_name , ctx = ast .Load ()),
26
+ attr = member_name ,
27
+ ctx = ast .Load ()
28
+ ),
29
+ attr = "value" ,
30
+ ctx = ast .Load ()
31
+ )
32
+ ]
33
+ )
34
+
35
+ # The return: return AssignmentType.Add
36
+ def make_return (enumeration_name : str , member_name : str ):
37
+ return ast .Return (
38
+ value = ast .Attribute (
39
+ value = ast .Name (id = enumeration_name , ctx = ast .Load ()),
40
+ attr = member_name ,
41
+ ctx = ast .Load ()
42
+ )
43
+ )
44
+
45
+
46
+ def deserializer_generation (click , language : Language , output ):
47
+ import_abc = ast .ImportFrom (
48
+ module = 'abc' ,
49
+ names = [ast .alias (name = 'ABC' , asname = None )],
50
+ level = 0
51
+ )
52
+ import_dataclass = ast .ImportFrom (
53
+ module = 'dataclasses' ,
54
+ names = [ast .alias (name = 'dataclass' , asname = None )],
55
+ level = 0
56
+ )
57
+ import_enum = ast .ImportFrom (
58
+ module = "enum" ,
59
+ names = [ast .alias (name = "Enum" , asname = None )],
60
+ level = 0
61
+ )
62
+ import_typing = ast .ImportFrom (
63
+ module = 'typing' ,
64
+ names = [ast .alias (name = 'Optional' , asname = None )],
65
+ level = 0
66
+ )
67
+ import_starlasu = ast .ImportFrom (
68
+ module = 'pylasu.model.metamodel' ,
69
+ names = [ast .alias (name = 'Expression' , asname = 'StarLasuExpression' ),
70
+ ast .alias (name = 'PlaceholderElement' , asname = 'StarLasuPlaceholderElement' ),
71
+ ast .alias (name = 'Named' , asname = 'StarLasuNamed' ),
72
+ ast .alias (name = 'TypeAnnotation' , asname = 'StarLasuTypeAnnotation' ),
73
+ ast .alias (name = 'Parameter' , asname = 'StarLasuParameter' ),
74
+ ast .alias (name = 'Statement' , asname = 'StarLasuStatement' ),
75
+ ast .alias (name = 'EntityDeclaration' , asname = 'StarLasuEntityDeclaration' ),
76
+ ast .alias (name = 'BehaviorDeclaration' , asname = 'StarLasuBehaviorDeclaration' ),
77
+ ast .alias (name = 'Documentation' , asname = 'StarLasuDocumentation' )],
78
+ level = 0
79
+ )
80
+ import_node = ast .ImportFrom (
81
+ module = 'pylasu.model' ,
82
+ names = [ast .alias (name = 'Node' , asname = None )],
83
+ level = 0
84
+ )
85
+ import_ast = ast .ImportFrom (
86
+ module = 'ast' ,
87
+ names = [ast .alias (name = e .get_name (), asname = None ) for e in language .get_elements () if not isinstance (e , PrimitiveType )],
88
+ level = 0
89
+ )
90
+ import_primitives = ast .ImportFrom (
91
+ module = 'primitive_types' ,
92
+ names = [ast .alias (name = e .get_name (), asname = None ) for e in language .get_elements () if isinstance (e , PrimitiveType )],
93
+ level = 0
94
+ )
95
+ module = ast .Module (body = [import_abc , import_dataclass , import_typing , import_enum , import_starlasu , import_node ,
96
+ import_ast , import_primitives ],
97
+ type_ignores = [])
98
+
99
+
100
+
101
+ for e in language .get_elements ():
102
+ if isinstance (e , Enumeration ):
103
+ arg_serialized = ast .arg (arg = "serialized" , annotation = ast .Name (id = "str" , ctx = ast .Load ()))
104
+ # The raise: raise ValueError(f"...")
105
+ raise_stmt = ast .Raise (
106
+ exc = ast .Call (
107
+ func = ast .Name (id = "ValueError" , ctx = ast .Load ()),
108
+ args = [
109
+ ast .JoinedStr (values = [
110
+ ast .Constant (value = f"Invalid value for { e .get_name ()} : " ),
111
+ ast .FormattedValue (
112
+ value = ast .Name (id = "serialized" , ctx = ast .Load ()),
113
+ conversion = - 1
114
+ )
115
+ ])
116
+ ],
117
+ keywords = []
118
+ ),
119
+ cause = None
120
+ )
121
+ # The function body
122
+ literals = e .get_literals ()
123
+ current_if = ast .If (
124
+ test = make_cond (e .get_name (), literals [0 ].get_name ()),
125
+ body = [make_return (e .get_name (), literals [0 ].get_name ())],
126
+ orelse = []
127
+ )
128
+ root_if = current_if
129
+
130
+ for literal in literals [1 :]:
131
+ next_if = ast .If (
132
+ test = make_cond (e .get_name (), literal .get_name ()),
133
+ body = [make_return (e .get_name (), literal .get_name ())],
134
+ orelse = []
135
+ )
136
+ current_if .orelse = [next_if ]
137
+ current_if = next_if
138
+
139
+ # Final else
140
+ current_if .orelse = [raise_stmt ]
141
+
142
+ # Function definition
143
+ func_def = ast .FunctionDef (
144
+ name = f"_deserialize_{ to_snake_case (e .get_name ())} " ,
145
+ args = ast .arguments (
146
+ posonlyargs = [],
147
+ args = [arg_serialized ],
148
+ kwonlyargs = [],
149
+ kw_defaults = [],
150
+ defaults = []
151
+ ),
152
+ body = [root_if ],
153
+ decorator_list = [],
154
+ returns = ast .Constant (value = e .get_name ())
155
+ )
156
+ module .body .append (func_def )
157
+
158
+ generated_code = astor .to_source (module )
159
+ output_path = Path (output )
160
+ output_path .mkdir (parents = True , exist_ok = True )
161
+ click .echo (f"📂 Saving deserializer to: { output } " )
162
+ with Path (f"{ output } /deserializer.py" ).open ("w" , encoding = "utf-8" ) as f :
163
+ f .write (generated_code )
0 commit comments