diff --git a/src/aiodynamo/errors.py b/src/aiodynamo/errors.py index fb6896a..3bf81df 100644 --- a/src/aiodynamo/errors.py +++ b/src/aiodynamo/errors.py @@ -1,5 +1,10 @@ import json -from typing import Any, Dict +from typing import Any, Dict, List, Optional, TypedDict + + +class CancellationReason(TypedDict): + Code: Optional[str] + Message: Optional[str] class AIODynamoError(Exception): @@ -110,7 +115,12 @@ class PointInTimeRecoveryUnavailable(AIODynamoError): class TransactionCanceled(AIODynamoError): - pass + def __init__(self, body: Dict[str, Any]): + self.body = body + self.cancellation_reasons: List[CancellationReason] = self.body[ + "CancellationReasons" + ] + super().__init__(body) class TransactionEmpty(AIODynamoError): @@ -207,17 +217,7 @@ def exception_from_response(status: int, body: bytes) -> Exception: return ServiceUnavailable() try: data = json.loads(body) - error = ERRORS[data["__type"].split("#", 1)[-1]](data) - if isinstance(error, TransactionCanceled): - error = extract_error_from_transaction_canceled(data) + error: Exception = ERRORS[data["__type"].split("#", 1)[-1]](data) return error except Exception: return UnknownError(status, body) - - -def extract_error_from_transaction_canceled(data: Dict[str, Any]) -> AIODynamoError: - try: - error = data["CancellationReasons"][0] - return ERRORS[f"{error['Code']}Exception"](error["Message"]) - except Exception: - return ERRORS[data["__type"].split("#", 1)[-1]](data) diff --git a/tests/integration/test_client.py b/tests/integration/test_client.py index 15fa250..43cf070 100644 --- a/tests/integration/test_client.py +++ b/tests/integration/test_client.py @@ -624,7 +624,7 @@ async def test_transact_write_items_put(client: Client, table: TableName) -> Non await client.transact_write_items(items=puts) assert len([item async for item in client.query(table, HashKey("h", "h"))]) == 2 - with pytest.raises(errors.ConditionalCheckFailed): + with pytest.raises(errors.TransactionCanceled) as excinfo: put = Put( table=table, item={"h": "h", "r": "0", "s": "initial"}, @@ -632,6 +632,26 @@ async def test_transact_write_items_put(client: Client, table: TableName) -> Non ) await client.transact_write_items(items=[put]) + assert len(excinfo.value.cancellation_reasons) == 1 + assert excinfo.value.cancellation_reasons[0]["Code"] == "ConditionalCheckFailed" + + with pytest.raises(errors.TransactionCanceled) as excinfo: + put = Put( + table=table, + item={"h": "h", "r": "3", "s": "initial"}, + condition=F("h").does_not_exist(), + ) + put_fail = Put( + table=table, + item={"h": "h", "r": "0", "s": "initial"}, + condition=F("h").does_not_exist(), + ) + await client.transact_write_items(items=[put, put_fail]) + + assert len(excinfo.value.cancellation_reasons) == 2 + assert excinfo.value.cancellation_reasons[0]["Code"] == "None" + assert excinfo.value.cancellation_reasons[1]["Code"] == "ConditionalCheckFailed" + @pytest.mark.usefixtures("supports_transactions") async def test_transact_write_items_update(client: Client, table: TableName) -> None: @@ -647,7 +667,7 @@ async def test_transact_write_items_update(client: Client, table: TableName) -> query = await client.query_single_page(table, HashKey("h", "h")) assert query.items[0]["s"] == "changed" - with pytest.raises(errors.ConditionalCheckFailed): + with pytest.raises(errors.TransactionCanceled) as excinfo: update = Update( table=table, key={"h": "h", "r": "1"}, @@ -656,6 +676,9 @@ async def test_transact_write_items_update(client: Client, table: TableName) -> ) await client.transact_write_items(items=[update]) + assert len(excinfo.value.cancellation_reasons) == 1 + assert excinfo.value.cancellation_reasons[0]["Code"] == "ConditionalCheckFailed" + @pytest.mark.usefixtures("supports_transactions") async def test_transact_write_items_delete(client: Client, table: TableName) -> None: @@ -670,7 +693,8 @@ async def test_transact_write_items_delete(client: Client, table: TableName) -> assert len([item async for item in client.query(table, HashKey("h", "h"))]) == 0 await client.put_item(table=table, item={"h": "h", "r": "1", "s": "initial"}) - with pytest.raises(errors.ConditionalCheckFailed): + + with pytest.raises(errors.TransactionCanceled) as excinfo: delete = Delete( table=table, key={"h": "h", "r": "1"}, @@ -678,6 +702,9 @@ async def test_transact_write_items_delete(client: Client, table: TableName) -> ) await client.transact_write_items(items=[delete]) + assert len(excinfo.value.cancellation_reasons) == 1 + assert excinfo.value.cancellation_reasons[0]["Code"] == "ConditionalCheckFailed" + @pytest.mark.usefixtures("supports_transactions") async def test_transact_write_items_condition_check( @@ -687,9 +714,12 @@ async def test_transact_write_items_condition_check( condition = ConditionCheck( table=table, key={"h": "h", "r": "1"}, condition=F("s").not_equals("initial") ) - with pytest.raises(errors.ConditionalCheckFailed): + with pytest.raises(errors.TransactionCanceled) as excinfo: await client.transact_write_items(items=[condition]) + assert len(excinfo.value.cancellation_reasons) == 1 + assert excinfo.value.cancellation_reasons[0]["Code"] == "ConditionalCheckFailed" + condition = ConditionCheck( table=table, key={"h": "h", "r": "1"}, condition=F("s").equals("initial") )