diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index cfb27adb..0bcc14a9 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -14,6 +14,7 @@ WaitTimeout, MetadataCommands, ) +from databricks.sql.thrift_api.TCLIService import ttypes if TYPE_CHECKING: from databricks.sql.client import Cursor @@ -402,7 +403,7 @@ def execute_command( lz4_compression: bool, cursor: Cursor, use_cloud_fetch: bool, - parameters: List[Dict[str, Any]], + parameters: List[ttypes.TSparkParameter], async_op: bool, enforce_embedded_schema_correctness: bool, row_limit: Optional[int] = None, @@ -437,9 +438,11 @@ def execute_command( for param in parameters: sea_parameters.append( StatementParameter( - name=param["name"], - value=param["value"], - type=param["type"] if "type" in param else None, + name=param.name, + value=( + None if param.value is None else param.value.stringValue + ), + type=param.type, ) ) diff --git a/tests/e2e/test_parameterized_queries.py b/tests/e2e/test_parameterized_queries.py index 79def9b7..453e5944 100644 --- a/tests/e2e/test_parameterized_queries.py +++ b/tests/e2e/test_parameterized_queries.py @@ -168,7 +168,13 @@ def patch_server_supports_native_params(self, supports_native_params: bool = Tru finally: pass - def _inline_roundtrip(self, params: dict, paramstyle: ParamStyle, target_column): + def _inline_roundtrip( + self, + params: dict, + paramstyle: ParamStyle, + target_column, + extra_params: dict = {}, + ): """This INSERT, SELECT, DELETE dance is necessary because simply selecting ``` "SELECT %(param)s" @@ -183,7 +189,9 @@ def _inline_roundtrip(self, params: dict, paramstyle: ParamStyle, target_column) SELECT_QUERY = f"SELECT {target_column} `col` FROM pysql_e2e_inline_param_test_table LIMIT 1" DELETE_QUERY = "DELETE FROM pysql_e2e_inline_param_test_table" - with self.connection(extra_params={"use_inline_params": True}) as conn: + with self.connection( + extra_params={"use_inline_params": True, **extra_params} + ) as conn: with conn.cursor() as cursor: cursor.execute(INSERT_QUERY, parameters=params) with conn.cursor() as cursor: @@ -198,6 +206,7 @@ def _native_roundtrip( parameters: Union[Dict, List[Dict]], paramstyle: ParamStyle, parameter_structure: ParameterStructure, + extra_params: dict = {}, ): if parameter_structure == ParameterStructure.POSITIONAL: _query = self.POSITIONAL_PARAMSTYLE_QUERY @@ -205,7 +214,9 @@ def _native_roundtrip( _query = self.NAMED_PARAMSTYLE_QUERY elif paramstyle == ParamStyle.PYFORMAT: _query = self.PYFORMAT_PARAMSTYLE_QUERY - with self.connection(extra_params={"use_inline_params": False}) as conn: + with self.connection( + extra_params={"use_inline_params": False, **extra_params} + ) as conn: with conn.cursor() as cursor: cursor.execute(_query, parameters=parameters) return cursor.fetchone() @@ -216,6 +227,7 @@ def _get_one_result( approach: ParameterApproach = ParameterApproach.NONE, paramstyle: ParamStyle = ParamStyle.NONE, parameter_structure: ParameterStructure = ParameterStructure.NONE, + extra_params: dict = {}, ): """When approach is INLINE then we use %(param)s paramstyle and a connection with use_inline_params=True When approach is NATIVE then we use :param paramstyle and a connection with use_inline_params=False @@ -228,12 +240,16 @@ def _get_one_result( params, paramstyle=ParamStyle.PYFORMAT, target_column=self._get_inline_table_column(params.get("p")), + extra_params=extra_params, ) elif approach == ParameterApproach.NATIVE: # native mode can use either ParamStyle.NAMED or ParamStyle.PYFORMAT # native mode can use either ParameterStructure.NAMED or ParameterStructure.POSITIONAL return self._native_roundtrip( - params, paramstyle=paramstyle, parameter_structure=parameter_structure + params, + paramstyle=paramstyle, + parameter_structure=parameter_structure, + extra_params=extra_params, ) def _quantize(self, input: Union[float, int], place_value=2) -> Decimal: @@ -379,7 +395,20 @@ def test_dbsqlparameter_single( assert self._eq(result.col, primitive) @pytest.mark.parametrize("use_inline_params", (True, False, "silent")) - def test_use_inline_off_by_default_with_warning(self, use_inline_params, caplog): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_use_inline_off_by_default_with_warning( + self, use_inline_params, caplog, extra_params + ): """ use_inline_params should be False by default. If a user explicitly sets use_inline_params, don't warn them about it. @@ -389,7 +418,7 @@ def test_use_inline_off_by_default_with_warning(self, use_inline_params, caplog) {"use_inline_params": use_inline_params} if use_inline_params else {} ) - with self.connection(extra_params=extra_args) as conn: + with self.connection(extra_params={**extra_args, **extra_params}) as conn: with conn.cursor() as cursor: with self.patch_server_supports_native_params( supports_native_params=True @@ -404,9 +433,20 @@ def test_use_inline_off_by_default_with_warning(self, use_inline_params, caplog) "Consider using native parameters." not in caplog.text ), "Log message should not be supressed" - def test_positional_native_params_with_defaults(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_positional_native_params_with_defaults(self, extra_params): query = "SELECT ? col" - with self.cursor() as cursor: + with self.cursor(extra_params=extra_params) as cursor: result = cursor.execute(query, parameters=[1]).fetchone() assert result.col == 1 @@ -422,10 +462,23 @@ def test_positional_native_params_with_defaults(self): ["foo", "bar", "baz"], ), ) - def test_positional_native_multiple(self, params): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_positional_native_multiple(self, params, extra_params): query = "SELECT ? `foo`, ? `bar`, ? `baz`" - with self.cursor(extra_params={"use_inline_params": False}) as cursor: + with self.cursor( + extra_params={"use_inline_params": False, **extra_params} + ) as cursor: result = cursor.execute(query, params).fetchone() expected = [i.value if isinstance(i, DbsqlParameterBase) else i for i in params] @@ -433,8 +486,19 @@ def test_positional_native_multiple(self, params): assert set(outcome) == set(expected) - def test_readme_example(self): - with self.cursor() as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_readme_example(self, extra_params): + with self.cursor(extra_params=extra_params) as cursor: result = cursor.execute( "SELECT :param `p`, * FROM RANGE(10)", {"param": "foo"} ).fetchall() @@ -498,11 +562,24 @@ def test_native_recursive_complex_type( class TestInlineParameterSyntax(PySQLPytestTestCase): """The inline parameter approach uses pyformat markers""" - def test_params_as_dict(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_params_as_dict(self, extra_params): query = "SELECT %(foo)s foo, %(bar)s bar, %(baz)s baz" params = {"foo": 1, "bar": 2, "baz": 3} - with self.connection(extra_params={"use_inline_params": True}) as conn: + with self.connection( + extra_params={"use_inline_params": True, **extra_params} + ) as conn: with conn.cursor() as cursor: result = cursor.execute(query, parameters=params).fetchone() @@ -510,7 +587,18 @@ def test_params_as_dict(self): assert result.bar == 2 assert result.baz == 3 - def test_params_as_sequence(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_params_as_sequence(self, extra_params): """One side-effect of ParamEscaper using Python string interpolation to inline the values is that it can work with "ordinal" parameters, but only if a user writes parameter markers that are not defined with PEP-249. This test exists to prove that it works in the ideal case. @@ -520,27 +608,53 @@ def test_params_as_sequence(self): query = "SELECT %s foo, %s bar, %s baz" params = (1, 2, 3) - with self.connection(extra_params={"use_inline_params": True}) as conn: + with self.connection( + extra_params={"use_inline_params": True, **extra_params} + ) as conn: with conn.cursor() as cursor: result = cursor.execute(query, parameters=params).fetchone() assert result.foo == 1 assert result.bar == 2 assert result.baz == 3 - def test_inline_ordinals_can_break_sql(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_inline_ordinals_can_break_sql(self, extra_params): """With inline mode, ordinal parameters can break the SQL syntax because `%` symbols are used to wildcard match within LIKE statements. This test just proves that's the case. """ query = "SELECT 'samsonite', %s WHERE 'samsonite' LIKE '%sonite'" params = ["luggage"] - with self.cursor(extra_params={"use_inline_params": True}) as cursor: + with self.cursor( + extra_params={"use_inline_params": True, **extra_params} + ) as cursor: with pytest.raises( TypeError, match="not enough arguments for format string" ): cursor.execute(query, parameters=params) - def test_inline_named_dont_break_sql(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_inline_named_dont_break_sql(self, extra_params): """With inline mode, ordinal parameters can break the SQL syntax because `%` symbols are used to wildcard match within LIKE statements. This test just proves that's the case. @@ -550,39 +664,80 @@ def test_inline_named_dont_break_sql(self): SELECT col_1 FROM base WHERE col_1 LIKE CONCAT(%(one)s, 'onite') """ params = {"one": "%(one)s"} - with self.cursor(extra_params={"use_inline_params": True}) as cursor: + with self.cursor( + extra_params={"use_inline_params": True, **extra_params} + ) as cursor: result = cursor.execute(query, parameters=params).fetchone() print("hello") - def test_native_ordinals_dont_break_sql(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_native_ordinals_dont_break_sql(self, extra_params): """This test accompanies test_inline_ordinals_can_break_sql to prove that ordinal parameters work in native mode for the exact same query, if we use the right marker `?` """ query = "SELECT 'samsonite', ? WHERE 'samsonite' LIKE '%sonite'" params = ["luggage"] - with self.cursor(extra_params={"use_inline_params": False}) as cursor: + with self.cursor( + extra_params={"use_inline_params": False, **extra_params} + ) as cursor: result = cursor.execute(query, parameters=params).fetchone() assert result.samsonite == "samsonite" assert result.luggage == "luggage" - def test_inline_like_wildcard_breaks(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_inline_like_wildcard_breaks(self, extra_params): """One flaw with the ParameterEscaper is that it fails if a query contains a SQL LIKE wildcard %. This test proves that's the case. """ query = "SELECT 1 `col` WHERE 'foo' LIKE '%'" params = {"param": "bar"} - with self.cursor(extra_params={"use_inline_params": True}) as cursor: + with self.cursor( + extra_params={"use_inline_params": True, **extra_params} + ) as cursor: with pytest.raises(ValueError, match="unsupported format character"): result = cursor.execute(query, parameters=params).fetchone() - def test_native_like_wildcard_works(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_native_like_wildcard_works(self, extra_params): """This is a mirror of test_inline_like_wildcard_breaks that proves that LIKE wildcards work under the native approach. """ query = "SELECT 1 `col` WHERE 'foo' LIKE '%'" params = {"param": "bar"} - with self.cursor(extra_params={"use_inline_params": False}) as cursor: + with self.cursor( + extra_params={"use_inline_params": False, **extra_params} + ) as cursor: result = cursor.execute(query, parameters=params).fetchone() assert result.col == 1 diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index 7eae8e5a..aec9d347 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -13,6 +13,7 @@ _filter_session_configuration, ) from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType +from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider from databricks.sql.exc import ( @@ -355,7 +356,7 @@ def test_command_execution_advanced( "status": {"state": "SUCCEEDED"}, } mock_http_client._make_request.return_value = execute_response - param = {"name": "param1", "value": "value1", "type": "STRING"} + param = ttypes.TSparkParameter(name="param1", value="value1", type="STRING") with patch.object(sea_client, "get_execution_result"): sea_client.execute_command(