diff --git a/tests/cases/arithmetic_decimal/power.test b/tests/cases/arithmetic_decimal/power.test index 37a0712d5..0d4889e0f 100644 --- a/tests/cases/arithmetic_decimal/power.test +++ b/tests/cases/arithmetic_decimal/power.test @@ -1,5 +1,5 @@ ### SUBSTRAIT_SCALAR_TEST: v1.0 -### SUBSTRAIT_INCLUDE: 'extensions/functions_arithmetic_decimal.yaml' +### SUBSTRAIT_INCLUDE: '/extensions/functions_arithmetic_decimal.yaml' # basic: Basic examples without any special cases power(8::dec<38, 0>, 2::dec<38, 0>) = 64::fp64 diff --git a/tests/coverage/coverage.py b/tests/coverage/coverage.py index ebea43273..705dc3273 100755 --- a/tests/coverage/coverage.py +++ b/tests/coverage/coverage.py @@ -109,22 +109,27 @@ def update_test_count(test_case_files: list, function_registry: FunctionRegistry for test_file in test_case_files: for test_case in test_file.testcases: function_variant = function_registry.get_function( - test_case.func_name, test_case.get_arg_types() + test_case.func_name, + test_file.include, + test_case.get_arg_types(), + test_case.get_return_type(), ) if function_variant: if ( - function_variant.return_type != test_case.get_return_type() - and not test_case.is_return_type_error() + not test_case.is_return_type_error() + and not function_registry.is_same_type( + function_variant.return_type, test_case.get_return_type() + ) ): error( - f"Return type mismatch in function {test_case.func_name}: " + f"Return type mismatch in function {test_case.get_signature()}: " f"{function_variant.return_type} != {test_case.get_return_type()}" ) num_tests_with_no_matching_function += 1 continue function_variant.increment_test_count() else: - error(f"Function not found: {test_case.func_name}({test_case.args})") + error(f"Function not found: {test_case.get_signature()}") num_tests_with_no_matching_function += 1 return num_tests_with_no_matching_function diff --git a/tests/coverage/extensions.py b/tests/coverage/extensions.py index ba8a43299..e66681d27 100644 --- a/tests/coverage/extensions.py +++ b/tests/coverage/extensions.py @@ -3,6 +3,7 @@ import yaml from tests.coverage.antlr_parser.FuncTestCaseLexer import FuncTestCaseLexer +from tests.coverage.nodes import SubstraitError enable_debug = False @@ -122,11 +123,10 @@ def get_supported_kernels_from_impls(func): return overloads @staticmethod - def add_functions_to_map(func_list, function_map, suffix, extension): + def add_functions_to_map(func_list, function_map, suffix, extension, uri): dup_idx = 0 for func in func_list: name = func["name"] - uri = extension[5:] # strip the ../.. if name in function_map: debug( f"Duplicate function name: {name} renaming to {name}_{suffix} extension: {extension}" @@ -163,14 +163,19 @@ def read_substrait_extensions(dir_path: str): suffix = suffix[ suffix.rfind("/") + 1 : ] # strip the path and get the name of the extension + uri = f"/extensions/{suffix}.yaml" suffix = suffix[suffix.find("_") + 1 :] # get the suffix after the last _ - dependencies[suffix] = Extension.get_base_uri() + extension + dependencies[suffix] = Extension.get_base_uri() + uri with open(extension, "r") as fh: data = yaml.load(fh, Loader=yaml.FullLoader) if "scalar_functions" in data: Extension.add_functions_to_map( - data["scalar_functions"], scalar_functions, suffix, extension + data["scalar_functions"], + scalar_functions, + suffix, + extension, + uri, ) if "aggregate_functions" in data: Extension.add_functions_to_map( @@ -178,10 +183,15 @@ def read_substrait_extensions(dir_path: str): aggregate_functions, suffix, extension, + uri, ) if "window_functions" in data: Extension.add_functions_to_map( - data["window_functions"], scalar_functions, suffix, extension + data["window_functions"], + scalar_functions, + suffix, + extension, + uri, ) return FunctionRegistry( @@ -263,13 +273,45 @@ def add_functions(self, functions, func_type): fun_arr.append(function) self.registry[f_name] = fun_arr - def get_function(self, name: str, args: object) -> [FunctionVariant]: + @staticmethod + def is_type_any(func_arg_type): + return func_arg_type[:3] == "any" + + @staticmethod + def is_same_type(func_arg_type, arg_type): + arg_type_base = arg_type.split("<")[0] + if func_arg_type == arg_type_base: + return True + return FunctionRegistry.is_type_any(func_arg_type) + + def get_function( + self, name: str, uri: str, args: object, return_type + ) -> [FunctionVariant]: functions = self.registry.get(name, None) if functions is None: return None for function in functions: + if uri != function.uri: + continue + if not isinstance(return_type, SubstraitError) and not self.is_same_type( + function.return_type, return_type + ): + continue if function.args == args: return function + if len(function.args) != len(args) and not ( + function.variadic and len(args) >= len(function.args) + ): + continue + is_match = True + for i, arg in enumerate(args): + j = i if i < len(function.args) else len(function.args) - 1 + if not self.is_same_type(function.args[j], arg): + is_match = False + break + if is_match: + return function + return None def get_extension_list(self): return list(self.extensions) diff --git a/tests/coverage/nodes.py b/tests/coverage/nodes.py index c79d1463f..e12a2d9cc 100644 --- a/tests/coverage/nodes.py +++ b/tests/coverage/nodes.py @@ -59,7 +59,7 @@ def get_arg_types(self): return [arg.get_base_type() for arg in self.args] def get_signature(self): - return f"{self.func_name}({', '.join([arg.type for arg in self.args])})" + return f"{self.func_name}({', '.join([arg.type for arg in self.args])}) = {self.get_return_type()}" @dataclass diff --git a/tests/coverage/test_coverage.py b/tests/coverage/test_coverage.py index 7d9adc66a..0e700ae1a 100644 --- a/tests/coverage/test_coverage.py +++ b/tests/coverage/test_coverage.py @@ -1,7 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 +import os + import pytest from antlr4 import InputStream from tests.coverage.case_file_parser import parse_stream, parse_one_file +from tests.coverage.extensions import Extension from tests.coverage.visitor import ParseError from tests.coverage.nodes import CaseLiteral @@ -411,6 +414,10 @@ def test_parse_errors_with_bad_aggregate_testcases(input_func_test, expected_mes "f37('1991-01-01T01:02:03.123456'::pts<6>, '1991-01-01T04:05:06.456'::precision_timestamp<6>) = 123456::i64", "f38('1991-01-01T01:02:03.456+05:30'::ptstz<3>) = '1991-01-01T00:00:00+15:30'::precision_timestamp_tz<3>", "f39('1991-01-01T01:02:03.123456+05:30'::ptstz<6>) = '1991-01-01T00:00:00+15:30'::precision_timestamp_tz<6>", + "logb(10::fp64, -inf::fp64) [on_domain_error:ERROR] = ", + "bitwise_and(-31766::dec<5, 0>, 900::dec<3, 0>) = 896::dec<5, 0>", + "or(true::bool, true::bool) = true::bool", + "between(5::i8, 0::i8, 127::i8) = true::bool", ], ) def test_parse_various_scalar_func_argument_types(input_func_test): @@ -436,3 +443,47 @@ def test_parse_various_aggregate_scalar_func_argument_types(input_func_test): ) test_file = parse_string(header + input_func_test + "\n") assert len(test_file.testcases) == 1 + + +@pytest.mark.parametrize( + "func_name, func_args, func_ret, func_uri, expected_failure", + [ + # lt for i8 with correct uri + ("lt", ["i8", "i8"], "bool", "/extensions/functions_comparison.yaml", False), + ("add", ["i8", "i8"], "i8", "/extensions/functions_arithmetic.yaml", False), + ( + "add", + ["dec", "dec"], + "dec", + "/extensions/functions_arithmetic_decimal.yaml", + False, + ), + ( + "bitwise_xor", + ["dec", "dec"], + "dec", + "/extensions/functions_arithmetic_decimal.yaml", + False, + ), + # negative case, lt for i8 with wrong uri + ("lt", ["i8", "i8"], "bool", "/extensions/functions_datetime.yaml", True), + ( + "add", + ["i8", "i8"], + "i8", + "/extensions/functions_arithmetic_decimal.yaml", + True, + ), + ("add", ["dec", "dec"], "dec", "/extensions/functions_arithmetic.yaml", True), + ("max", ["dec", "dec"], "dec", "/extensions/functions_arithmetic.yaml", True), + ], +) +def test_uri_match_in_get_function( + func_name, func_args, func_ret, func_uri, expected_failure +): + script_dir = os.path.dirname(os.path.abspath(__file__)) + extensions_path = os.path.join(script_dir, "../../extensions") + registry = Extension.read_substrait_extensions(extensions_path) + + function = registry.get_function(func_name, func_uri, func_args, func_ret) + assert (function is None) == expected_failure