16
16
limitations under the License.
17
17
-----------------------------------------------------------------
18
18
"""
19
-
19
+ import itertools
20
20
from abc import ABC , abstractmethod
21
21
from collections .abc import Callable
22
22
from typing import ClassVar , Optional , Union
23
23
24
24
from app .translator .const import DEFAULT_VALUE_TYPE
25
- from app .translator .core .context_vars import return_only_first_query_ctx_var
25
+ from app .translator .core .const import TOKEN_TYPE
26
+ from app .translator .core .context_vars import return_only_first_query_ctx_var , wrap_query_with_meta_info_ctx_var
26
27
from app .translator .core .custom_types .tokens import LogicalOperatorType , OperatorType
27
28
from app .translator .core .custom_types .values import ValueType
28
29
from app .translator .core .escape_manager import EscapeManager
29
30
from app .translator .core .exceptions .core import NotImplementedException , StrictPlatformException
30
31
from app .translator .core .exceptions .parser import UnsupportedOperatorException
31
32
from app .translator .core .functions import PlatformFunctions
32
33
from app .translator .core .mapping import DEFAULT_MAPPING_NAME , BasePlatformMappings , LogSourceSignature , SourceMapping
33
- from app .translator .core .models .field import Field , FieldValue , Keyword
34
+ from app .translator .core .models .field import Field , FieldField , FieldValue , Keyword
34
35
from app .translator .core .models .functions .base import Function , RenderedFunctions
35
36
from app .translator .core .models .identifier import Identifier
36
37
from app .translator .core .models .platform_details import PlatformDetails
37
38
from app .translator .core .models .query_container import MetaInfoContainer , RawQueryContainer , TokenizedQueryContainer
38
39
from app .translator .core .str_value_manager import StrValue , StrValueManager
39
- from app .translator .core .tokenizer import TOKEN_TYPE
40
40
41
41
42
- class BaseQueryFieldValue (ABC ):
42
+ class BaseFieldValueRender (ABC ):
43
43
details : PlatformDetails = None
44
44
escape_manager : EscapeManager = None
45
45
str_value_manager : StrValueManager = None
46
46
47
47
def __init__ (self , or_token : str ):
48
- self .field_value : dict [str , Callable [[str , DEFAULT_VALUE_TYPE ], str ]] = {
48
+ self .modifiers_map : dict [str , Callable [[str , DEFAULT_VALUE_TYPE ], str ]] = {
49
49
OperatorType .EQ : self .equal_modifier ,
50
50
OperatorType .NOT_EQ : self .not_equal_modifier ,
51
51
OperatorType .LT : self .less_modifier ,
@@ -155,11 +155,20 @@ def apply_value(self, value: Union[str, int], value_type: str = ValueType.value)
155
155
return self .escape_manager .escape (value , value_type )
156
156
157
157
def apply_field_value (self , field : str , operator : Identifier , value : DEFAULT_VALUE_TYPE ) -> str :
158
- if modifier_function := self .field_value .get (operator .token_type ):
158
+ if modifier_function := self .modifiers_map .get (operator .token_type ):
159
159
return modifier_function (field , value )
160
160
raise UnsupportedOperatorException (operator .token_type )
161
161
162
162
163
+ class BaseFieldFieldRender (ABC ):
164
+ operators_map : ClassVar [dict [str , str ]] = {}
165
+
166
+ def apply_field_field (self , field_left : str , operator : Identifier , field_right : str ) -> str :
167
+ if mapped_operator := self .operators_map .get (operator .token_type ):
168
+ return f"{ field_left } { mapped_operator } { field_right } "
169
+ raise UnsupportedOperatorException (operator .token_type )
170
+
171
+
163
172
class QueryRender (ABC ):
164
173
comment_symbol : str = None
165
174
details : PlatformDetails = None
@@ -180,6 +189,13 @@ def render_not_supported_functions(self, not_supported_functions: list) -> str:
180
189
not_supported_functions_str = "\n " .join (line_template + func .lstrip () for func in not_supported_functions )
181
190
return "\n \n " + self .wrap_with_comment (f"{ self .unsupported_functions_text } \n { not_supported_functions_str } " )
182
191
192
+ def wrap_with_not_supported_functions (self , query : str , not_supported_functions : Optional [list ] = None ) -> str :
193
+ if not_supported_functions and wrap_query_with_meta_info_ctx_var .get ():
194
+ rendered_not_supported = self .render_not_supported_functions (not_supported_functions )
195
+ return query + rendered_not_supported
196
+
197
+ return query
198
+
183
199
def wrap_with_comment (self , value : str ) -> str :
184
200
return f"{ self .comment_symbol } { value } "
185
201
@@ -199,13 +215,14 @@ class PlatformQueryRender(QueryRender):
199
215
group_token = "(%s)"
200
216
query_parts_delimiter = " "
201
217
202
- field_value_map = BaseQueryFieldValue (or_token = or_token )
218
+ field_field_render = BaseFieldFieldRender ()
219
+ field_value_render = BaseFieldValueRender (or_token = or_token )
203
220
204
221
raw_log_field_pattern_map : ClassVar [dict [str , str ]] = None
205
222
206
223
def __init__ (self ):
207
224
super ().__init__ ()
208
- self .operator_map = {
225
+ self .logical_operators_map = {
209
226
LogicalOperatorType .AND : f" { self .and_token } " ,
210
227
LogicalOperatorType .OR : f" { self .or_token } " ,
211
228
LogicalOperatorType .NOT : f" { self .not_token } " ,
@@ -233,31 +250,34 @@ def map_field(self, field: Field, source_mapping: SourceMapping) -> list[str]:
233
250
234
251
def apply_token (self , token : Union [FieldValue , Keyword , Identifier ], source_mapping : SourceMapping ) -> str :
235
252
if isinstance (token , FieldValue ):
236
- if token .alias :
237
- field_name = token .alias .name
238
- else :
239
- mapped_fields = self .map_field (token .field , source_mapping )
240
- if len (mapped_fields ) > 1 :
241
- return self .group_token % self .operator_map [LogicalOperatorType .OR ].join (
242
- [
243
- self .field_value_map .apply_field_value (
244
- field = field , operator = token .operator , value = token .value
245
- )
246
- for field in mapped_fields
247
- ]
248
- )
249
-
250
- field_name = mapped_fields [0 ]
251
-
252
- return self .field_value_map .apply_field_value (field = field_name , operator = token .operator , value = token .value )
253
-
253
+ mapped_fields = [token .alias .name ] if token .alias else self .map_field (token .field , source_mapping )
254
+ joined = self .logical_operators_map [LogicalOperatorType .OR ].join (
255
+ [
256
+ self .field_value_render .apply_field_value (field = field , operator = token .operator , value = token .value )
257
+ for field in mapped_fields
258
+ ]
259
+ )
260
+ return self .group_token % joined if len (mapped_fields ) > 1 else joined
261
+ if isinstance (token , FieldField ):
262
+ alias_left , field_left = token .alias_left , token .field_left
263
+ mapped_fields_left = [alias_left .name ] if alias_left else self .map_field (field_left , source_mapping )
264
+ alias_right , field_right = token .alias_right , token .field_right
265
+ mapped_fields_right = [alias_right .name ] if alias_right else self .map_field (field_right , source_mapping )
266
+ cross_paired_fields = list (itertools .product (mapped_fields_left , mapped_fields_right ))
267
+ joined = self .logical_operators_map [LogicalOperatorType .OR ].join (
268
+ [
269
+ self .field_field_render .apply_field_field (pair [0 ], token .operator , pair [1 ])
270
+ for pair in cross_paired_fields
271
+ ]
272
+ )
273
+ return self .group_token % joined if len (cross_paired_fields ) > 1 else joined
254
274
if isinstance (token , Function ):
255
275
func_render = self .platform_functions .manager .get_in_query_render (token .name )
256
276
return func_render .render (token , source_mapping )
257
277
if isinstance (token , Keyword ):
258
- return self .field_value_map .apply_field_value (field = "" , operator = token .operator , value = token .value )
278
+ return self .field_value_render .apply_field_value (field = "" , operator = token .operator , value = token .value )
259
279
if token .token_type in LogicalOperatorType :
260
- return self .operator_map .get (token .token_type )
280
+ return self .logical_operators_map .get (token .token_type )
261
281
262
282
return token .token_type
263
283
@@ -273,8 +293,8 @@ def generate_query(self, tokens: list[TOKEN_TYPE], source_mapping: SourceMapping
273
293
raise StrictPlatformException (self .details .name , "" , source_mapping .source_id , sorted (unmapped_fields ))
274
294
return "" .join (result_values )
275
295
276
- def wrap_query_with_meta_info (self , meta_info : MetaInfoContainer , query : str ) -> str :
277
- if meta_info and (meta_info .id or meta_info .title ):
296
+ def wrap_with_meta_info (self , query : str , meta_info : Optional [ MetaInfoContainer ] ) -> str :
297
+ if wrap_query_with_meta_info_ctx_var . get () and meta_info and (meta_info .id or meta_info .title ):
278
298
meta_info_dict = {
279
299
"name: " : meta_info .title ,
280
300
"uuid: " : meta_info .id ,
@@ -307,11 +327,8 @@ def finalize_query(
307
327
** kwargs , # noqa: ARG002
308
328
) -> str :
309
329
query = self ._join_query_parts (prefix , query , functions )
310
- query = self .wrap_query_with_meta_info (meta_info = meta_info , query = query )
311
- if not_supported_functions :
312
- rendered_not_supported = self .render_not_supported_functions (not_supported_functions )
313
- return query + rendered_not_supported
314
- return query
330
+ query = self .wrap_with_meta_info (query , meta_info )
331
+ return self .wrap_with_not_supported_functions (query , not_supported_functions )
315
332
316
333
@staticmethod
317
334
def unique_queries (queries_map : dict [str , str ]) -> dict [str , dict [str ]]:
@@ -342,7 +359,7 @@ def _get_source_mappings(self, source_mapping_ids: list[str]) -> list[SourceMapp
342
359
343
360
return source_mappings
344
361
345
- def _generate_from_raw_query_container (self , query_container : RawQueryContainer ) -> str :
362
+ def generate_from_raw_query_container (self , query_container : RawQueryContainer ) -> str :
346
363
return self .finalize_query (
347
364
prefix = "" , query = query_container .query , functions = "" , meta_info = query_container .meta_info
348
365
)
@@ -380,7 +397,7 @@ def generate_raw_log_fields(self, fields: list[Field], source_mapping: SourceMap
380
397
defined_raw_log_fields .append (prefix )
381
398
return "\n " .join (defined_raw_log_fields )
382
399
383
- def _generate_from_tokenized_query_container (self , query_container : TokenizedQueryContainer ) -> str :
400
+ def generate_from_tokenized_query_container (self , query_container : TokenizedQueryContainer ) -> str :
384
401
queries_map = {}
385
402
errors = []
386
403
source_mappings = self ._get_source_mappings (query_container .meta_info .source_mapping_ids )
@@ -417,6 +434,6 @@ def _generate_from_tokenized_query_container(self, query_container: TokenizedQue
417
434
418
435
def generate (self , query_container : Union [RawQueryContainer , TokenizedQueryContainer ]) -> str :
419
436
if isinstance (query_container , RawQueryContainer ):
420
- return self ._generate_from_raw_query_container (query_container )
437
+ return self .generate_from_raw_query_container (query_container )
421
438
422
- return self ._generate_from_tokenized_query_container (query_container )
439
+ return self .generate_from_tokenized_query_container (query_container )
0 commit comments