diff --git a/pyproject.toml b/pyproject.toml index 66a0730b..6289347d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -110,7 +110,7 @@ optional-dependencies.release = [ optional-dependencies.test = [ "cratedb-toolkit[testing]", "dask[dataframe]", - "pandas<2.3", + "pandas[test]<2.3", "pueblo>=0.0.7", "pytest<9", "pytest-cov<7", diff --git a/src/sqlalchemy_cratedb/compiler.py b/src/sqlalchemy_cratedb/compiler.py index 7b2c5ccd..6ee43a23 100644 --- a/src/sqlalchemy_cratedb/compiler.py +++ b/src/sqlalchemy_cratedb/compiler.py @@ -254,6 +254,30 @@ def visit_TIMESTAMP(self, type_, **kw): """ return "TIMESTAMP %s" % ((type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE",) + def visit_BLOB(self, type_, **kw): + return "STRING" + + def visit_FLOAT(self, type_, **kw): + """ + From `sqlalchemy.sql.sqltypes.Float`. + + When a :paramref:`.Float.precision` is not provided in a + :class:`_types.Float` type some backend may compile this type as + an 8 bytes / 64 bit float datatype. To use a 4 bytes / 32 bit float + datatype a precision <= 24 can usually be provided or the + :class:`_types.REAL` type can be used. + This is known to be the case in the PostgreSQL and MSSQL dialects + that render the type as ``FLOAT`` that's in both an alias of + ``DOUBLE PRECISION``. Other third party dialects may have similar + behavior. + """ + if not type_.precision: + return "FLOAT" + elif type_.precision <= 24: + return "FLOAT" + else: + return "DOUBLE" + class CrateCompiler(compiler.SQLCompiler): def visit_getitem_binary(self, binary, operator, **kw): diff --git a/src/sqlalchemy_cratedb/dialect.py b/src/sqlalchemy_cratedb/dialect.py index 90102a78..4302f7eb 100644 --- a/src/sqlalchemy_cratedb/dialect.py +++ b/src/sqlalchemy_cratedb/dialect.py @@ -34,6 +34,7 @@ ) from .sa_version import SA_1_4, SA_2_0, SA_VERSION from .type import FloatVector, ObjectArray, ObjectType +from .type.binary import LargeBinary TYPES_MAP = { "boolean": sqltypes.Boolean, @@ -158,6 +159,7 @@ def process(value): sqltypes.Date: Date, sqltypes.DateTime: DateTime, sqltypes.TIMESTAMP: DateTime, + sqltypes.LargeBinary: LargeBinary, } diff --git a/src/sqlalchemy_cratedb/type/__init__.py b/src/sqlalchemy_cratedb/type/__init__.py index b524bb39..6d92e0e2 100644 --- a/src/sqlalchemy_cratedb/type/__init__.py +++ b/src/sqlalchemy_cratedb/type/__init__.py @@ -1,4 +1,5 @@ from .array import ObjectArray +from .binary import LargeBinary from .geo import Geopoint, Geoshape from .object import ObjectType from .vector import FloatVector, knn_match @@ -6,6 +7,7 @@ __all__ = [ Geopoint, Geoshape, + LargeBinary, ObjectArray, ObjectType, FloatVector, diff --git a/src/sqlalchemy_cratedb/type/binary.py b/src/sqlalchemy_cratedb/type/binary.py new file mode 100644 index 00000000..04b04073 --- /dev/null +++ b/src/sqlalchemy_cratedb/type/binary.py @@ -0,0 +1,44 @@ +import base64 + +import sqlalchemy as sa + + +class LargeBinary(sa.String): + """A type for large binary byte data. + + The :class:`.LargeBinary` type corresponds to a large and/or unlengthed + binary type for the target platform, such as BLOB on MySQL and BYTEA for + PostgreSQL. It also handles the necessary conversions for the DBAPI. + + """ + + __visit_name__ = "large_binary" + + def bind_processor(self, dialect): + if dialect.dbapi is None: + return None + + # TODO: DBAPIBinary = dialect.dbapi.Binary + + def process(value): + if value is not None: + # TODO: return DBAPIBinary(value) + return base64.b64encode(value).decode() + else: + return None + + return process + + # Python 3 has native bytes() type + # both sqlite3 and pg8000 seem to return it, + # psycopg2 as of 2.5 returns 'memoryview' + def result_processor(self, dialect, coltype): + if dialect.returns_native_bytes: + return None + + def process(value): + if value is not None: + return base64.b64decode(value) + return value + + return process diff --git a/tests/test_support_pandas.py b/tests/test_support_pandas.py index 47fe9c7a..ce8bb6e2 100644 --- a/tests/test_support_pandas.py +++ b/tests/test_support_pandas.py @@ -1,7 +1,9 @@ import re import sys +import pandas as pd import pytest +from pandas._testing import assert_equal from pueblo.testing.pandas import makeTimeDataFrame from sqlalchemy.exc import ProgrammingError @@ -15,6 +17,18 @@ df = makeTimeDataFrame(nper=INSERT_RECORDS, freq="S") df["time"] = df.index +float_double_data = { + "col_1": [19556.88, 629414.27, 51570.0, 2933.52, 20338.98], + "col_2": [ + 15379.920000000002, + 1107140.42, + 8081.999999999999, + 1570.0300000000002, + 29468.539999999997, + ], +} +float_double_df = pd.DataFrame.from_dict(float_double_data) + @pytest.mark.skipif( sys.version_info < (3, 8), reason="Feature not supported on Python 3.7 and earlier" @@ -113,3 +127,33 @@ def test_table_kwargs_unknown(cratedb_service): "passed to [ALTER | CREATE] TABLE statement]" ) ) + + +@pytest.mark.skipif( + sys.version_info < (3, 8), reason="Feature not supported on Python 3.7 and earlier" +) +@pytest.mark.skipif( + SA_VERSION < SA_2_0, reason="Feature not supported on SQLAlchemy 1.4 and earlier" +) +def test_float_double(cratedb_service): + """ + Validate I/O with floating point numbers, specifically DOUBLE types. + + Motto: Do not lose precision when DOUBLE is required. + """ + tablename = "pandas_double" + engine = cratedb_service.database.engine + float_double_df.to_sql( + tablename, + engine, + if_exists="replace", + index=False, + ) + cratedb_service.database.run_sql(f"REFRESH TABLE {tablename}") + df_load = pd.read_sql_table(tablename, engine) + + before = float_double_df.sort_values(by="col_1", ignore_index=True) + after = df_load.sort_values(by="col_1", ignore_index=True) + + pd.options.display.float_format = "{:.12f}".format + assert_equal(before, after, check_exact=True)