From 11aa5993ad45f2fd3dae38ded62b90c6f065f0f0 Mon Sep 17 00:00:00 2001 From: Andrew Bloom Date: Mon, 22 Apr 2024 09:07:01 -0500 Subject: [PATCH 1/2] allow pinot table aliases --- pinotdb/sqlalchemy.py | 3 --- tests/unit/test_sqlalchemy.py | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/pinotdb/sqlalchemy.py b/pinotdb/sqlalchemy.py index d964d97..6047a4b 100644 --- a/pinotdb/sqlalchemy.py +++ b/pinotdb/sqlalchemy.py @@ -21,9 +21,6 @@ def visit_select(self, select, **kwargs): return super().visit_select(select, **kwargs) def visit_column(self, column, result_map=None, **kwargs): - # Pinot does not support table aliases - if column.table is not None: - column.table.named_with_column = False result_map = result_map or kwargs.pop("add_to_result_map", None) # This is a hack to modify the original column, but how do I clone it ? column.is_literal = True diff --git a/tests/unit/test_sqlalchemy.py b/tests/unit/test_sqlalchemy.py index 31b7c4f..d5b4ce7 100644 --- a/tests/unit/test_sqlalchemy.py +++ b/tests/unit/test_sqlalchemy.py @@ -344,7 +344,7 @@ def test_can_select_table_directly(self): self.assertEqual( str(compiler), - 'SELECT some_column \nFROM some_table', + 'SELECT some_table.some_column \nFROM some_table', ) From f81e5c42eab990905784f91cd4af76e2904622f2 Mon Sep 17 00:00:00 2001 From: Andrew Bloom Date: Mon, 22 Apr 2024 15:57:50 -0500 Subject: [PATCH 2/2] Fix types --- pinotdb/sqlalchemy.py | 4 ++-- tests/unit/test_sqlalchemy.py | 26 +++++++++++++------------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/pinotdb/sqlalchemy.py b/pinotdb/sqlalchemy.py index 6047a4b..b5486c1 100644 --- a/pinotdb/sqlalchemy.py +++ b/pinotdb/sqlalchemy.py @@ -58,7 +58,7 @@ def visit_REAL(self, type_, **kwargs): return "DOUBLE" def visit_NUMERIC(self, type_, **kwargs): - return "LONG" + return "NUMERIC" visit_DECIMAL = visit_NUMERIC visit_INTEGER = visit_NUMERIC @@ -69,7 +69,7 @@ def visit_NUMERIC(self, type_, **kwargs): visit_DATE = visit_NUMERIC def visit_CHAR(self, type_, **kwargs): - return "STRING" + return "VARCHAR" visit_NCHAR = visit_CHAR visit_VARCHAR = visit_CHAR diff --git a/tests/unit/test_sqlalchemy.py b/tests/unit/test_sqlalchemy.py index d5b4ce7..1393e7e 100644 --- a/tests/unit/test_sqlalchemy.py +++ b/tests/unit/test_sqlalchemy.py @@ -357,44 +357,44 @@ def test_compiles_real(self): self.assertEqual(self.compiler.visit_REAL(None), 'DOUBLE') def test_compiles_numeric(self): - self.assertEqual(self.compiler.visit_NUMERIC(None), 'LONG') + self.assertEqual(self.compiler.visit_NUMERIC(None), 'NUMERIC') def test_compiles_decimal(self): - self.assertEqual(self.compiler.visit_DECIMAL(None), 'LONG') + self.assertEqual(self.compiler.visit_DECIMAL(None), 'NUMERIC') def test_compiles_integer(self): - self.assertEqual(self.compiler.visit_INTEGER(None), 'LONG') + self.assertEqual(self.compiler.visit_INTEGER(None), 'NUMERIC') def test_compiles_smallint(self): - self.assertEqual(self.compiler.visit_SMALLINT(None), 'LONG') + self.assertEqual(self.compiler.visit_SMALLINT(None), 'NUMERIC') def test_compiles_bigint(self): - self.assertEqual(self.compiler.visit_BIGINT(None), 'LONG') + self.assertEqual(self.compiler.visit_BIGINT(None), 'NUMERIC') # TODO: Check if this is correct (seems strange to have boolean as long). def test_compiles_boolean(self): - self.assertEqual(self.compiler.visit_BOOLEAN(None), 'LONG') + self.assertEqual(self.compiler.visit_BOOLEAN(None), 'NUMERIC') def test_compiles_timestamp(self): - self.assertEqual(self.compiler.visit_TIMESTAMP(None), 'LONG') + self.assertEqual(self.compiler.visit_TIMESTAMP(None), 'NUMERIC') def test_compiles_date(self): - self.assertEqual(self.compiler.visit_DATE(None), 'LONG') + self.assertEqual(self.compiler.visit_DATE(None), 'NUMERIC') def test_compiles_char(self): - self.assertEqual(self.compiler.visit_CHAR(None), 'STRING') + self.assertEqual(self.compiler.visit_CHAR(None), 'VARCHAR') def test_compiles_nchar(self): - self.assertEqual(self.compiler.visit_NCHAR(None), 'STRING') + self.assertEqual(self.compiler.visit_NCHAR(None), 'VARCHAR') def test_compiles_varchar(self): - self.assertEqual(self.compiler.visit_VARCHAR(None), 'STRING') + self.assertEqual(self.compiler.visit_VARCHAR(None), 'VARCHAR') def test_compiles_nvarchar(self): - self.assertEqual(self.compiler.visit_NVARCHAR(None), 'STRING') + self.assertEqual(self.compiler.visit_NVARCHAR(None), 'VARCHAR') def test_compiles_text(self): - self.assertEqual(self.compiler.visit_TEXT(None), 'STRING') + self.assertEqual(self.compiler.visit_TEXT(None), 'VARCHAR') def test_compiles_binary(self): self.assertEqual(self.compiler.visit_BINARY(None), 'BYTES')