Skip to content

SEA: support primitive params #612

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: sea-migration
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/databricks/sql/backend/sea/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
WaitTimeout,
MetadataCommands,
)
from databricks.sql.thrift_api.TCLIService import ttypes

if TYPE_CHECKING:
from databricks.sql.client import Cursor
Expand Down Expand Up @@ -402,7 +403,7 @@ def execute_command(
lz4_compression: bool,
cursor: Cursor,
use_cloud_fetch: bool,
parameters: List[Dict[str, Any]],
parameters: List[ttypes.TSparkParameter],
async_op: bool,
enforce_embedded_schema_correctness: bool,
row_limit: Optional[int] = None,
Expand Down Expand Up @@ -437,9 +438,11 @@ def execute_command(
for param in parameters:
sea_parameters.append(
StatementParameter(
name=param["name"],
value=param["value"],
type=param["type"] if "type" in param else None,
name=param.name,
value=(
None if param.value is None else param.value.stringValue
),
type=param.type,
)
)

Expand Down
207 changes: 181 additions & 26 deletions tests/e2e/test_parameterized_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,13 @@ def patch_server_supports_native_params(self, supports_native_params: bool = Tru
finally:
pass

def _inline_roundtrip(self, params: dict, paramstyle: ParamStyle, target_column):
def _inline_roundtrip(
self,
params: dict,
paramstyle: ParamStyle,
target_column,
extra_params: dict = {},
):
"""This INSERT, SELECT, DELETE dance is necessary because simply selecting
```
"SELECT %(param)s"
Expand All @@ -183,7 +189,9 @@ def _inline_roundtrip(self, params: dict, paramstyle: ParamStyle, target_column)
SELECT_QUERY = f"SELECT {target_column} `col` FROM pysql_e2e_inline_param_test_table LIMIT 1"
DELETE_QUERY = "DELETE FROM pysql_e2e_inline_param_test_table"

with self.connection(extra_params={"use_inline_params": True}) as conn:
with self.connection(
extra_params={"use_inline_params": True, **extra_params}
) as conn:
with conn.cursor() as cursor:
cursor.execute(INSERT_QUERY, parameters=params)
with conn.cursor() as cursor:
Expand All @@ -198,14 +206,17 @@ def _native_roundtrip(
parameters: Union[Dict, List[Dict]],
paramstyle: ParamStyle,
parameter_structure: ParameterStructure,
extra_params: dict = {},
):
if parameter_structure == ParameterStructure.POSITIONAL:
_query = self.POSITIONAL_PARAMSTYLE_QUERY
elif paramstyle == ParamStyle.NAMED:
_query = self.NAMED_PARAMSTYLE_QUERY
elif paramstyle == ParamStyle.PYFORMAT:
_query = self.PYFORMAT_PARAMSTYLE_QUERY
with self.connection(extra_params={"use_inline_params": False}) as conn:
with self.connection(
extra_params={"use_inline_params": False, **extra_params}
) as conn:
with conn.cursor() as cursor:
cursor.execute(_query, parameters=parameters)
return cursor.fetchone()
Expand All @@ -216,6 +227,7 @@ def _get_one_result(
approach: ParameterApproach = ParameterApproach.NONE,
paramstyle: ParamStyle = ParamStyle.NONE,
parameter_structure: ParameterStructure = ParameterStructure.NONE,
extra_params: dict = {},
):
"""When approach is INLINE then we use %(param)s paramstyle and a connection with use_inline_params=True
When approach is NATIVE then we use :param paramstyle and a connection with use_inline_params=False
Expand All @@ -228,12 +240,16 @@ def _get_one_result(
params,
paramstyle=ParamStyle.PYFORMAT,
target_column=self._get_inline_table_column(params.get("p")),
extra_params=extra_params,
)
elif approach == ParameterApproach.NATIVE:
# native mode can use either ParamStyle.NAMED or ParamStyle.PYFORMAT
# native mode can use either ParameterStructure.NAMED or ParameterStructure.POSITIONAL
return self._native_roundtrip(
params, paramstyle=paramstyle, parameter_structure=parameter_structure
params,
paramstyle=paramstyle,
parameter_structure=parameter_structure,
extra_params=extra_params,
)

def _quantize(self, input: Union[float, int], place_value=2) -> Decimal:
Expand Down Expand Up @@ -379,7 +395,20 @@ def test_dbsqlparameter_single(
assert self._eq(result.col, primitive)

@pytest.mark.parametrize("use_inline_params", (True, False, "silent"))
def test_use_inline_off_by_default_with_warning(self, use_inline_params, caplog):
@pytest.mark.parametrize(
"extra_params",
[
{},
{
"use_sea": True,
"use_cloud_fetch": False,
"enable_query_result_lz4_compression": False,
},
],
)
def test_use_inline_off_by_default_with_warning(
self, use_inline_params, caplog, extra_params
):
"""
use_inline_params should be False by default.
If a user explicitly sets use_inline_params, don't warn them about it.
Expand All @@ -389,7 +418,7 @@ def test_use_inline_off_by_default_with_warning(self, use_inline_params, caplog)
{"use_inline_params": use_inline_params} if use_inline_params else {}
)

with self.connection(extra_params=extra_args) as conn:
with self.connection(extra_params={**extra_args, **extra_params}) as conn:
with conn.cursor() as cursor:
with self.patch_server_supports_native_params(
supports_native_params=True
Expand All @@ -404,9 +433,20 @@ def test_use_inline_off_by_default_with_warning(self, use_inline_params, caplog)
"Consider using native parameters." not in caplog.text
), "Log message should not be supressed"

def test_positional_native_params_with_defaults(self):
@pytest.mark.parametrize(
"extra_params",
[
{},
{
"use_sea": True,
"use_cloud_fetch": False,
"enable_query_result_lz4_compression": False,
},
],
)
def test_positional_native_params_with_defaults(self, extra_params):
query = "SELECT ? col"
with self.cursor() as cursor:
with self.cursor(extra_params=extra_params) as cursor:
result = cursor.execute(query, parameters=[1]).fetchone()

assert result.col == 1
Expand All @@ -422,19 +462,43 @@ def test_positional_native_params_with_defaults(self):
["foo", "bar", "baz"],
),
)
def test_positional_native_multiple(self, params):
@pytest.mark.parametrize(
"extra_params",
[
{},
{
"use_sea": True,
"use_cloud_fetch": False,
"enable_query_result_lz4_compression": False,
},
],
)
def test_positional_native_multiple(self, params, extra_params):
query = "SELECT ? `foo`, ? `bar`, ? `baz`"

with self.cursor(extra_params={"use_inline_params": False}) as cursor:
with self.cursor(
extra_params={"use_inline_params": False, **extra_params}
) as cursor:
result = cursor.execute(query, params).fetchone()

expected = [i.value if isinstance(i, DbsqlParameterBase) else i for i in params]
outcome = [result.foo, result.bar, result.baz]

assert set(outcome) == set(expected)

def test_readme_example(self):
with self.cursor() as cursor:
@pytest.mark.parametrize(
"extra_params",
[
{},
{
"use_sea": True,
"use_cloud_fetch": False,
"enable_query_result_lz4_compression": False,
},
],
)
def test_readme_example(self, extra_params):
with self.cursor(extra_params=extra_params) as cursor:
result = cursor.execute(
"SELECT :param `p`, * FROM RANGE(10)", {"param": "foo"}
).fetchall()
Expand Down Expand Up @@ -498,19 +562,43 @@ def test_native_recursive_complex_type(
class TestInlineParameterSyntax(PySQLPytestTestCase):
"""The inline parameter approach uses pyformat markers"""

def test_params_as_dict(self):
@pytest.mark.parametrize(
"extra_params",
[
{},
{
"use_sea": True,
"use_cloud_fetch": False,
"enable_query_result_lz4_compression": False,
},
],
)
def test_params_as_dict(self, extra_params):
query = "SELECT %(foo)s foo, %(bar)s bar, %(baz)s baz"
params = {"foo": 1, "bar": 2, "baz": 3}

with self.connection(extra_params={"use_inline_params": True}) as conn:
with self.connection(
extra_params={"use_inline_params": True, **extra_params}
) as conn:
with conn.cursor() as cursor:
result = cursor.execute(query, parameters=params).fetchone()

assert result.foo == 1
assert result.bar == 2
assert result.baz == 3

def test_params_as_sequence(self):
@pytest.mark.parametrize(
"extra_params",
[
{},
{
"use_sea": True,
"use_cloud_fetch": False,
"enable_query_result_lz4_compression": False,
},
],
)
def test_params_as_sequence(self, extra_params):
"""One side-effect of ParamEscaper using Python string interpolation to inline the values
is that it can work with "ordinal" parameters, but only if a user writes parameter markers
that are not defined with PEP-249. This test exists to prove that it works in the ideal case.
Expand All @@ -520,27 +608,53 @@ def test_params_as_sequence(self):
query = "SELECT %s foo, %s bar, %s baz"
params = (1, 2, 3)

with self.connection(extra_params={"use_inline_params": True}) as conn:
with self.connection(
extra_params={"use_inline_params": True, **extra_params}
) as conn:
with conn.cursor() as cursor:
result = cursor.execute(query, parameters=params).fetchone()
assert result.foo == 1
assert result.bar == 2
assert result.baz == 3

def test_inline_ordinals_can_break_sql(self):
@pytest.mark.parametrize(
"extra_params",
[
{},
{
"use_sea": True,
"use_cloud_fetch": False,
"enable_query_result_lz4_compression": False,
},
],
)
def test_inline_ordinals_can_break_sql(self, extra_params):
"""With inline mode, ordinal parameters can break the SQL syntax
because `%` symbols are used to wildcard match within LIKE statements. This test
just proves that's the case.
"""
query = "SELECT 'samsonite', %s WHERE 'samsonite' LIKE '%sonite'"
params = ["luggage"]
with self.cursor(extra_params={"use_inline_params": True}) as cursor:
with self.cursor(
extra_params={"use_inline_params": True, **extra_params}
) as cursor:
with pytest.raises(
TypeError, match="not enough arguments for format string"
):
cursor.execute(query, parameters=params)

def test_inline_named_dont_break_sql(self):
@pytest.mark.parametrize(
"extra_params",
[
{},
{
"use_sea": True,
"use_cloud_fetch": False,
"enable_query_result_lz4_compression": False,
},
],
)
def test_inline_named_dont_break_sql(self, extra_params):
"""With inline mode, ordinal parameters can break the SQL syntax
because `%` symbols are used to wildcard match within LIKE statements. This test
just proves that's the case.
Expand All @@ -550,39 +664,80 @@ def test_inline_named_dont_break_sql(self):
SELECT col_1 FROM base WHERE col_1 LIKE CONCAT(%(one)s, 'onite')
"""
params = {"one": "%(one)s"}
with self.cursor(extra_params={"use_inline_params": True}) as cursor:
with self.cursor(
extra_params={"use_inline_params": True, **extra_params}
) as cursor:
result = cursor.execute(query, parameters=params).fetchone()
print("hello")

def test_native_ordinals_dont_break_sql(self):
@pytest.mark.parametrize(
"extra_params",
[
{},
{
"use_sea": True,
"use_cloud_fetch": False,
"enable_query_result_lz4_compression": False,
},
],
)
def test_native_ordinals_dont_break_sql(self, extra_params):
"""This test accompanies test_inline_ordinals_can_break_sql to prove that ordinal
parameters work in native mode for the exact same query, if we use the right marker `?`
"""
query = "SELECT 'samsonite', ? WHERE 'samsonite' LIKE '%sonite'"
params = ["luggage"]
with self.cursor(extra_params={"use_inline_params": False}) as cursor:
with self.cursor(
extra_params={"use_inline_params": False, **extra_params}
) as cursor:
result = cursor.execute(query, parameters=params).fetchone()

assert result.samsonite == "samsonite"
assert result.luggage == "luggage"

def test_inline_like_wildcard_breaks(self):
@pytest.mark.parametrize(
"extra_params",
[
{},
{
"use_sea": True,
"use_cloud_fetch": False,
"enable_query_result_lz4_compression": False,
},
],
)
def test_inline_like_wildcard_breaks(self, extra_params):
"""One flaw with the ParameterEscaper is that it fails if a query contains
a SQL LIKE wildcard %. This test proves that's the case.
"""
query = "SELECT 1 `col` WHERE 'foo' LIKE '%'"
params = {"param": "bar"}
with self.cursor(extra_params={"use_inline_params": True}) as cursor:
with self.cursor(
extra_params={"use_inline_params": True, **extra_params}
) as cursor:
with pytest.raises(ValueError, match="unsupported format character"):
result = cursor.execute(query, parameters=params).fetchone()

def test_native_like_wildcard_works(self):
@pytest.mark.parametrize(
"extra_params",
[
{},
{
"use_sea": True,
"use_cloud_fetch": False,
"enable_query_result_lz4_compression": False,
},
],
)
def test_native_like_wildcard_works(self, extra_params):
"""This is a mirror of test_inline_like_wildcard_breaks that proves that LIKE
wildcards work under the native approach.
"""
query = "SELECT 1 `col` WHERE 'foo' LIKE '%'"
params = {"param": "bar"}
with self.cursor(extra_params={"use_inline_params": False}) as cursor:
with self.cursor(
extra_params={"use_inline_params": False, **extra_params}
) as cursor:
result = cursor.execute(query, parameters=params).fetchone()

assert result.col == 1
Loading
Loading