diff --git a/psqlextra/query.py b/psqlextra/query.py index 2d24b5a..65a20c5 100644 --- a/psqlextra/query.py +++ b/psqlextra/query.py @@ -327,7 +327,9 @@ def upsert( self.on_conflict( conflict_target, - ConflictAction.UPDATE, + ConflictAction.UPDATE + if (update_condition or update_condition is None) + else ConflictAction.NOTHING, index_predicate=index_predicate, update_condition=update_condition, update_values=update_values, diff --git a/tests/test_upsert.py b/tests/test_upsert.py index 3aa6207..a9e567b 100644 --- a/tests/test_upsert.py +++ b/tests/test_upsert.py @@ -4,6 +4,7 @@ from django.db import connection, models from django.db.models import F, Q from django.db.models.expressions import CombinedExpression, Value +from django.test.utils import CaptureQueriesContext from psqlextra.expressions import ExcludedCol from psqlextra.fields import HStoreField @@ -144,6 +145,35 @@ def test_upsert_with_update_condition(): assert obj1.active +@pytest.mark.parametrize("update_condition_value", [0, False]) +def test_upsert_with_update_condition_false(update_condition_value): + """Tests that an expression can be used as an upsert update condition.""" + + model = get_fake_model( + { + "name": models.TextField(unique=True), + "priority": models.IntegerField(), + "active": models.BooleanField(), + } + ) + + obj1 = model.objects.create(name="joe", priority=1, active=False) + + with CaptureQueriesContext(connection) as ctx: + upsert_result = model.objects.upsert( + conflict_target=["name"], + update_condition=update_condition_value, + fields=dict(name="joe", priority=2, active=True), + ) + assert upsert_result is None + assert len(ctx) == 1 + assert 'ON CONFLICT ("name") DO NOTHING' in ctx[0]["sql"] + + obj1.refresh_from_db() + assert obj1.priority == 1 + assert not obj1.active + + def test_upsert_with_update_values(): """Tests that the default update values can be overriden with custom expressions."""