Skip to content

Commit

Permalink
Merge pull request #46 from Intreecom/feature/decimals
Browse files Browse the repository at this point in the history
  • Loading branch information
s3rius authored Mar 30, 2024
2 parents 4b56c56 + 86149b8 commit ef36c04
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 7 deletions.
2 changes: 0 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ pyo3 = { version = "0.20.0", features = [
"abi3-py38",
"extension-module",
"chrono",
"rust_decimal",
] }
rust_decimal = "1.0"
pyo3-asyncio = { version = "0.20.0", features = ["tokio-runtime"] }
pyo3-log = "0.9.0"
rustc-hash = "1.1.0"
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,4 @@ convention = "pep257"
ignore-decorators = ["typing.overload"]

[tool.ruff.pylint]
allow-magic-value-types = ["int", "str", "float", "tuple"]
allow-magic-value-types = ["int", "str", "float"]
3 changes: 3 additions & 0 deletions python/tests/test_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import ipaddress
import random
import uuid
from decimal import Decimal
from typing import Any, Callable

import pytest
Expand Down Expand Up @@ -30,6 +31,8 @@
("UUID", uuid.uuid5(uuid.uuid4(), "name")),
("INET", ipaddress.ip_address("192.168.1.1")),
("INET", ipaddress.ip_address("2001:db8::8a2e:370:7334")),
("DECIMAL", Decimal("1.1")),
("DECIMAL", Decimal("1.112e10")),
],
)
async def test_bindings(
Expand Down
34 changes: 30 additions & 4 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ pub enum ScyllaPyCQLDTO {
Counter(i64),
Bool(bool),
Double(eq_float::F64),
Decimal(bigdecimal_04::BigDecimal),
Float(eq_float::F32),
Bytes(Vec<u8>),
Date(chrono::NaiveDate),
Expand Down Expand Up @@ -131,11 +132,12 @@ impl Value for ScyllaPyCQLDTO {
ScyllaPyCQLDTO::Timestamp(timestamp) => {
scylla::frame::value::CqlTimestamp::from(*timestamp).serialize(buf)
}
ScyllaPyCQLDTO::Null => Option::<i16>::None.serialize(buf),
ScyllaPyCQLDTO::Null => Option::<bool>::None.serialize(buf),
ScyllaPyCQLDTO::Udt(udt) => {
buf.extend(udt);
Ok(())
}
ScyllaPyCQLDTO::Decimal(decimal) => decimal.serialize(buf),
ScyllaPyCQLDTO::Unset => scylla::frame::value::Unset.serialize(buf),
}
}
Expand Down Expand Up @@ -247,6 +249,12 @@ pub fn py_to_value(
Ok(ScyllaPyCQLDTO::Time(chrono::NaiveTime::from_str(
item.call_method0("isoformat")?.extract::<&str>()?,
)?))
} else if item.get_type().name()? == "Decimal" {
Ok(ScyllaPyCQLDTO::Decimal(
bigdecimal_04::BigDecimal::from_str(item.str()?.to_str()?).map_err(|err| {
ScyllaPyError::BindingError(format!("Cannot parse decimal {err}"))
})?,
))
} else if item.get_type().name()? == "datetime" {
let milliseconds = item.call_method0("timestamp")?.extract::<f64>()? * 1000f64;
#[allow(clippy::cast_possible_truncation)]
Expand Down Expand Up @@ -576,9 +584,27 @@ pub fn cql_to_py<'a>(
}
Ok(res_map)
}
ColumnType::Custom(_) | ColumnType::Varint | ColumnType::Decimal => Err(
ScyllaPyError::ValueDowncastError(col_name.into(), "Unknown"),
),
ColumnType::Decimal => {
// Because the `as_decimal` method is not implemented for `CqlValue`,
// will make a PR.
let decimal: bigdecimal_04::BigDecimal = match unwrapped_value {
CqlValue::Decimal(inner) => inner.clone().into(),
_ => {
return Err(ScyllaPyError::ValueDowncastError(
col_name.into(),
"Decimal",
))
}
};
Ok(py
.import("decimal")?
.getattr("Decimal")?
.call1((decimal.to_scientific_notation(),))?)
}
ColumnType::Custom(_) | ColumnType::Varint => Err(ScyllaPyError::ValueDowncastError(
col_name.into(),
"Unknown",
)),
}
}

Expand Down

0 comments on commit ef36c04

Please sign in to comment.