Skip to content

Commit

Permalink
Merge pull request #173 from HENNGE/flattened-conditions
Browse files Browse the repository at this point in the history
Flatten AND and OR conditions
  • Loading branch information
ojii authored Jan 19, 2024
2 parents 4f63cb1 + 371a1dc commit b628c31
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 8 deletions.
57 changes: 49 additions & 8 deletions src/aiodynamo/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
Any,
Callable,
Dict,
Generator,
Iterable,
Iterator,
List,
Optional,
Expand Down Expand Up @@ -377,10 +379,24 @@ def encode(self, params: Parameters) -> str:

class Condition(metaclass=abc.ABCMeta):
def __and__(self, other: Condition) -> Condition:
return AndCondition(self, other)
if isinstance(self, AndCondition):
if isinstance(other, AndCondition):
return AndCondition(self.children.extending(other.children))
else:
return AndCondition(self.children.appending(other))
elif isinstance(other, AndCondition):
return AndCondition(other.children.prepending(self))
return AndCondition(SubConditions.create(self, other))

def __or__(self, other: Condition) -> Condition:
return OrCondition(self, other)
if isinstance(self, OrCondition):
if isinstance(other, OrCondition):
return OrCondition(self.children.extending(other.children))
else:
return OrCondition(self.children.appending(other))
elif isinstance(other, OrCondition):
return OrCondition(other.children.prepending(self))
return OrCondition(SubConditions.create(self, other))

def __invert__(self) -> Condition:
return NotCondition(self)
Expand Down Expand Up @@ -408,20 +424,18 @@ def encode(self, params: Parameters) -> str:

@dataclass(frozen=True)
class AndCondition(Condition):
lhs: Condition
rhs: Condition
children: SubConditions

def encode(self, params: Parameters) -> str:
return f"({self.lhs.encode(params)} AND {self.rhs.encode(params)})"
return "(" + " AND ".join(child.encode(params) for child in self.children) + ")"


@dataclass(frozen=True)
class OrCondition(Condition):
lhs: Condition
rhs: Condition
children: SubConditions

def encode(self, params: Parameters) -> str:
return f"({self.lhs.encode(params)} OR {self.rhs.encode(params)})"
return "(" + " OR ".join(child.encode(params) for child in self.children) + ")"


@dataclass(frozen=True)
Expand Down Expand Up @@ -659,3 +673,30 @@ def __and__(self, field: F) -> ProjectionExpression:

def encode(self, params: Parameters) -> str:
return ",".join(params.encode_path(field.path) for field in self.fields)


@dataclass(frozen=True)
class SubConditions:
first: Condition
second: Condition
rest: tuple[Condition, ...]

@classmethod
def create(
cls, first: Condition, second: Condition, *rest: Condition
) -> SubConditions:
return cls(first, second, rest)

def prepending(self, value: Condition) -> SubConditions:
return SubConditions(value, self.first, (self.second, *self.rest))

def appending(self, value: Condition) -> SubConditions:
return SubConditions(self.first, self.second, (*self.rest, value))

def extending(self, values: Iterable[Condition]) -> SubConditions:
return SubConditions(self.first, self.second, (*self.rest, *values))

def __iter__(self) -> Generator[Condition, None, None]:
yield self.first
yield self.second
yield from self.rest
115 changes: 115 additions & 0 deletions tests/unit/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
import pytest

from aiodynamo.expressions import (
AndCondition,
Comparison,
Condition,
F,
HashKey,
OrCondition,
Parameters,
ProjectionExpression,
SubConditions,
UpdateExpression,
)

Expand Down Expand Up @@ -115,6 +119,10 @@ def test_f_repr(f: F, r: str) -> None:
[
(F("a").equals(True) & F("b").gt(1), "(a = True AND b > 1)"),
(F("a", 1).begins_with("foo"), "begins_with(a[1], 'foo')"),
(
F("a").equals("a") & F("b").equals("b") & F("c").equals("c"),
"(a = 'a' AND b = 'b' AND c = 'c')",
),
],
)
def test_condition_debug(expr: Condition, expected: str) -> None:
Expand All @@ -133,3 +141,110 @@ def test_condition_debug(expr: Condition, expected: str) -> None:
)
def test_update_expression_debug(expr: UpdateExpression, expected: str) -> None:
assert expr.debug(int) == expected


@pytest.mark.parametrize(
"expr,expected",
[
(
F("a").equals("a") & F("b").equals("b"),
AndCondition(
SubConditions.create(
Comparison(F("a"), "=", "a"), Comparison(F("b"), "=", "b")
)
),
),
(
(F("a").equals("a") & F("b").equals("b")) & F("c").equals("c"),
AndCondition(
SubConditions.create(
Comparison(F("a"), "=", "a"),
Comparison(F("b"), "=", "b"),
Comparison(F("c"), "=", "c"),
)
),
),
(
F("a").equals("a") & (F("b").equals("b") & F("c").equals("c")),
AndCondition(
SubConditions.create(
Comparison(F("a"), "=", "a"),
Comparison(F("b"), "=", "b"),
Comparison(F("c"), "=", "c"),
)
),
),
(
(F("a").equals("a") & F("b").equals("b"))
& (F("c").equals("c") & F("d").equals("d")),
AndCondition(
SubConditions.create(
Comparison(F("a"), "=", "a"),
Comparison(F("b"), "=", "b"),
Comparison(F("c"), "=", "c"),
Comparison(F("d"), "=", "d"),
)
),
),
(
F("a").equals("a") | F("b").equals("b"),
OrCondition(
SubConditions.create(
Comparison(F("a"), "=", "a"), Comparison(F("b"), "=", "b")
)
),
),
(
(F("a").equals("a") | F("b").equals("b")) | F("c").equals("c"),
OrCondition(
SubConditions.create(
Comparison(F("a"), "=", "a"),
Comparison(F("b"), "=", "b"),
Comparison(F("c"), "=", "c"),
)
),
),
(
F("a").equals("a") | (F("b").equals("b") | F("c").equals("c")),
OrCondition(
SubConditions.create(
Comparison(F("a"), "=", "a"),
Comparison(F("b"), "=", "b"),
Comparison(F("c"), "=", "c"),
)
),
),
(
(F("a").equals("a") | F("b").equals("b"))
| (F("c").equals("c") | F("d").equals("d")),
OrCondition(
SubConditions.create(
Comparison(F("a"), "=", "a"),
Comparison(F("b"), "=", "b"),
Comparison(F("c"), "=", "c"),
Comparison(F("d"), "=", "d"),
)
),
),
(
(F("a").equals("a") | F("b").equals("b"))
& (F("c").equals("c") | F("d").equals("d")),
AndCondition(
SubConditions.create(
OrCondition(
SubConditions.create(
Comparison(F("a"), "=", "a"), Comparison(F("b"), "=", "b")
)
),
OrCondition(
SubConditions.create(
Comparison(F("c"), "=", "c"), Comparison(F("d"), "=", "d")
)
),
)
),
),
],
)
def test_condition_flattening(expr: Condition, expected: Condition) -> None:
assert expr == expected

0 comments on commit b628c31

Please sign in to comment.