diff --git a/mongoengine/document.py b/mongoengine/document.py index 0ba5db126..184017922 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -332,6 +332,7 @@ def save( _refs=None, save_condition=None, signal_kwargs=None, + upsert=None, **kwargs, ): """Save the :class:`~mongoengine.Document` to the database. If the @@ -361,6 +362,8 @@ def save( Raises :class:`OperationError` if the conditions are not satisfied :param signal_kwargs: (optional) kwargs dictionary to be passed to the signal calls. + :param upsert: (optional) explicitly forces upsert to value if it is an + update .. versionchanged:: 0.5 In existing documents it only saves changed fields using @@ -407,7 +410,7 @@ def save( object_id = self._save_create(doc, force_insert, write_concern) else: object_id, created = self._save_update( - doc, save_condition, write_concern + doc, save_condition, write_concern, upsert ) if cascade is None: @@ -505,11 +508,15 @@ def _integrate_shard_key(self, doc, select_dict): return select_dict - def _save_update(self, doc, save_condition, write_concern): + def _save_update(self, doc, save_condition, write_concern, upsert=None): """Update an existing document. Helper method, should only be used inside save(). """ + if upsert and (save_condition is not None): + raise ValueError( + "Updating with a save_condition implies upsert is False or None but upsert is True" + ) collection = self._get_collection() object_id = doc["_id"] created = False @@ -524,12 +531,13 @@ def _save_update(self, doc, save_condition, write_concern): update_doc = self._get_update_doc() if update_doc: - upsert = save_condition is None + if upsert is None: + upsert = save_condition is None with set_write_concern(collection, write_concern) as wc_collection: last_error = wc_collection.update_one( select_dict, update_doc, upsert=upsert ).raw_result - if not upsert and last_error["n"] == 0: + if save_condition is not None and last_error["n"] == 0: raise SaveConditionError( "Race condition preventing document update detected" ) diff --git a/tests/document/test_instance.py b/tests/document/test_instance.py index 1469c9bb6..60fa62fd9 100644 --- a/tests/document/test_instance.py +++ b/tests/document/test_instance.py @@ -1457,6 +1457,65 @@ def test_inserts_if_you_set_the_pk(self): assert 2 == self.Person.objects.count() + def test_save_upsert_false_doesnt_insert_when_deleted(self): + class Person(Document): + name = StringField() + + Person.drop_collection() + + p1 = Person(name="Wilson Snr") + p1.save() + p2 = Person.objects().first() + p1.delete() + p2.name = " Bob Snr" + p2.save(upsert=False) + + assert Person.objects.count() == 0 + + def test_save_upsert_true_inserts_when_deleted(self): + class Person(Document): + name = StringField() + + Person.drop_collection() + + p1 = Person(name="Wilson Snr") + p1.save() + p2 = Person.objects().first() + p1.delete() + p2.name = "Bob Snr" + p2.save(upsert=True) + + assert Person.objects.count() == 1 + + def test_save_upsert_null_inserts_when_deleted(self): + # probably want to remove this as this is bad but preserved for backwards compatibility + # see https://github.com/MongoEngine/mongoengine/issues/564 + class Person(Document): + name = StringField() + + Person.drop_collection() + + p1 = Person(name="Wilson Snr") + p1.save() + p2 = Person.objects().first() + p1.delete() + p2.name = "Bob Snr" + p2.save(upsert=None) # default if you dont pass it + + assert Person.objects.count() == 1 + + def test_save_upsert_raises_value_error_when_upsert_and_save_condition_set(self): + class Person(Document): + name = StringField() + + Person.drop_collection() + + p1 = Person(name="Wilson Snr") + p1.save() + p1.name = "Bob Snr" + with pytest.raises(ValueError): + p1.save(save_condition={}, upsert=True) + def test_can_save_if_not_included(self): class EmbeddedDoc(EmbeddedDocument): pass