From f5fb9e7cea6ed66fa8890b96eaca01b84bc1567f Mon Sep 17 00:00:00 2001 From: Ishmeet Bindra Date: Thu, 11 Jan 2024 18:49:04 +0530 Subject: [PATCH] Introduced DB Config --- pydanticrud/backends/dynamodb.py | 2 +- pydanticrud/backends/sqlite.py | 6 +++--- pydanticrud/main.py | 6 +++--- tests/test_dynamodb.py | 14 +++++++------- tests/test_model.py | 4 ++-- tests/test_sqlite.py | 2 +- 6 files changed, 17 insertions(+), 17 deletions(-) diff --git a/pydanticrud/backends/dynamodb.py b/pydanticrud/backends/dynamodb.py index 9ee85dd..01db4d0 100644 --- a/pydanticrud/backends/dynamodb.py +++ b/pydanticrud/backends/dynamodb.py @@ -212,7 +212,7 @@ def __init__(self, cls, result, serialized_items): class Backend: def __init__(self, cls): - cfg = cls.model_config + cfg = cls.db_config self.cls = cls self.schema = cls.schema() self.hash_key = cfg.get("hash_key") diff --git a/pydanticrud/backends/sqlite.py b/pydanticrud/backends/sqlite.py index 3f8e5c0..41e58fb 100644 --- a/pydanticrud/backends/sqlite.py +++ b/pydanticrud/backends/sqlite.py @@ -45,7 +45,7 @@ def get_column_data(field_type): class Backend: def __init__(self, cls): - cfg = cls.model_config + cfg = cls.db_config self.hash_key = cfg.get("hash_key") self.table_name = cls.get_table_name() @@ -162,8 +162,8 @@ def get(self, item_key): def save(self, item, condition: Optional[Rule] = None) -> bool: table_name = item.get_table_name() - hash_key = item.model_config.get("hash_key") - key = item.model_config.get("hash_key") + hash_key = item.db_config.get("hash_key") + key = item.db_config.get("hash_key") fields = tuple(self._columns.keys()) item_data = item.dict() diff --git a/pydanticrud/main.py b/pydanticrud/main.py index a3fa417..c1ad5ce 100644 --- a/pydanticrud/main.py +++ b/pydanticrud/main.py @@ -5,8 +5,8 @@ class CrudMetaClass(ModelMetaclass): def __new__(mcs, name, bases, namespace, **kwargs): cls = super().__new__(mcs, name, bases, namespace, **kwargs) - if hasattr(cls, "model_config") and "backend" in cls.model_config: - cls.__backend__ = cls.model_config["backend"](cls) + if hasattr(cls, "db_config") and "backend" in cls.db_config: + cls.__backend__ = cls.db_config["backend"](cls) return cls @@ -43,7 +43,7 @@ def initialize(cls): @classmethod def get_table_name(cls) -> str: - return cls.model_config.get("title").lower() + return cls.db_config.get("title").lower() @classmethod def exists(cls) -> bool: diff --git a/tests/test_dynamodb.py b/tests/test_dynamodb.py index 9f47fb5..4ad3297 100644 --- a/tests/test_dynamodb.py +++ b/tests/test_dynamodb.py @@ -33,7 +33,7 @@ class SimpleKeyModel(BaseModel): data: Dict[int, int] = {} items: List[int] hash: UUID - model_config = ConfigDict(title="ModelTitle123", hash_key="name", ttl="expires", backend=DynamoDbBackend, endpoint="http://localhost:18002", global_indexes={"by-id": ("id",)}) + db_config = ConfigDict(title="ModelTitle123", hash_key="name", ttl="expires", backend=DynamoDbBackend, endpoint="http://localhost:18002", global_indexes={"by-id": ("id",)}) class AliasKeyModel(BaseModel): @@ -48,7 +48,7 @@ def type_from_typ(cls, values): if 'typ' in values: values['type'] = values.pop('typ') return values - model_config = ConfigDict(title="AliasTitle123", hash_key="name", backend=DynamoDbBackend, endpoint="http://localhost:18002") + db_config = ConfigDict(title="AliasTitle123", hash_key="name", backend=DynamoDbBackend, endpoint="http://localhost:18002") class ComplexKeyModel(BaseModel): @@ -59,7 +59,7 @@ class ComplexKeyModel(BaseModel): notification_id: str thread_id: str body: str = "some random string" - model_config = ConfigDict(title="ComplexModelTitle123", hash_key="account", range_key="sort_date_key", backend=DynamoDbBackend, endpoint="http://localhost:18002", local_indexes={ + db_config = ConfigDict(title="ComplexModelTitle123", hash_key="account", range_key="sort_date_key", backend=DynamoDbBackend, endpoint="http://localhost:18002", local_indexes={ "by-category": ("account", "category_id"), "by-notification": ("account", "notification_id"), "by-thread": ("account", "thread_id") @@ -82,7 +82,7 @@ class NestedModel(BaseModel): expires: str ticket: Optional[Ticket] other: Union[Ticket, SomethingElse] - model_config = ConfigDict(title="NestedModelTitle123", hash_key="account", range_key="sort_date_key", backend=DynamoDbBackend, endpoint="http://localhost:18002") + db_config = ConfigDict(title="NestedModelTitle123", hash_key="account", range_key="sort_date_key", backend=DynamoDbBackend, endpoint="http://localhost:18002") def alias_model_data_generator(**kwargs): @@ -231,7 +231,7 @@ def complex_query_data(complex_table): yield data finally: for datum in data: - ComplexKeyModel.delete((datum[ComplexKeyModel.model_config.get("hash_key")], datum[ComplexKeyModel.model_config.get("range_key")])) + ComplexKeyModel.delete((datum[ComplexKeyModel.db_config.get("hash_key")], datum[ComplexKeyModel.db_config.get("range_key")])) @pytest.fixture(scope="module") @@ -258,7 +258,7 @@ def nested_query_data(nested_table): yield data finally: for datum in data: - NestedModel.delete((datum[NestedModel.model_config.get("hash_key")], datum[NestedModel.model_config.get("range_key")])) + NestedModel.delete((datum[NestedModel.db_config.get("hash_key")], datum[NestedModel.db_config.get("range_key")])) @pytest.fixture @@ -271,7 +271,7 @@ def nested_query_data_empty_ticket(nested_table): yield data finally: for datum in data: - NestedModel.delete((datum[NestedModel.model_config.get("hash_key")], datum[NestedModel.model_config.get("range_key")])) + NestedModel.delete((datum[NestedModel.db_config.get("hash_key")], datum[NestedModel.db_config.get("range_key")])) def test_save_get_delete_simple(dynamo, simple_table): diff --git a/tests/test_model.py b/tests/test_model.py index 1184ee8..3ebc59e 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -21,7 +21,7 @@ class Model(BaseModel): id: int name: str total: float - model_config = ConfigDict(title="ModelTitle123", backend=FalseBackend) + db_config = ConfigDict(title="ModelTitle123", backend=FalseBackend) def test_model_has_backend_methods(): @@ -55,4 +55,4 @@ def test_model_backend_query(): def test_model_table_name_from_title(): - assert Model.get_table_name() == Model.model_config.get("title").lower() + assert Model.get_table_name() == Model.db_config.get("title").lower() diff --git a/tests/test_sqlite.py b/tests/test_sqlite.py index fef8f7f..bb73669 100644 --- a/tests/test_sqlite.py +++ b/tests/test_sqlite.py @@ -19,7 +19,7 @@ class Model(BaseModel): enabled: bool data: Dict[str, str] items: List[int] - model_config = ConfigDict(title="ModelTitle123", hash_key="id", backend=SqliteBackend, database=":memory:") + db_config = ConfigDict(title="ModelTitle123", hash_key="id", backend=SqliteBackend, database=":memory:") @pytest.fixture()