diff --git a/src/aiodynamo/expressions.py b/src/aiodynamo/expressions.py index 305d799..c7e6fc8 100644 --- a/src/aiodynamo/expressions.py +++ b/src/aiodynamo/expressions.py @@ -19,7 +19,7 @@ from .errors import CannotAddToNestedField from .types import AttributeType, Numeric, ParametersDict -from .utils import deparametetrize, low_level_serialize +from .utils import MinLen2AppendOnlyList, deparametetrize, low_level_serialize _ParametersCache = Dict[Tuple[Any, Any], str] @@ -377,10 +377,20 @@ def encode(self, params: Parameters) -> str: class Condition(metaclass=abc.ABCMeta): def __and__(self, other: Condition) -> Condition: - return AndCondition(self, other) + if isinstance(self, AndCondition): + if isinstance(other, AndCondition): + return AndCondition(self.children.extending(other.children)) + else: + return AndCondition(self.children.appending(other)) + return AndCondition(MinLen2AppendOnlyList.create(self, other)) def __or__(self, other: Condition) -> Condition: - return OrCondition(self, other) + if isinstance(self, OrCondition): + if isinstance(other, OrCondition): + return OrCondition(self.children.extending(other.children)) + else: + return OrCondition(self.children.appending(other)) + return OrCondition(MinLen2AppendOnlyList.create(self, other)) def __invert__(self) -> Condition: return NotCondition(self) @@ -408,20 +418,18 @@ def encode(self, params: Parameters) -> str: @dataclass(frozen=True) class AndCondition(Condition): - lhs: Condition - rhs: Condition + children: MinLen2AppendOnlyList[Condition] def encode(self, params: Parameters) -> str: - return f"({self.lhs.encode(params)} AND {self.rhs.encode(params)})" + return "(" + " AND ".join(child.encode(params) for child in self.children) + ")" @dataclass(frozen=True) class OrCondition(Condition): - lhs: Condition - rhs: Condition + children: MinLen2AppendOnlyList[Condition] def encode(self, params: Parameters) -> str: - return f"({self.lhs.encode(params)} OR {self.rhs.encode(params)})" + return "(" + " OR ".join(child.encode(params) for child in self.children) + ")" @dataclass(frozen=True) diff --git a/src/aiodynamo/utils.py b/src/aiodynamo/utils.py index 0258857..f70248e 100644 --- a/src/aiodynamo/utils.py +++ b/src/aiodynamo/utils.py @@ -5,6 +5,7 @@ import decimal import logging from collections import abc as collections_abc +from dataclasses import dataclass from functools import reduce from typing import ( TYPE_CHECKING, @@ -12,8 +13,12 @@ Awaitable, Callable, Dict, + Generator, + Generic, + Iterable, List, Mapping, + Self, Set, Tuple, TypeVar, @@ -197,3 +202,43 @@ def deparametetrize( for key, value in params.names.items(): expression = expression.replace(key, value) return expression + + +@dataclass(frozen=True) +class MinLen2AppendOnlyList(Generic[T]): + first: T + second: T + rest: tuple[T, ...] + + @classmethod + def create(cls, first: T, second: T, *rest: T) -> Self: + return cls(first, second, rest) + + def appending(self, value: T) -> MinLen2AppendOnlyList[T]: + return MinLen2AppendOnlyList(self.first, self.second, (*self.rest, value)) + + def extending(self, values: Iterable[T]) -> MinLen2AppendOnlyList[T]: + return MinLen2AppendOnlyList(self.first, self.second, (*self.rest, *values)) + + def __contains__(self, item: Any) -> bool: + return item == self.first or item == self.second or item in self.rest + + def __getitem__(self, index: int) -> T: + if index == 0: + return self.first + elif index == 1: + return self.second + return self.rest[index - 2] + + def __len__(self) -> int: + return len(self.rest) + 2 + + def __iter__(self) -> Generator[T, None, None]: + yield self.first + yield self.second + yield from self.rest + + def __reversed__(self) -> Generator[T, None, None]: + yield from reversed(self.rest) + yield self.second + yield self.first diff --git a/tests/unit/test_expressions.py b/tests/unit/test_expressions.py index f6a2b39..1662bd5 100644 --- a/tests/unit/test_expressions.py +++ b/tests/unit/test_expressions.py @@ -3,13 +3,17 @@ import pytest from aiodynamo.expressions import ( + AndCondition, + Comparison, Condition, F, HashKey, + OrCondition, Parameters, ProjectionExpression, UpdateExpression, ) +from aiodynamo.utils import MinLen2AppendOnlyList @pytest.mark.parametrize( @@ -115,6 +119,10 @@ def test_f_repr(f: F, r: str) -> None: [ (F("a").equals(True) & F("b").gt(1), "(a = True AND b > 1)"), (F("a", 1).begins_with("foo"), "begins_with(a[1], 'foo')"), + ( + F("a").equals("a") & F("b").equals("b") & F("c").equals("c"), + "(a = 'a' AND b = 'b' AND c = 'c')", + ), ], ) def test_condition_debug(expr: Condition, expected: str) -> None: @@ -133,3 +141,90 @@ def test_condition_debug(expr: Condition, expected: str) -> None: ) def test_update_expression_debug(expr: UpdateExpression, expected: str) -> None: assert expr.debug(int) == expected + + +@pytest.mark.parametrize( + "expr,expected", + [ + ( + F("a").equals("a") & F("b").equals("b"), + AndCondition( + MinLen2AppendOnlyList.create( + Comparison(F("a"), "=", "a"), Comparison(F("b"), "=", "b") + ) + ), + ), + ( + (F("a").equals("a") & F("b").equals("b")) & F("c").equals("c"), + AndCondition( + MinLen2AppendOnlyList.create( + Comparison(F("a"), "=", "a"), + Comparison(F("b"), "=", "b"), + Comparison(F("c"), "=", "c"), + ) + ), + ), + ( + (F("a").equals("a") & F("b").equals("b")) + & (F("c").equals("c") & F("d").equals("d")), + AndCondition( + MinLen2AppendOnlyList.create( + Comparison(F("a"), "=", "a"), + Comparison(F("b"), "=", "b"), + Comparison(F("c"), "=", "c"), + Comparison(F("d"), "=", "d"), + ) + ), + ), + ( + F("a").equals("a") | F("b").equals("b"), + OrCondition( + MinLen2AppendOnlyList.create( + Comparison(F("a"), "=", "a"), Comparison(F("b"), "=", "b") + ) + ), + ), + ( + (F("a").equals("a") | F("b").equals("b")) | F("c").equals("c"), + OrCondition( + MinLen2AppendOnlyList.create( + Comparison(F("a"), "=", "a"), + Comparison(F("b"), "=", "b"), + Comparison(F("c"), "=", "c"), + ) + ), + ), + ( + (F("a").equals("a") | F("b").equals("b")) + | (F("c").equals("c") | F("d").equals("d")), + OrCondition( + MinLen2AppendOnlyList.create( + Comparison(F("a"), "=", "a"), + Comparison(F("b"), "=", "b"), + Comparison(F("c"), "=", "c"), + Comparison(F("d"), "=", "d"), + ) + ), + ), + ( + (F("a").equals("a") | F("b").equals("b")) + & (F("c").equals("c") | F("d").equals("d")), + AndCondition( + MinLen2AppendOnlyList.create( + OrCondition( + MinLen2AppendOnlyList.create( + Comparison(F("a"), "=", "a"), Comparison(F("b"), "=", "b") + ) + ), + OrCondition( + MinLen2AppendOnlyList.create( + Comparison(F("c"), "=", "c"), Comparison(F("d"), "=", "d") + ) + ), + ) + ), + ), + ], +) +def test_condition_flattening(expr: Condition, expected: Condition) -> None: + assert expr == expected