From 06827619802f5ecec79bc89446e0d422052d261c Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 20 Jun 2025 11:49:40 +0000 Subject: [PATCH 1/4] preliminary fix Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index 76903ccd..ab210e30 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -11,6 +11,7 @@ ResultCompression, WaitTimeout, ) +from databricks.sql.thrift_api.TCLIService import ttypes if TYPE_CHECKING: from databricks.sql.client import Cursor @@ -402,7 +403,7 @@ def execute_command( lz4_compression: bool, cursor: "Cursor", use_cloud_fetch: bool, - parameters: List[Dict[str, Any]], + parameters: List[ttypes.TSparkParameter], async_op: bool, enforce_embedded_schema_correctness: bool, ) -> Union["ResultSet", None]: @@ -436,9 +437,9 @@ def execute_command( for param in parameters: sea_parameters.append( StatementParameter( - name=param["name"], - value=param["value"], - type=param["type"] if "type" in param else None, + name=param.name, + value=param.value, + type=param.type if param.type else None, ) ) From 53a555f7902f598b9c997ea9e7c5407dd9c9a4c7 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 20 Jun 2025 11:54:56 +0000 Subject: [PATCH 2/4] update unit tests to pass ttypes.TSparkParameter Signed-off-by: varun-edachali-dbx --- tests/unit/test_sea_backend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_sea_backend.py b/tests/unit/test_sea_backend.py index f30c92ed..a2c1aeb0 100644 --- a/tests/unit/test_sea_backend.py +++ b/tests/unit/test_sea_backend.py @@ -13,6 +13,7 @@ _filter_session_configuration, ) from databricks.sql.backend.types import SessionId, CommandId, CommandState, BackendType +from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.types import SSLOptions from databricks.sql.auth.authenticators import AuthProvider from databricks.sql.exc import ( @@ -354,7 +355,7 @@ def test_command_execution_advanced( "status": {"state": "SUCCEEDED"}, } mock_http_client._make_request.return_value = execute_response - param = {"name": "param1", "value": "value1", "type": "STRING"} + param = ttypes.TSparkParameter(name="param1", value="value1", type="STRING") with patch.object(sea_client, "get_execution_result"): sea_client.execute_command( From 5a800fd8bc7e45cb747adef0f393fa2d0e63c446 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 20 Jun 2025 12:28:13 +0000 Subject: [PATCH 3/4] remove explicit check for type attr in param Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index ab210e30..ff0a259c 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -439,7 +439,7 @@ def execute_command( StatementParameter( name=param.name, value=param.value, - type=param.type if param.type else None, + type=param.type, ) ) From cba08ac5dbdff9df920e334157cba1aef4772a48 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 8 Jul 2025 18:58:40 +0530 Subject: [PATCH 4/4] run primitive parameterised query tests Signed-off-by: varun-edachali-dbx --- src/databricks/sql/backend/sea/backend.py | 4 +- tests/e2e/test_parameterized_queries.py | 207 +++++++++++++++++++--- 2 files changed, 184 insertions(+), 27 deletions(-) diff --git a/src/databricks/sql/backend/sea/backend.py b/src/databricks/sql/backend/sea/backend.py index f4ccdaf8..0bcc14a9 100644 --- a/src/databricks/sql/backend/sea/backend.py +++ b/src/databricks/sql/backend/sea/backend.py @@ -439,7 +439,9 @@ def execute_command( sea_parameters.append( StatementParameter( name=param.name, - value=param.value, + value=( + None if param.value is None else param.value.stringValue + ), type=param.type, ) ) diff --git a/tests/e2e/test_parameterized_queries.py b/tests/e2e/test_parameterized_queries.py index 79def9b7..453e5944 100644 --- a/tests/e2e/test_parameterized_queries.py +++ b/tests/e2e/test_parameterized_queries.py @@ -168,7 +168,13 @@ def patch_server_supports_native_params(self, supports_native_params: bool = Tru finally: pass - def _inline_roundtrip(self, params: dict, paramstyle: ParamStyle, target_column): + def _inline_roundtrip( + self, + params: dict, + paramstyle: ParamStyle, + target_column, + extra_params: dict = {}, + ): """This INSERT, SELECT, DELETE dance is necessary because simply selecting ``` "SELECT %(param)s" @@ -183,7 +189,9 @@ def _inline_roundtrip(self, params: dict, paramstyle: ParamStyle, target_column) SELECT_QUERY = f"SELECT {target_column} `col` FROM pysql_e2e_inline_param_test_table LIMIT 1" DELETE_QUERY = "DELETE FROM pysql_e2e_inline_param_test_table" - with self.connection(extra_params={"use_inline_params": True}) as conn: + with self.connection( + extra_params={"use_inline_params": True, **extra_params} + ) as conn: with conn.cursor() as cursor: cursor.execute(INSERT_QUERY, parameters=params) with conn.cursor() as cursor: @@ -198,6 +206,7 @@ def _native_roundtrip( parameters: Union[Dict, List[Dict]], paramstyle: ParamStyle, parameter_structure: ParameterStructure, + extra_params: dict = {}, ): if parameter_structure == ParameterStructure.POSITIONAL: _query = self.POSITIONAL_PARAMSTYLE_QUERY @@ -205,7 +214,9 @@ def _native_roundtrip( _query = self.NAMED_PARAMSTYLE_QUERY elif paramstyle == ParamStyle.PYFORMAT: _query = self.PYFORMAT_PARAMSTYLE_QUERY - with self.connection(extra_params={"use_inline_params": False}) as conn: + with self.connection( + extra_params={"use_inline_params": False, **extra_params} + ) as conn: with conn.cursor() as cursor: cursor.execute(_query, parameters=parameters) return cursor.fetchone() @@ -216,6 +227,7 @@ def _get_one_result( approach: ParameterApproach = ParameterApproach.NONE, paramstyle: ParamStyle = ParamStyle.NONE, parameter_structure: ParameterStructure = ParameterStructure.NONE, + extra_params: dict = {}, ): """When approach is INLINE then we use %(param)s paramstyle and a connection with use_inline_params=True When approach is NATIVE then we use :param paramstyle and a connection with use_inline_params=False @@ -228,12 +240,16 @@ def _get_one_result( params, paramstyle=ParamStyle.PYFORMAT, target_column=self._get_inline_table_column(params.get("p")), + extra_params=extra_params, ) elif approach == ParameterApproach.NATIVE: # native mode can use either ParamStyle.NAMED or ParamStyle.PYFORMAT # native mode can use either ParameterStructure.NAMED or ParameterStructure.POSITIONAL return self._native_roundtrip( - params, paramstyle=paramstyle, parameter_structure=parameter_structure + params, + paramstyle=paramstyle, + parameter_structure=parameter_structure, + extra_params=extra_params, ) def _quantize(self, input: Union[float, int], place_value=2) -> Decimal: @@ -379,7 +395,20 @@ def test_dbsqlparameter_single( assert self._eq(result.col, primitive) @pytest.mark.parametrize("use_inline_params", (True, False, "silent")) - def test_use_inline_off_by_default_with_warning(self, use_inline_params, caplog): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_use_inline_off_by_default_with_warning( + self, use_inline_params, caplog, extra_params + ): """ use_inline_params should be False by default. If a user explicitly sets use_inline_params, don't warn them about it. @@ -389,7 +418,7 @@ def test_use_inline_off_by_default_with_warning(self, use_inline_params, caplog) {"use_inline_params": use_inline_params} if use_inline_params else {} ) - with self.connection(extra_params=extra_args) as conn: + with self.connection(extra_params={**extra_args, **extra_params}) as conn: with conn.cursor() as cursor: with self.patch_server_supports_native_params( supports_native_params=True @@ -404,9 +433,20 @@ def test_use_inline_off_by_default_with_warning(self, use_inline_params, caplog) "Consider using native parameters." not in caplog.text ), "Log message should not be supressed" - def test_positional_native_params_with_defaults(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_positional_native_params_with_defaults(self, extra_params): query = "SELECT ? col" - with self.cursor() as cursor: + with self.cursor(extra_params=extra_params) as cursor: result = cursor.execute(query, parameters=[1]).fetchone() assert result.col == 1 @@ -422,10 +462,23 @@ def test_positional_native_params_with_defaults(self): ["foo", "bar", "baz"], ), ) - def test_positional_native_multiple(self, params): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_positional_native_multiple(self, params, extra_params): query = "SELECT ? `foo`, ? `bar`, ? `baz`" - with self.cursor(extra_params={"use_inline_params": False}) as cursor: + with self.cursor( + extra_params={"use_inline_params": False, **extra_params} + ) as cursor: result = cursor.execute(query, params).fetchone() expected = [i.value if isinstance(i, DbsqlParameterBase) else i for i in params] @@ -433,8 +486,19 @@ def test_positional_native_multiple(self, params): assert set(outcome) == set(expected) - def test_readme_example(self): - with self.cursor() as cursor: + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_readme_example(self, extra_params): + with self.cursor(extra_params=extra_params) as cursor: result = cursor.execute( "SELECT :param `p`, * FROM RANGE(10)", {"param": "foo"} ).fetchall() @@ -498,11 +562,24 @@ def test_native_recursive_complex_type( class TestInlineParameterSyntax(PySQLPytestTestCase): """The inline parameter approach uses pyformat markers""" - def test_params_as_dict(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_params_as_dict(self, extra_params): query = "SELECT %(foo)s foo, %(bar)s bar, %(baz)s baz" params = {"foo": 1, "bar": 2, "baz": 3} - with self.connection(extra_params={"use_inline_params": True}) as conn: + with self.connection( + extra_params={"use_inline_params": True, **extra_params} + ) as conn: with conn.cursor() as cursor: result = cursor.execute(query, parameters=params).fetchone() @@ -510,7 +587,18 @@ def test_params_as_dict(self): assert result.bar == 2 assert result.baz == 3 - def test_params_as_sequence(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_params_as_sequence(self, extra_params): """One side-effect of ParamEscaper using Python string interpolation to inline the values is that it can work with "ordinal" parameters, but only if a user writes parameter markers that are not defined with PEP-249. This test exists to prove that it works in the ideal case. @@ -520,27 +608,53 @@ def test_params_as_sequence(self): query = "SELECT %s foo, %s bar, %s baz" params = (1, 2, 3) - with self.connection(extra_params={"use_inline_params": True}) as conn: + with self.connection( + extra_params={"use_inline_params": True, **extra_params} + ) as conn: with conn.cursor() as cursor: result = cursor.execute(query, parameters=params).fetchone() assert result.foo == 1 assert result.bar == 2 assert result.baz == 3 - def test_inline_ordinals_can_break_sql(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_inline_ordinals_can_break_sql(self, extra_params): """With inline mode, ordinal parameters can break the SQL syntax because `%` symbols are used to wildcard match within LIKE statements. This test just proves that's the case. """ query = "SELECT 'samsonite', %s WHERE 'samsonite' LIKE '%sonite'" params = ["luggage"] - with self.cursor(extra_params={"use_inline_params": True}) as cursor: + with self.cursor( + extra_params={"use_inline_params": True, **extra_params} + ) as cursor: with pytest.raises( TypeError, match="not enough arguments for format string" ): cursor.execute(query, parameters=params) - def test_inline_named_dont_break_sql(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_inline_named_dont_break_sql(self, extra_params): """With inline mode, ordinal parameters can break the SQL syntax because `%` symbols are used to wildcard match within LIKE statements. This test just proves that's the case. @@ -550,39 +664,80 @@ def test_inline_named_dont_break_sql(self): SELECT col_1 FROM base WHERE col_1 LIKE CONCAT(%(one)s, 'onite') """ params = {"one": "%(one)s"} - with self.cursor(extra_params={"use_inline_params": True}) as cursor: + with self.cursor( + extra_params={"use_inline_params": True, **extra_params} + ) as cursor: result = cursor.execute(query, parameters=params).fetchone() print("hello") - def test_native_ordinals_dont_break_sql(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_native_ordinals_dont_break_sql(self, extra_params): """This test accompanies test_inline_ordinals_can_break_sql to prove that ordinal parameters work in native mode for the exact same query, if we use the right marker `?` """ query = "SELECT 'samsonite', ? WHERE 'samsonite' LIKE '%sonite'" params = ["luggage"] - with self.cursor(extra_params={"use_inline_params": False}) as cursor: + with self.cursor( + extra_params={"use_inline_params": False, **extra_params} + ) as cursor: result = cursor.execute(query, parameters=params).fetchone() assert result.samsonite == "samsonite" assert result.luggage == "luggage" - def test_inline_like_wildcard_breaks(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_inline_like_wildcard_breaks(self, extra_params): """One flaw with the ParameterEscaper is that it fails if a query contains a SQL LIKE wildcard %. This test proves that's the case. """ query = "SELECT 1 `col` WHERE 'foo' LIKE '%'" params = {"param": "bar"} - with self.cursor(extra_params={"use_inline_params": True}) as cursor: + with self.cursor( + extra_params={"use_inline_params": True, **extra_params} + ) as cursor: with pytest.raises(ValueError, match="unsupported format character"): result = cursor.execute(query, parameters=params).fetchone() - def test_native_like_wildcard_works(self): + @pytest.mark.parametrize( + "extra_params", + [ + {}, + { + "use_sea": True, + "use_cloud_fetch": False, + "enable_query_result_lz4_compression": False, + }, + ], + ) + def test_native_like_wildcard_works(self, extra_params): """This is a mirror of test_inline_like_wildcard_breaks that proves that LIKE wildcards work under the native approach. """ query = "SELECT 1 `col` WHERE 'foo' LIKE '%'" params = {"param": "bar"} - with self.cursor(extra_params={"use_inline_params": False}) as cursor: + with self.cursor( + extra_params={"use_inline_params": False, **extra_params} + ) as cursor: result = cursor.execute(query, parameters=params).fetchone() assert result.col == 1