diff --git a/docs/changelog.rst b/docs/changelog.rst index 3f4252a55..e666a8fe8 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -8,6 +8,7 @@ Development =========== - (Fill this out as you fix issues and develop your features). - EnumField improvements: now `choices` limits the values of an enum to allow +- renamed Doc._get_changed_fields into Doc._get_updated_fields Changes in 0.23.1 =========== diff --git a/mongoengine/base/__init__.py b/mongoengine/base/__init__.py index dca0c4bb7..c9197d102 100644 --- a/mongoengine/base/__init__.py +++ b/mongoengine/base/__init__.py @@ -13,6 +13,7 @@ __all__ = ( # common "UPDATE_OPERATORS", + "UNSET_SENTINEL", "_document_registry", "get_document", # datastructures diff --git a/mongoengine/base/common.py b/mongoengine/base/common.py index 85897324f..3d9e17071 100644 --- a/mongoengine/base/common.py +++ b/mongoengine/base/common.py @@ -1,6 +1,6 @@ from mongoengine.errors import NotRegistered -__all__ = ("UPDATE_OPERATORS", "get_document", "_document_registry") +__all__ = ("UPDATE_OPERATORS", "get_document", "_document_registry", "UNSET_SENTINEL") UPDATE_OPERATORS = { @@ -21,6 +21,7 @@ "rename", } +UNSET_SENTINEL = object() _document_registry = {} diff --git a/mongoengine/base/datastructures.py b/mongoengine/base/datastructures.py index a32f6040a..0acaaaeb3 100644 --- a/mongoengine/base/datastructures.py +++ b/mongoengine/base/datastructures.py @@ -38,21 +38,36 @@ def wrapper(self, key, *args, **kwargs): return wrapper +def mark_key_as_unset_wrapper(parent_method): + """Decorator that ensures _mark_as_unset method gets called with the key argument""" + + def wrapper(self, key, *args, **kwargs): + # Can't use super() in the decorator. + result = parent_method(self, key, *args, **kwargs) + self._mark_as_unset(key) + return result + + return wrapper + + class BaseDict(dict): """A special dict so we can watch any changes.""" _dereferenced = False - _instance = None - _name = None def __init__(self, dict_items, instance, name): BaseDocument = _import_class("BaseDocument") if isinstance(instance, BaseDocument): self._instance = weakref.proxy(instance) + else: + self._instance = instance self._name = name super().__init__(dict_items) + def _get_resolved_key(self, key=None): + return f"{self._name}.{key}" if key else self._name + def get(self, key, default=None): # get does not use __getitem__ by default so we must override it as well try: @@ -63,17 +78,18 @@ def get(self, key, default=None): def __getitem__(self, key): value = super().__getitem__(key) + resolved_key = self._get_resolved_key(key) + EmbeddedDocument = _import_class("EmbeddedDocument") if isinstance(value, EmbeddedDocument) and value._instance is None: value._instance = self._instance elif isinstance(value, dict) and not isinstance(value, BaseDict): - value = BaseDict(value, None, f"{self._name}.{key}") + value = BaseDict(value, None, resolved_key) super().__setitem__(key, value) value._instance = self._instance elif isinstance(value, list) and not isinstance(value, BaseList): - value = BaseList(value, None, f"{self._name}.{key}") + value = BaseList(value, self._instance, resolved_key) super().__setitem__(key, value) - value._instance = self._instance return value def __getstate__(self): @@ -86,37 +102,47 @@ def __setstate__(self, state): return self __setitem__ = mark_key_as_changed_wrapper(dict.__setitem__) - __delattr__ = mark_key_as_changed_wrapper(dict.__delattr__) - __delitem__ = mark_key_as_changed_wrapper(dict.__delitem__) + __delattr__ = mark_key_as_unset_wrapper(dict.__delattr__) + __delitem__ = mark_key_as_unset_wrapper(dict.__delitem__) pop = mark_as_changed_wrapper(dict.pop) clear = mark_as_changed_wrapper(dict.clear) update = mark_as_changed_wrapper(dict.update) popitem = mark_as_changed_wrapper(dict.popitem) setdefault = mark_as_changed_wrapper(dict.setdefault) + def _mark_as_unset(self, key): + resolved_key = self._get_resolved_key(key) + if hasattr(self._instance, "_mark_as_unset"): + self._instance._mark_as_unset(resolved_key) + def _mark_as_changed(self, key=None): + resolved_key = self._get_resolved_key(key) if hasattr(self._instance, "_mark_as_changed"): - if key: - self._instance._mark_as_changed(f"{self._name}.{key}") - else: - self._instance._mark_as_changed(self._name) + self._instance._mark_as_changed(resolved_key) class BaseList(list): """A special list so we can watch any changes.""" _dereferenced = False - _instance = None - _name = None def __init__(self, list_items, instance, name): BaseDocument = _import_class("BaseDocument") if isinstance(instance, BaseDocument): self._instance = weakref.proxy(instance) + else: + self._instance = instance + self._name = name super().__init__(list_items) + def _get_resolved_key(self, key=None): + if key is not None: + return f"{self._name}.{key % len(self)}" + else: + return self._name + def __getitem__(self, key): # change index to positive value because MongoDB does not support negative one if isinstance(key, int) and key < 0: @@ -128,19 +154,20 @@ def __getitem__(self, key): # to parent's instance. This is buggy for now but would require more work to be handled properly return value + resolved_key = self._get_resolved_key(key) + EmbeddedDocument = _import_class("EmbeddedDocument") if isinstance(value, EmbeddedDocument) and value._instance is None: value._instance = self._instance elif isinstance(value, dict) and not isinstance(value, BaseDict): # Replace dict by BaseDict - value = BaseDict(value, None, f"{self._name}.{key}") + value = BaseDict(value, None, resolved_key) super().__setitem__(key, value) value._instance = self._instance elif isinstance(value, list) and not isinstance(value, BaseList): # Replace list by BaseList - value = BaseList(value, None, f"{self._name}.{key}") + value = BaseList(value, self._instance, resolved_key) super().__setitem__(key, value) - value._instance = self._instance return value def __iter__(self): @@ -178,11 +205,9 @@ def __setitem__(self, key, value): __imul__ = mark_as_changed_wrapper(list.__imul__) def _mark_as_changed(self, key=None): + resolved_key = self._get_resolved_key(key) if hasattr(self._instance, "_mark_as_changed"): - if key is not None: - self._instance._mark_as_changed(f"{self._name}.{key % len(self)}") - else: - self._instance._mark_as_changed(self._name) + self._instance._mark_as_changed(resolved_key) class EmbeddedDocumentList(BaseList): diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index 46935c1b8..fc0d660de 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -1,12 +1,11 @@ import copy -import numbers from functools import partial import pymongo from bson import SON, DBRef, ObjectId, json_util from mongoengine import signals -from mongoengine.base.common import get_document +from mongoengine.base.common import UNSET_SENTINEL, get_document from mongoengine.base.datastructures import ( BaseDict, BaseList, @@ -34,14 +33,16 @@ class BaseDocument: # Currently, handling of `_changed_fields` seems unnecessarily convoluted: # 1. `BaseDocument` defines `_changed_fields` in its `__slots__`, yet it's # not setting it to `[]` (or any other value) in `__init__`. - # 2. `EmbeddedDocument` sets `_changed_fields` to `[]` it its overloaded + # 2. `EmbeddedDocument` sets `_changed_fields` to `[]` in its overloaded # `__init__`. # 3. `Document` does NOT set `_changed_fields` upon initialization. The # field is primarily set via `_from_son` or `_clear_changed_fields`, # though there are also other methods that manipulate it. + # This is done to avoid tracking changes on un-saved Documents # 4. The codebase is littered with `hasattr` calls for `_changed_fields`. __slots__ = ( "_changed_fields", + "_unset_fields", "_initialised", "_created", "_data", @@ -144,14 +145,24 @@ def __delattr__(self, *args, **kwargs): """Handle deletions of fields""" field_name = args[0] if field_name in self._fields: - default = self._fields[field_name].default - if callable(default): - default = default() - setattr(self, field_name, default) + setattr(self, field_name, UNSET_SENTINEL) else: super().__delattr__(*args, **kwargs) def __setattr__(self, name, value): + unset = value is UNSET_SENTINEL + if unset: + if name in self._fields: + default = self._fields[name].default + value = default() if callable(default) else default + else: + # dynamic field + value = None + self._mark_as_unset(name) + else: + # unmark anyway + self._unmark_as_unset(name) + # Handle dynamic data only if an initialised dynamic document if self._dynamic and not self._dynamic_lock: @@ -168,8 +179,7 @@ def __setattr__(self, name, value): # Handle marking data as changed if name in self._dynamic_fields: self._data[name] = value - if hasattr(self, "_changed_fields"): - self._mark_as_changed(name) + self._mark_as_changed(name) try: self__created = self._created except AttributeError: @@ -204,6 +214,7 @@ def __getstate__(self): data = {} for k in ( "_changed_fields", + "_unset_fields", "_initialised", "_created", "_dynamic_fields", @@ -219,6 +230,7 @@ def __setstate__(self, data): data["_data"] = self.__class__._from_son(data["_data"])._data for k in ( "_changed_fields", + "_unset_fields", "_initialised", "_created", "_data", @@ -490,23 +502,55 @@ def __expand_dynamic_values(self, name, value): return value - def _mark_as_changed(self, key): - """Mark a key as explicitly changed by the user.""" - if not key: - return - - if not hasattr(self, "_changed_fields"): - return - + def _resolve_key(self, key): + """Resolve key based on actual db field""" if "." in key: key, rest = key.split(".", 1) key = self._db_field_map.get(key, key) key = f"{key}.{rest}" else: key = self._db_field_map.get(key, key) + return key + + def _unmark_as_unset(self, key): + if not key or not hasattr(self, "_unset_fields"): + return + + key = self._resolve_key(key) + if key in self._unset_fields: + self._unset_fields.remove(key) + + def _mark_as_unset(self, key): + if not key or not hasattr(self, "_unset_fields"): + return + + key = self._resolve_key(key) + + if key not in self._unset_fields: + levels = key.split(".") + idx = 1 + while idx <= len(levels): + if ".".join(levels[:idx]) in self._unset_fields: + break + idx += 1 + else: + self._unset_fields.append(key) + # remove lower level changed fields + level = ".".join(levels[:idx]) + "." + for field in self._unset_fields[:]: + if field.startswith(level): + self._unset_fields.remove(field) + + def _mark_as_changed(self, key): + """Mark a key as explicitly changed by the user.""" + if not key or not hasattr(self, "_changed_fields"): + return + + key = self._resolve_key(key) if key not in self._changed_fields: - levels, idx = key.split("."), 1 + levels = key.split(".") + idx = 1 while idx <= len(levels): if ".".join(levels[:idx]) in self._changed_fields: break @@ -515,20 +559,21 @@ def _mark_as_changed(self, key): self._changed_fields.append(key) # remove lower level changed fields level = ".".join(levels[:idx]) + "." - remove = self._changed_fields.remove for field in self._changed_fields[:]: if field.startswith(level): - remove(field) + self._changed_fields.remove(field) def _clear_changed_fields(self): - """Using _get_changed_fields iterate and remove any fields that + """Using _get_updated_fields iterate and remove any fields that are marked as changed. """ ReferenceField = _import_class("ReferenceField") GenericReferenceField = _import_class("GenericReferenceField") - for changed in self._get_changed_fields(): - parts = changed.split(".") + changed_, unset = self._get_updated_fields() + updated_fields = changed_ + unset + for updated_field in updated_fields: + parts = updated_field.split(".") data = self for part in parts: if isinstance(data, list): @@ -549,6 +594,7 @@ def _clear_changed_fields(self): continue data._changed_fields = [] + data._unset_fields = [] elif isinstance(data, (list, tuple, dict)): if hasattr(data, "field") and isinstance( data.field, (ReferenceField, GenericReferenceField) @@ -557,6 +603,7 @@ def _clear_changed_fields(self): BaseDocument._nestable_types_clear_changed_fields(data) self._changed_fields = [] + self._unset_fields = [] @staticmethod def _nestable_types_clear_changed_fields(data): @@ -574,18 +621,20 @@ def _nestable_types_clear_changed_fields(data): iterator = data.items() for _index_or_key, value in iterator: - if hasattr(value, "_get_changed_fields") and not isinstance( + if hasattr(value, "_get_updated_fields") and not isinstance( value, Document ): # don't follow references value._clear_changed_fields() elif isinstance(value, (list, tuple, dict)): BaseDocument._nestable_types_clear_changed_fields(value) - @staticmethod - def _nestable_types_changed_fields(changed_fields, base_key, data): + def _nestable_types_changed_fields( + self, changed_fields, unset_fields, base_key, data + ): """Inspect nested data for changed fields :param changed_fields: Previously collected changed fields + :param unset_fields: Previously collected unset fields :param base_key: The base key that must be used to prepend changes to this data :param data: data to inspect for changes """ @@ -600,18 +649,19 @@ def _nestable_types_changed_fields(changed_fields, base_key, data): item_key = f"{base_key}{index_or_key}." # don't check anything lower if this key is already marked # as changed. - if item_key[:-1] in changed_fields: + if item_key[:-1] in changed_fields or item_key[:-1] in unset_fields: continue - if hasattr(value, "_get_changed_fields"): - changed = value._get_changed_fields() + if hasattr(value, "_get_updated_fields"): + changed, unset = value._get_updated_fields() changed_fields += [f"{item_key}{k}" for k in changed if k] + unset_fields += [f"{item_key}{k}" for k in unset if k] elif isinstance(value, (list, tuple, dict)): - BaseDocument._nestable_types_changed_fields( - changed_fields, item_key, value + self._nestable_types_changed_fields( + changed_fields, unset_fields, item_key, value ) - def _get_changed_fields(self): + def _get_updated_fields(self): """Return a list of all fields that have explicitly been changed.""" EmbeddedDocument = _import_class("EmbeddedDocument") LazyReferenceField = _import_class("LazyReferenceField") @@ -620,8 +670,10 @@ def _get_changed_fields(self): GenericReferenceField = _import_class("GenericReferenceField") SortedListField = _import_class("SortedListField") - changed_fields = [] - changed_fields += getattr(self, "_changed_fields", []) + unset_fields = list( + getattr(self, "_unset_fields", []) + ) # cast to list to use a copy of the original + changed_fields = list(getattr(self, "_changed_fields", [])) for field_name in self._fields_ordered: db_field_name = self._db_field_map.get(field_name, field_name) @@ -629,17 +681,20 @@ def _get_changed_fields(self): data = self._data.get(field_name, None) field = self._fields.get(field_name) - if db_field_name in changed_fields: + if db_field_name in (changed_fields + unset_fields): # Whole field already marked as changed, no need to go further continue - if isinstance(field, ReferenceField): # Don't follow referenced documents + if isinstance(field, ReferenceField): + # Don't follow referenced documents + # as it is tracked separately continue if isinstance(data, EmbeddedDocument): # Find all embedded fields that have been changed - changed = data._get_changed_fields() + changed, unset = data._get_updated_fields() changed_fields += [f"{key}{k}" for k in changed if k] + unset_fields += [f"{key}{k}" for k in unset if k] elif isinstance(data, (list, tuple, dict)): if hasattr(field, "field") and isinstance( field.field, @@ -650,6 +705,7 @@ def _get_changed_fields(self): GenericReferenceField, ), ): + # Don't follow list(referenced documents) continue elif isinstance(field, SortedListField) and field._ordering: # if ordering is affected whole list is changed @@ -657,8 +713,13 @@ def _get_changed_fields(self): changed_fields.append(db_field_name) continue - self._nestable_types_changed_fields(changed_fields, key, data) - return changed_fields + self._nestable_types_changed_fields( + changed_fields, unset_fields, key, data + ) + + # unset fields are also marked as changed by design + changed_fields = [f for f in changed_fields if f not in unset_fields] + return changed_fields, unset_fields def _delta(self): """Returns the delta (set, unset) of the changes for a document. @@ -667,8 +728,8 @@ def _delta(self): # Handles cases where not loaded from_son but has _id doc = self.to_mongo() - set_fields = self._get_changed_fields() - unset_data = {} + set_fields, unset_fields = self._get_updated_fields() + if hasattr(self, "_changed_fields"): set_data = {} # Fetch each set item from its path @@ -694,53 +755,13 @@ def _delta(self): if "_id" in set_data: del set_data["_id"] - # Determine if any changed items were actually unset. - for path, value in list(set_data.items()): - if value or isinstance( - value, (numbers.Number, bool) - ): # Account for 0 and True that are truthy - continue - - parts = path.split(".") - - if self._dynamic and len(parts) and parts[0] in self._dynamic_fields: - del set_data[path] + unset_data = {} + if hasattr(self, "_unset_fields"): + for path in unset_fields: + if path in set_data: + del set_data[path] + # raise Exception('Should not occur') unset_data[path] = 1 - continue - - # If we've set a value that ain't the default value don't unset it. - default = None - if path in self._fields: - default = self._fields[path].default - else: # Perform a full lookup for lists / embedded lookups - d = self - db_field_name = parts.pop() - for p in parts: - if isinstance(d, list) and p.isdigit(): - d = d[int(p)] - elif hasattr(d, "__getattribute__") and not isinstance(d, dict): - real_path = d._reverse_db_field_map.get(p, p) - d = getattr(d, real_path) - else: - d = d.get(p) - - if hasattr(d, "_fields"): - field_name = d._reverse_db_field_map.get( - db_field_name, db_field_name - ) - if field_name in d._fields: - default = d._fields.get(field_name).default - else: - default = None - - if default is not None: - default = default() if callable(default) else default - - if value != default: - continue - - del set_data[path] - unset_data[path] = 1 return set_data, unset_data @classmethod @@ -810,6 +831,7 @@ def _from_son(cls, son, _auto_dereference=True, created=False): obj = cls(__auto_convert=False, _created=created, **data) obj._changed_fields = [] + obj._unset_fields = [] if not _auto_dereference: obj._fields = fields diff --git a/mongoengine/document.py b/mongoengine/document.py index fa8960fa6..c5493a4a2 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -6,6 +6,7 @@ from mongoengine import signals from mongoengine.base import ( + UNSET_SENTINEL, BaseDict, BaseDocument, BaseList, @@ -75,7 +76,7 @@ class EmbeddedDocument(BaseDocument, metaclass=DocumentMetaclass): :attr:`meta` dictionary. """ - __slots__ = ("_instance",) + __slots__ = ("_instance", "_changed_fields", "_unset_fields") # my_metaclass is defined so that metaclass can be queried in Python 2 & 3 my_metaclass = DocumentMetaclass @@ -90,6 +91,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._instance = None self._changed_fields = [] + self._unset_fields = [] def __eq__(self, other): if isinstance(other, self.__class__): @@ -721,6 +723,7 @@ def reload(self, *fields, **kwargs): :param fields: (optional) args list of fields to reload :param max_depth: (optional) depth of dereferencing to follow """ + # max depth can be provided as first arg OR as a kwarg max_depth = 1 if fields and isinstance(fields[0], int): max_depth = fields[0] @@ -731,6 +734,9 @@ def reload(self, *fields, **kwargs): if self.pk is None: raise self.DoesNotExist("Document does not exist") + # The way reload works is by fetching another instance + # of the original Document and then re-attach the reloaded attribute + # to the original Document instance obj = ( self._qs.read_preference(ReadPreference.PRIMARY) .filter(**self._object_key) @@ -763,12 +769,18 @@ def reload(self, *fields, **kwargs): if fields else obj._changed_fields ) + self._unset_fields = ( + list(set(self._unset_fields) - set(fields)) if fields else obj._unset_fields + ) self._created = False return self def _reload(self, key, value): """Used by :meth:`~mongoengine.Document.reload` to ensure the correct instance is linked to self. + + In fact .reload() is fetching another instance of the source document + and then re-attach special fields to the source document """ if isinstance(value, BaseDict): value = [(k, self._reload(k, v)) for k, v in value.items()] @@ -781,7 +793,6 @@ def _reload(self, key, value): value = BaseList(value, self, key) elif isinstance(value, (EmbeddedDocument, DynamicEmbeddedDocument)): value._instance = None - value._changed_fields = [] return value def to_dbref(self): @@ -1035,7 +1046,7 @@ def __delattr__(self, *args, **kwargs): """ field_name = args[0] if field_name in self._dynamic_fields: - setattr(self, field_name, None) + setattr(self, field_name, UNSET_SENTINEL) self._dynamic_fields[field_name].null = False else: super().__delattr__(*args, **kwargs) @@ -1058,10 +1069,7 @@ def __delattr__(self, *args, **kwargs): """ field_name = args[0] if field_name in self._fields: - default = self._fields[field_name].default - if callable(default): - default = default() - setattr(self, field_name, default) + super().__delattr__(*args, **kwargs) else: setattr(self, field_name, None) diff --git a/tests/document/test_delta.py b/tests/document/test_delta.py index 68c698b64..a7f226b70 100644 --- a/tests/document/test_delta.py +++ b/tests/document/test_delta.py @@ -25,8 +25,10 @@ def tearDown(self): for collection in list_collection_names(self.db): self.db.drop_collection(collection) - def test_delta(self): + def test_delta_on_document_class(self): self.delta(Document) + + def test_delta_on_dynamic_document_class(self): self.delta(DynamicDocument) @staticmethod @@ -42,45 +44,52 @@ class Doc(DocClass): doc.save() doc = Doc.objects.first() - assert doc._get_changed_fields() == [] + assert doc._get_updated_fields() == ([], []) assert doc._delta() == ({}, {}) - doc.string_field = "hello" - assert doc._get_changed_fields() == ["string_field"] - assert doc._delta() == ({"string_field": "hello"}, {}) - + for attr_name, new_value in [ + ("string_field", "hello"), + ("int_field", 1), + ("dict_field", {"hello": "world", "ping": "pong"}), + ("list_field", ["1", 2, {"hello": "world"}]), + ]: + doc._changed_fields = [] + setattr(doc, attr_name, new_value) + assert doc._get_updated_fields() == ([attr_name], []) + assert doc._delta() == ({attr_name: new_value}, {}) + + # Test emptying dict/list fields gets marked as changed (and not unset) doc._changed_fields = [] - doc.int_field = 1 - assert doc._get_changed_fields() == ["int_field"] - assert doc._delta() == ({"int_field": 1}, {}) - - doc._changed_fields = [] - dict_value = {"hello": "world", "ping": "pong"} - doc.dict_field = dict_value - assert doc._get_changed_fields() == ["dict_field"] - assert doc._delta() == ({"dict_field": dict_value}, {}) + doc.dict_field = {} + assert doc._get_updated_fields() == (["dict_field"], []) + assert doc._delta() == ({"dict_field": {}}, {}) doc._changed_fields = [] - list_value = ["1", 2, {"hello": "world"}] - doc.list_field = list_value - assert doc._get_changed_fields() == ["list_field"] - assert doc._delta() == ({"list_field": list_value}, {}) + doc.list_field = [] + assert doc._get_updated_fields() == (["list_field"], []) + assert doc._delta() == ({"list_field": []}, {}) # Test unsetting - doc._changed_fields = [] - doc.dict_field = {} - assert doc._get_changed_fields() == ["dict_field"] - assert doc._delta() == ({}, {"dict_field": 1}) + for attr_name in ("int_field", "string_field", "list_field", "dict_field"): + doc._changed_fields = [] + doc._unset_fields = [] - doc._changed_fields = [] - doc.list_field = [] - assert doc._get_changed_fields() == ["list_field"] - assert doc._delta() == ({}, {"list_field": 1}) + assert doc._get_updated_fields() == ([], []) + delattr(doc, attr_name) + # del doc.int_field + assert doc._get_updated_fields() == ([], [attr_name]) + assert doc._delta() == ({}, {attr_name: 1}) - def test_delta_recursive(self): + def test_delta_recursive_document_embeddeddoc(self): self.delta_recursive(Document, EmbeddedDocument) + + def test_delta_recursive_dynamicdocument_embeddeddoc(self): self.delta_recursive(DynamicDocument, EmbeddedDocument) + + def test_delta_recursive_dynamicdocument_dynamicembeddeddoc(self): self.delta_recursive(Document, DynamicEmbeddedDocument) + + def test_delta_recursive_document_dynamicembeddeddoc(self): self.delta_recursive(DynamicDocument, DynamicEmbeddedDocument) def delta_recursive(self, DocClass, EmbeddedClass): @@ -103,7 +112,7 @@ class Doc(DocClass): doc.save() doc = Doc.objects.first() - assert doc._get_changed_fields() == [] + assert doc._get_updated_fields() == ([], []) assert doc._delta() == ({}, {}) embedded_1 = Embedded() @@ -114,7 +123,7 @@ class Doc(DocClass): embedded_1.list_field = ["1", 2, {"hello": "world"}] doc.embedded_field = embedded_1 - assert doc._get_changed_fields() == ["embedded_field"] + assert doc._get_updated_fields() == (["embedded_field"], []) embedded_delta = { "id": "010101", @@ -130,17 +139,18 @@ class Doc(DocClass): doc = doc.reload(10) doc.embedded_field.dict_field = {} - assert doc._get_changed_fields() == ["embedded_field.dict_field"] - assert doc.embedded_field._delta() == ({}, {"dict_field": 1}) - assert doc._delta() == ({}, {"embedded_field.dict_field": 1}) + assert doc._get_updated_fields() == (["embedded_field.dict_field"], []) + + assert doc.embedded_field._delta() == ({"dict_field": {}}, {}) + assert doc._delta() == ({"embedded_field.dict_field": {}}, {}) doc.save() doc = doc.reload(10) assert doc.embedded_field.dict_field == {} doc.embedded_field.list_field = [] - assert doc._get_changed_fields() == ["embedded_field.list_field"] - assert doc.embedded_field._delta() == ({}, {"list_field": 1}) - assert doc._delta() == ({}, {"embedded_field.list_field": 1}) + assert doc._get_updated_fields() == (["embedded_field.list_field"], []) + assert doc.embedded_field._delta() == ({"list_field": []}, {}) + assert doc._delta() == ({"embedded_field.list_field": []}, {}) doc.save() doc = doc.reload(10) assert doc.embedded_field.list_field == [] @@ -152,7 +162,7 @@ class Doc(DocClass): embedded_2.list_field = ["1", 2, {"hello": "world"}] doc.embedded_field.list_field = ["1", 2, embedded_2] - assert doc._get_changed_fields() == ["embedded_field.list_field"] + assert doc._get_updated_fields() == (["embedded_field.list_field"], []) assert doc.embedded_field._delta() == ( { @@ -196,7 +206,10 @@ class Doc(DocClass): assert doc.embedded_field.list_field[2][k] == embedded_2[k] doc.embedded_field.list_field[2].string_field = "world" - assert doc._get_changed_fields() == ["embedded_field.list_field.2.string_field"] + assert doc._get_updated_fields() == ( + ["embedded_field.list_field.2.string_field"], + [], + ) assert doc.embedded_field._delta() == ( {"list_field.2.string_field": "world"}, {}, @@ -212,7 +225,7 @@ class Doc(DocClass): # Test multiple assignments doc.embedded_field.list_field[2].string_field = "hello world" doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] - assert doc._get_changed_fields() == ["embedded_field.list_field.2"] + assert doc._get_updated_fields() == (["embedded_field.list_field.2"], []) assert doc.embedded_field._delta() == ( { "list_field.2": { @@ -264,7 +277,8 @@ class Doc(DocClass): doc = doc.reload(10) assert doc.embedded_field.list_field[2].list_field == [1, 2, {"hello": "world"}] - del doc.embedded_field.list_field[2].list_field[2]["hello"] + key = doc.embedded_field.list_field[2].list_field[2] + del key["hello"] assert doc._delta() == ( {}, {"embedded_field.list_field.2.list_field.2.hello": 1}, @@ -283,7 +297,7 @@ class Doc(DocClass): doc = doc.reload(10) doc.dict_field["Embedded"].string_field = "Hello World" - assert doc._get_changed_fields() == ["dict_field.Embedded.string_field"] + assert doc._get_updated_fields() == (["dict_field.Embedded.string_field"], []) assert doc._delta() == ({"dict_field.Embedded.string_field": "Hello World"}, {}) def test_circular_reference_deltas(self): @@ -362,11 +376,14 @@ class Organization(DocClass2): return person, organization, employee - def test_delta_db_field(self): + def test_delta_db_field_document(self): self.delta_db_field(Document) + + def test_delta_db_field_dynamic_document(self): self.delta_db_field(DynamicDocument) - def delta_db_field(self, DocClass): + @staticmethod + def delta_db_field(DocClass): class Doc(DocClass): string_field = StringField(db_field="db_string_field") int_field = IntField(db_field="db_int_field") @@ -378,40 +395,50 @@ class Doc(DocClass): doc.save() doc = Doc.objects.first() - assert doc._get_changed_fields() == [] + assert doc._get_updated_fields() == ([], []) assert doc._delta() == ({}, {}) doc.string_field = "hello" - assert doc._get_changed_fields() == ["db_string_field"] + assert doc._get_updated_fields() == (["db_string_field"], []) assert doc._delta() == ({"db_string_field": "hello"}, {}) doc._changed_fields = [] doc.int_field = 1 - assert doc._get_changed_fields() == ["db_int_field"] + assert doc._get_updated_fields() == (["db_int_field"], []) assert doc._delta() == ({"db_int_field": 1}, {}) doc._changed_fields = [] dict_value = {"hello": "world", "ping": "pong"} doc.dict_field = dict_value - assert doc._get_changed_fields() == ["db_dict_field"] + assert doc._get_updated_fields() == (["db_dict_field"], []) assert doc._delta() == ({"db_dict_field": dict_value}, {}) doc._changed_fields = [] list_value = ["1", 2, {"hello": "world"}] doc.list_field = list_value - assert doc._get_changed_fields() == ["db_list_field"] + assert doc._get_updated_fields() == (["db_list_field"], []) assert doc._delta() == ({"db_list_field": list_value}, {}) - # Test unsetting doc._changed_fields = [] doc.dict_field = {} - assert doc._get_changed_fields() == ["db_dict_field"] - assert doc._delta() == ({}, {"db_dict_field": 1}) + assert doc._get_updated_fields() == (["db_dict_field"], []) + assert doc._delta() == ({"db_dict_field": {}}, {}) doc._changed_fields = [] doc.list_field = [] - assert doc._get_changed_fields() == ["db_list_field"] - assert doc._delta() == ({}, {"db_list_field": 1}) + assert doc._get_updated_fields() == (["db_list_field"], []) + assert doc._delta() == ({"db_list_field": []}, {}) + + # Test unsetting + for attr_name in ("int_field", "string_field", "list_field", "dict_field"): + db_attr_name = "db_" + attr_name + doc._changed_fields = [] + doc._unset_fields = [] + + assert doc._get_updated_fields() == ([], []) + delattr(doc, attr_name) + assert doc._get_updated_fields() == ([], [db_attr_name]) + assert doc._delta() == ({}, {db_attr_name: 1}) # Test it saves that data doc = Doc() @@ -441,8 +468,7 @@ def test_delta_recursive_db_field_on_dynamicdoc_and_embeddeddoc(self): def test_delta_recursive_db_field_on_dynamicdoc_and_dynamicembeddeddoc(self): self.delta_recursive_db_field(DynamicDocument, DynamicEmbeddedDocument) - @staticmethod - def delta_recursive_db_field(DocClass, EmbeddedClass): + def delta_recursive_db_field(self, DocClass, EmbeddedClass): class Embedded(EmbeddedClass): string_field = StringField(db_field="db_string_field") int_field = IntField(db_field="db_int_field") @@ -463,7 +489,7 @@ class Doc(DocClass): doc.save() doc = Doc.objects.first() - assert doc._get_changed_fields() == [] + assert doc._get_updated_fields() == ([], []) assert doc._delta() == ({}, {}) embedded_1 = Embedded() @@ -473,7 +499,7 @@ class Doc(DocClass): embedded_1.list_field = ["1", 2, {"hello": "world"}] doc.embedded_field = embedded_1 - assert doc._get_changed_fields() == ["db_embedded_field"] + assert doc._get_updated_fields() == (["db_embedded_field"], []) embedded_delta = { "db_string_field": "hello", @@ -488,18 +514,18 @@ class Doc(DocClass): doc = doc.reload(10) doc.embedded_field.dict_field = {} - assert doc._get_changed_fields() == ["db_embedded_field.db_dict_field"] - assert doc.embedded_field._delta() == ({}, {"db_dict_field": 1}) - assert doc._delta() == ({}, {"db_embedded_field.db_dict_field": 1}) + assert doc._get_updated_fields() == (["db_embedded_field.db_dict_field"], []) + assert doc.embedded_field._delta() == ({"db_dict_field": {}}, {}) + assert doc._delta() == ({"db_embedded_field.db_dict_field": {}}, {}) doc.save() doc = doc.reload(10) assert doc.embedded_field.dict_field == {} - assert doc._get_changed_fields() == [] + assert doc._get_updated_fields() == ([], []) doc.embedded_field.list_field = [] - assert doc._get_changed_fields() == ["db_embedded_field.db_list_field"] - assert doc.embedded_field._delta() == ({}, {"db_list_field": 1}) - assert doc._delta() == ({}, {"db_embedded_field.db_list_field": 1}) + assert doc._get_updated_fields() == (["db_embedded_field.db_list_field"], []) + assert doc.embedded_field._delta() == ({"db_list_field": []}, {}) + assert doc._delta() == ({"db_embedded_field.db_list_field": []}, {}) doc.save() doc = doc.reload(10) assert doc.embedded_field.list_field == [] @@ -511,7 +537,7 @@ class Doc(DocClass): embedded_2.list_field = ["1", 2, {"hello": "world"}] doc.embedded_field.list_field = ["1", 2, embedded_2] - assert doc._get_changed_fields() == ["db_embedded_field.db_list_field"] + assert doc._get_updated_fields() == (["db_embedded_field.db_list_field"], []) assert doc.embedded_field._delta() == ( { "db_list_field": [ @@ -546,7 +572,7 @@ class Doc(DocClass): {}, ) doc.save() - assert doc._get_changed_fields() == [] + assert doc._get_updated_fields() == ([], []) doc = doc.reload(10) assert doc.embedded_field.list_field[0] == "1" @@ -555,9 +581,10 @@ class Doc(DocClass): assert doc.embedded_field.list_field[2][k] == embedded_2[k] doc.embedded_field.list_field[2].string_field = "world" - assert doc._get_changed_fields() == [ - "db_embedded_field.db_list_field.2.db_string_field" - ] + assert doc._get_updated_fields() == ( + ["db_embedded_field.db_list_field.2.db_string_field"], + [], + ) assert doc.embedded_field._delta() == ( {"db_list_field.2.db_string_field": "world"}, {}, @@ -573,7 +600,7 @@ class Doc(DocClass): # Test multiple assignments doc.embedded_field.list_field[2].string_field = "hello world" doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] - assert doc._get_changed_fields() == ["db_embedded_field.db_list_field.2"] + assert doc._get_updated_fields() == (["db_embedded_field.db_list_field.2"], []) assert doc.embedded_field._delta() == ( { "db_list_field.2": { @@ -681,13 +708,13 @@ class Person(DynamicDocument): p.age = 24 assert p.age == 24 - assert p._get_changed_fields() == ["age"] + assert p._get_updated_fields() == (["age"], []) assert p._delta() == ({"age": 24}, {}) p = Person.objects(age=22).get() p.age = 24 assert p.age == 24 - assert p._get_changed_fields() == ["age"] + assert p._get_updated_fields() == (["age"], []) assert p._delta() == ({"age": 24}, {}) p.save() @@ -702,40 +729,39 @@ class Doc(DynamicDocument): doc.save() doc = Doc.objects.first() - assert doc._get_changed_fields() == [] + assert doc._get_updated_fields() == ([], []) assert doc._delta() == ({}, {}) doc.string_field = "hello" - assert doc._get_changed_fields() == ["string_field"] + assert doc._get_updated_fields() == (["string_field"], []) assert doc._delta() == ({"string_field": "hello"}, {}) doc._changed_fields = [] doc.int_field = 1 - assert doc._get_changed_fields() == ["int_field"] + assert doc._get_updated_fields() == (["int_field"], []) assert doc._delta() == ({"int_field": 1}, {}) doc._changed_fields = [] dict_value = {"hello": "world", "ping": "pong"} doc.dict_field = dict_value - assert doc._get_changed_fields() == ["dict_field"] + assert doc._get_updated_fields() == (["dict_field"], []) assert doc._delta() == ({"dict_field": dict_value}, {}) doc._changed_fields = [] list_value = ["1", 2, {"hello": "world"}] doc.list_field = list_value - assert doc._get_changed_fields() == ["list_field"] + assert doc._get_updated_fields() == (["list_field"], []) assert doc._delta() == ({"list_field": list_value}, {}) - # Test unsetting doc._changed_fields = [] doc.dict_field = {} - assert doc._get_changed_fields() == ["dict_field"] - assert doc._delta() == ({}, {"dict_field": 1}) + assert doc._get_updated_fields() == (["dict_field"], []) + assert doc._delta() == ({"dict_field": {}}, {}) doc._changed_fields = [] doc.list_field = [] - assert doc._get_changed_fields() == ["list_field"] - assert doc._delta() == ({}, {"list_field": 1}) + assert doc._get_updated_fields() == (["list_field"], []) + assert doc._delta() == ({"list_field": []}, {}) def test_delta_with_dbref_true(self): person, organization, employee = self.circular_reference_deltas_2( @@ -743,7 +769,7 @@ def test_delta_with_dbref_true(self): ) employee.name = "test" - assert organization._get_changed_fields() == [] + assert organization._get_updated_fields() == ([], []) updates, removals = organization._delta() assert removals == {} @@ -760,7 +786,7 @@ def test_delta_with_dbref_false(self): ) employee.name = "test" - assert organization._get_changed_fields() == [] + assert organization._get_updated_fields() == ([], []) updates, removals = organization._delta() assert removals == {} @@ -787,11 +813,11 @@ class MyDoc(Document): subdoc = mydoc.subs["a"]["b"] subdoc.name = "bar" - assert subdoc._get_changed_fields() == ["name"] - assert mydoc._get_changed_fields() == ["subs.a.b.name"] + assert subdoc._get_updated_fields() == (["name"], []) + assert mydoc._get_updated_fields() == (["subs.a.b.name"], []) mydoc._clear_changed_fields() - assert mydoc._get_changed_fields() == [] + assert mydoc._get_updated_fields() == ([], []) def test_nested_nested_fields_db_field_set__gets_mark_as_changed_and_cleaned(self): class EmbeddedDoc(EmbeddedDocument): @@ -808,19 +834,19 @@ class MyDoc(Document): mydoc = MyDoc.objects.first() mydoc.embed.name = "foo1" - assert mydoc.embed._get_changed_fields() == ["db_name"] - assert mydoc._get_changed_fields() == ["db_embed.db_name"] + assert mydoc.embed._get_updated_fields() == (["db_name"], []) + assert mydoc._get_updated_fields() == (["db_embed.db_name"], []) mydoc = MyDoc.objects.first() embed = EmbeddedDoc(name="foo2") embed.name = "bar" mydoc.embed = embed - assert embed._get_changed_fields() == ["db_name"] - assert mydoc._get_changed_fields() == ["db_embed"] + assert embed._get_updated_fields() == (["db_name"], []) + assert mydoc._get_updated_fields() == (["db_embed"], []) mydoc._clear_changed_fields() - assert mydoc._get_changed_fields() == [] + assert mydoc._get_updated_fields() == ([], []) def test_lower_level_mark_as_changed(self): class EmbeddedDoc(EmbeddedDocument): @@ -835,17 +861,17 @@ class MyDoc(Document): mydoc = MyDoc.objects.first() mydoc.subs["a"] = EmbeddedDoc() - assert mydoc._get_changed_fields() == ["subs.a"] + assert mydoc._get_updated_fields() == (["subs.a"], []) subdoc = mydoc.subs["a"] subdoc.name = "bar" - assert subdoc._get_changed_fields() == ["name"] - assert mydoc._get_changed_fields() == ["subs.a"] + assert subdoc._get_updated_fields() == (["name"], []) + assert mydoc._get_updated_fields() == (["subs.a"], []) mydoc.save() mydoc._clear_changed_fields() - assert mydoc._get_changed_fields() == [] + assert mydoc._get_updated_fields() == ([], []) def test_upper_level_mark_as_changed(self): class EmbeddedDoc(EmbeddedDocument): @@ -862,15 +888,15 @@ class MyDoc(Document): subdoc = mydoc.subs["a"] subdoc.name = "bar" - assert subdoc._get_changed_fields() == ["name"] - assert mydoc._get_changed_fields() == ["subs.a.name"] + assert subdoc._get_updated_fields() == (["name"], []) + assert mydoc._get_updated_fields() == (["subs.a.name"], []) mydoc.subs["a"] = EmbeddedDoc() - assert mydoc._get_changed_fields() == ["subs.a"] + assert mydoc._get_updated_fields() == (["subs.a"], []) mydoc.save() mydoc._clear_changed_fields() - assert mydoc._get_changed_fields() == [] + assert mydoc._get_updated_fields() == ([], []) def test_referenced_object_changed_attributes(self): """Ensures that when you save a new reference to a field, the referenced object isn't altered""" diff --git a/tests/document/test_instance.py b/tests/document/test_instance.py index 21f2c82ef..a4bf20f66 100644 --- a/tests/document/test_instance.py +++ b/tests/document/test_instance.py @@ -554,25 +554,119 @@ class Animal(Document): Animal.drop_collection() - def test_reload_with_changed_fields(self): + def test__get_updated_fields_various_case(self): """Ensures reloading will not affect changed fields""" + class VitalSigns(EmbeddedDocument): + blood_pressure = FloatField() + + class Car(Document): + brand = StringField() + + class User(Document): + name = StringField() + car = ReferenceField(Car) + vital_signs = EmbeddedDocumentField(VitalSigns) + + User.drop_collection() + Car.drop_collection() + + car = Car(brand="Lamborghini").save() + user = User( + name="Bob", vital_signs=VitalSigns(blood_pressure=0.99), car=car + ).save() + + def reload(user): + user.reload() + assert user._get_updated_fields() == ([], []) + + # Alter top level attr + user.name = "John" + user.car = Car(brand="VW").save() + user.vital_signs = VitalSigns(blood_pressure=0.1) + assert user._get_updated_fields() == (["name", "car", "vital_signs"], []) + reload(user) + + # unset top level attr + del user.name + del user.car + del user.vital_signs + assert user._get_updated_fields() == ([], ["name", "car", "vital_signs"]) + reload(user) + + # alter reference field attribute + # must not get tracked in main doc + user.car.brand = "garbage" + assert user._get_updated_fields() == ([], []) + reload(user) + + # alter embedded doc attribute + user.vital_signs.blood_pressure = 0.123 + assert user._get_updated_fields() == (["vital_signs.blood_pressure"], []) + reload(user) + + # unset embedded doc attribute + user = User( + name="Foo", vital_signs=VitalSigns(blood_pressure=0.95), car=car + ).save() + del user.vital_signs.blood_pressure + assert user._get_updated_fields() == ([], ["vital_signs.blood_pressure"]) + + def test_reload_with_changed_fields_document(self): + """Ensures reloading will not affect changed fields""" + + class VitalSigns(EmbeddedDocument): + blood_pressure = FloatField() + class User(Document): name = StringField() number = IntField() + phone = StringField() + vital_signs = EmbeddedDocumentField(VitalSigns) User.drop_collection() - user = User(name="Bob", number=1).save() + user = User( + name="Bob", + number=1, + phone="01234", + vital_signs=VitalSigns(blood_pressure=0.99), + ).save() + user.name = "John" user.number = 2 + user.vital_signs.blood_pressure = 0.11 + del user.phone - assert user._get_changed_fields() == ["name", "number"] + assert user._get_updated_fields() == ( + ["name", "number", "vital_signs.blood_pressure"], + ["phone"], + ) user.reload("number") - assert user._get_changed_fields() == ["name"] + assert user._get_updated_fields() == ( + ["name", "vital_signs.blood_pressure"], + ["phone"], + ) + user.reload("vital_signs") + assert user._get_updated_fields() == (["name"], ["phone"]) + user.reload("phone") + assert user._get_updated_fields() == (["name"], []) user.save() + assert user._get_updated_fields() == ([], []) + + raw_doc = get_as_pymongo(user) + assert raw_doc == { + "name": "John", + "_id": user.id, + "number": 1, + "phone": "01234", + "vital_signs": {"blood_pressure": 0.99}, + } + user.reload() assert user.name == "John" + assert user.number == 1 + assert user.phone == "01234" def test_reload_referencing(self): """Ensures reloading updates weakrefs correctly.""" @@ -603,18 +697,19 @@ class Doc(Document): doc.embedded_field.list_field.append(1) doc.embedded_field.dict_field["woot"] = "woot" - changed = doc._get_changed_fields() + changed, unset = doc._get_updated_fields() assert changed == [ "list_field", "dict_field.woot", "embedded_field.list_field", "embedded_field.dict_field.woot", ] + assert unset == [] doc.save() assert len(doc.list_field) == 4 doc = doc.reload(10) - assert doc._get_changed_fields() == [] + assert doc._get_updated_fields() == ([], []) assert len(doc.list_field) == 4 assert len(doc.dict_field) == 2 assert len(doc.embedded_field.list_field) == 4 @@ -624,7 +719,7 @@ class Doc(Document): doc.save() doc.dict_field["extra"] = 1 doc = doc.reload(10, "list_field") - assert doc._get_changed_fields() == ["dict_field.extra"] + assert doc._get_updated_fields() == (["dict_field.extra"], []) assert len(doc.list_field) == 5 assert len(doc.dict_field) == 3 assert len(doc.embedded_field.list_field) == 4 @@ -965,7 +1060,7 @@ def test_modify_update(self): del doc_copy.job.years assert doc.to_json() == doc_copy.to_json() - assert doc._get_changed_fields() == [] + assert doc._get_updated_fields() == ([], []) self.assertDbEqual([dict(other_doc.to_mongo()), dict(doc.to_mongo())]) @@ -1626,7 +1721,7 @@ class User(self.Person): assert person.age == 21 assert person.active is False - def test__get_changed_fields_same_ids_reference_field_does_not_enters_infinite_loop_embedded_doc( + def test__get_updated_fields_same_ids_reference_field_does_not_enters_infinite_loop_embedded_doc( self, ): # Refers to Issue #1685 @@ -1637,10 +1732,11 @@ class ParentModel(Document): child = EmbeddedDocumentField(EmbeddedChildModel) emb = EmbeddedChildModel(id={"1": [1]}) - changed_fields = ParentModel(child=emb)._get_changed_fields() + changed_fields, unset_fields = ParentModel(child=emb)._get_updated_fields() assert changed_fields == [] + assert unset_fields == [] - def test__get_changed_fields_same_ids_reference_field_does_not_enters_infinite_loop_different_doc( + def test__get_updated_fields_same_ids_reference_field_does_not_enters_infinite_loop_different_doc( self, ): # Refers to Issue #1685 @@ -1659,10 +1755,10 @@ class Message(Document): message = Message(id=1, author=user).save() message.author.name = "tutu" - assert message._get_changed_fields() == [] - assert user._get_changed_fields() == ["name"] + assert message._get_updated_fields() == ([], []) + assert user._get_updated_fields() == (["name"], []) - def test__get_changed_fields_same_ids_embedded(self): + def test__get_updated_fields_same_ids_embedded(self): # Refers to Issue #1768 class User(EmbeddedDocument): id = IntField() @@ -1679,7 +1775,7 @@ class Message(Document): message = Message(id=1, author=user).save() message.author.name = "tutu" - assert message._get_changed_fields() == ["author.name"] + assert message._get_updated_fields() == (["author.name"], []) message.save() message_fetched = Message.objects.with_id(message.id) diff --git a/tests/fields/test_fields.py b/tests/fields/test_fields.py index 69fe14712..0bcadb43c 100644 --- a/tests/fields/test_fields.py +++ b/tests/fields/test_fields.py @@ -1840,10 +1840,10 @@ class Doc2(Document): doc2 = Doc2(ref=doc1, refs=[doc11]).save() doc2.ref.name = "garbage2" - assert doc2._get_changed_fields() == [] + assert doc2._get_updated_fields() == ([], []) doc2.refs[0].name = "garbage3" - assert doc2._get_changed_fields() == [] + assert doc2._get_updated_fields() == ([], []) assert doc2._delta() == ({}, {}) def test_generic_reference_field(self): diff --git a/tests/queryset/test_field_list.py b/tests/queryset/test_field_list.py index 96bd804d9..b2cb9b1ce 100644 --- a/tests/queryset/test_field_list.py +++ b/tests/queryset/test_field_list.py @@ -4,6 +4,7 @@ from mongoengine import * from mongoengine.queryset import QueryFieldList +from tests.utils import MongoDBTestCase, get_as_pymongo class TestQueryFieldList: @@ -66,6 +67,103 @@ def test_using_a_slice(self): assert q.as_dict() == {"a": {"$slice": 5}} +class TestListField(MongoDBTestCase): + def test_list_field_empty(self): + class BlogPost(Document): + authors = ListField(default=[]) + + BlogPost.drop_collection() + + blog = BlogPost().save() + + assert get_as_pymongo(blog) == { + "_id": blog.id, + "authors": [], + } + + blog.authors = [] + blog.save() + assert get_as_pymongo(blog) == { + "_id": blog.id, + "authors": [], + } + + blog.authors = [1] + blog.save() + assert get_as_pymongo(blog) == {"_id": blog.id, "authors": [1]} + + del blog.authors + blog.save() + assert get_as_pymongo(blog) == { + "_id": blog.id, + } + + # set empty list in constructor + blog2 = BlogPost(authors=[]).save() + assert get_as_pymongo(blog2) == { + "_id": blog2.id, + "authors": [], + } + + # set None in constructor + blog3 = BlogPost(authors=None).save() + assert get_as_pymongo(blog3) == { + "_id": blog3.id, + "authors": [], + } + + def test_only_on_list_field_without_key_return_default(self): + # Ensure no regression of #938 + class A(Document): + my_list = ListField(IntField()) + + A.drop_collection() + + app = A(my_list=[]).save() + app.save() + + del app.my_list + app.save() + + assert get_as_pymongo(app) == { + "_id": app.id, + } + a = A.objects(id=app.id).only("my_list").get() + assert a.my_list == [] + + def test_item_frequencies_with_empty_list_edge_cases(self): + class TestDocument(Document): + fruit = ListField(StringField()) + + TestDocument.drop_collection() + + doc1 = TestDocument(fruit=["a", "a", "b"]).save() + doc2 = TestDocument(fruit=["b", "c"]).save() + + assert TestDocument.objects.item_frequencies("fruit") == { + "a": 2, + "b": 2, + "c": 1, + } + + doc2.delete() + assert TestDocument.objects.item_frequencies("fruit") == {"a": 2, "b": 1} + + doc1.fruit = [] + doc1.save() + assert TestDocument.objects.item_frequencies("fruit") == {} + + # delete the fruit field from db + # this creates weird item_frequencies result + # but somehow it is consistent + del doc1.fruit + doc1.save() + assert get_as_pymongo(doc1) == { + "_id": doc1.id, + } + assert TestDocument.objects.item_frequencies("fruit") == {None: 1} + + class TestOnlyExcludeAll(unittest.TestCase): def setUp(self): connect(db="mongoenginetest") diff --git a/tests/queryset/test_queryset.py b/tests/queryset/test_queryset.py index 1aa4f32a3..20234ee03 100644 --- a/tests/queryset/test_queryset.py +++ b/tests/queryset/test_queryset.py @@ -815,20 +815,18 @@ class TestOrganization(Document): o.owner = p p.name = "p2" - assert o._get_changed_fields() == ["owner"] - assert p._get_changed_fields() == ["name"] + assert o._get_updated_fields() == (["owner"], []) + assert p._get_updated_fields() == (["name"], []) o.save() - assert o._get_changed_fields() == [] - assert p._get_changed_fields() == ["name"] # Fails; it's empty + assert o._get_updated_fields() == ([], []) + assert p._get_updated_fields() == (["name"], []) - # This will do NOTHING at all, even though we changed the name p.save() - p.reload() - assert p.name == "p2" # Fails; it's still `p1` + assert p.name == "p2" def test_upsert(self): self.Person.drop_collection() @@ -1080,7 +1078,7 @@ class Comment(Document): with pytest.raises(NotUniqueError): Comment.objects.insert(com1) - def test_get_changed_fields_query_count(self): + def test_get_updated_fields_query_count(self): """Make sure we don't perform unnecessary db operations when none of document's fields were updated. """ @@ -1118,7 +1116,7 @@ class Project(Document): # Checking changed fields of a newly fetched document should not # result in a query. - org._get_changed_fields() + org._get_updated_fields() assert q == 1 # Saving a doc without changing any of its fields should not result diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index 42ce42a1d..3572aed9b 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -76,7 +76,8 @@ def test_clear_calls_mark_as_changed(self): def test___delitem___calls_mark_as_changed(self): base_dict = self._get_basedict({"k": "v"}) del base_dict["k"] - assert base_dict._instance._changed_fields == ["my_name.k"] + assert base_dict._instance._changed_fields == [] + assert base_dict._instance._unset_fields == ["my_name.k"] assert base_dict == {} def test___getitem____KeyError(self): @@ -155,7 +156,8 @@ def test___delattr____tracked_by_changes(self): base_dict = self._get_basedict({}) base_dict.a_new_attr = "test" del base_dict.a_new_attr - assert base_dict._instance._changed_fields == ["my_name.a_new_attr"] + assert base_dict._instance._changed_fields == [] + assert base_dict._instance._unset_fields == ["my_name.a_new_attr"] class TestBaseList: @@ -163,10 +165,7 @@ class TestBaseList: def _get_baselist(list_items): """Get a BaseList bound to a fake document instance""" fake_doc = DocumentStub() - base_list = BaseList(list_items, instance=None, name="my_name") - base_list._instance = ( - fake_doc # hack to inject the mock, it does not work in the constructor - ) + base_list = BaseList(list_items, instance=fake_doc, name="my_name") return base_list def test___init___(self): @@ -185,7 +184,7 @@ def test___iter__(self): base_list = BaseList(values, instance=None, name="my_name") assert values == list(base_list) - def test___iter___allow_modification_while_iterating_withou_error(self): + def test___iter___allow_modification_while_iterating_without_error(self): # regular list allows for this, thus this subclass must comply to that base_list = BaseList([True, False, True, False], instance=None, name="my_name") for idx, val in enumerate(base_list): diff --git a/tests/test_dereference.py b/tests/test_dereference.py index bddcc5432..0ea702426 100644 --- a/tests/test_dereference.py +++ b/tests/test_dereference.py @@ -4,17 +4,10 @@ from mongoengine import * from mongoengine.context_managers import query_counter +from tests.utils import MongoDBTestCase -class FieldTest(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.db = connect(db="mongoenginetest") - - @classmethod - def tearDownClass(cls): - cls.db.drop_database("mongoenginetest") - +class FieldTest(MongoDBTestCase): def test_list_item_dereference(self): """Ensure that DBRef items in ListFields are dereferenced.""" @@ -416,7 +409,7 @@ def __repr__(self): daughter.relations.append(mother) daughter.relations.append(daughter) - assert daughter._get_changed_fields() == ["relations"] + assert daughter._get_updated_fields() == (["relations"], []) daughter.save() assert "[, ]" == "%s" % Person.objects()