22from itertools import chain
33from typing import (
44 TYPE_CHECKING ,
5+ Any ,
56 Dict ,
67 Generic ,
78 Iterable ,
1718from django .db .models import Expression , Q , QuerySet
1819from django .db .models .fields import NOT_PROVIDED
1920
21+ from .expressions import ExcludedCol
2022from .sql import PostgresInsertQuery , PostgresQuery
2123from .types import ConflictAction
2224
@@ -51,6 +53,7 @@ def __init__(self, model=None, query=None, using=None, hints=None):
5153 self .conflict_action = None
5254 self .conflict_update_condition = None
5355 self .index_predicate = None
56+ self .update_values = None
5457
5558 def annotate (self , ** annotations ) -> "Self" : # type: ignore[valid-type, override]
5659 """Custom version of the standard annotate function that allows using
@@ -108,6 +111,7 @@ def on_conflict(
108111 action : ConflictAction ,
109112 index_predicate : Optional [Union [Expression , Q , str ]] = None ,
110113 update_condition : Optional [Union [Expression , Q , str ]] = None ,
114+ update_values : Optional [Dict [str , Union [Any , Expression ]]] = None ,
111115 ):
112116 """Sets the action to take when conflicts arise when attempting to
113117 insert/create a new row.
@@ -125,12 +129,18 @@ def on_conflict(
125129
126130 update_condition:
127131 Only update if this SQL expression evaluates to true.
132+
133+ update_values:
134+ Optionally, values/expressions to use when rows
135+ conflict. If not specified, all columns specified
136+ in the rows are updated with the values you specified.
128137 """
129138
130139 self .conflict_target = fields
131140 self .conflict_action = action
132141 self .conflict_update_condition = update_condition
133142 self .index_predicate = index_predicate
143+ self .update_values = update_values
134144
135145 return self
136146
@@ -293,6 +303,7 @@ def upsert(
293303 index_predicate : Optional [Union [Expression , Q , str ]] = None ,
294304 using : Optional [str ] = None ,
295305 update_condition : Optional [Union [Expression , Q , str ]] = None ,
306+ update_values : Optional [Dict [str , Union [Any , Expression ]]] = None ,
296307 ) -> int :
297308 """Creates a new record or updates the existing one with the specified
298309 data.
@@ -315,6 +326,11 @@ def upsert(
315326 update_condition:
316327 Only update if this SQL expression evaluates to true.
317328
329+ update_values:
330+ Optionally, values/expressions to use when rows
331+ conflict. If not specified, all columns specified
332+ in the rows are updated with the values you specified.
333+
318334 Returns:
319335 The primary key of the row that was created/updated.
320336 """
@@ -324,6 +340,7 @@ def upsert(
324340 ConflictAction .UPDATE ,
325341 index_predicate = index_predicate ,
326342 update_condition = update_condition ,
343+ update_values = update_values ,
327344 )
328345
329346 kwargs = {** fields , "using" : using }
@@ -336,6 +353,7 @@ def upsert_and_get(
336353 index_predicate : Optional [Union [Expression , Q , str ]] = None ,
337354 using : Optional [str ] = None ,
338355 update_condition : Optional [Union [Expression , Q , str ]] = None ,
356+ update_values : Optional [Dict [str , Union [Any , Expression ]]] = None ,
339357 ):
340358 """Creates a new record or updates the existing one with the specified
341359 data and then gets the row.
@@ -358,6 +376,11 @@ def upsert_and_get(
358376 update_condition:
359377 Only update if this SQL expression evaluates to true.
360378
379+ update_values:
380+ Optionally, values/expressions to use when rows
381+ conflict. If not specified, all columns specified
382+ in the rows are updated with the values you specified.
383+
361384 Returns:
362385 The model instance representing the row
363386 that was created/updated.
@@ -368,6 +391,7 @@ def upsert_and_get(
368391 ConflictAction .UPDATE ,
369392 index_predicate = index_predicate ,
370393 update_condition = update_condition ,
394+ update_values = update_values ,
371395 )
372396
373397 kwargs = {** fields , "using" : using }
@@ -381,6 +405,7 @@ def bulk_upsert(
381405 return_model : bool = False ,
382406 using : Optional [str ] = None ,
383407 update_condition : Optional [Union [Expression , Q , str ]] = None ,
408+ update_values : Optional [Dict [str , Union [Any , Expression ]]] = None ,
384409 ):
385410 """Creates a set of new records or updates the existing ones with the
386411 specified data.
@@ -407,6 +432,11 @@ def bulk_upsert(
407432 update_condition:
408433 Only update if this SQL expression evaluates to true.
409434
435+ update_values:
436+ Optionally, values/expressions to use when rows
437+ conflict. If not specified, all columns specified
438+ in the rows are updated with the values you specified.
439+
410440 Returns:
411441 A list of either the dicts of the rows upserted, including the pk or
412442 the models of the rows upserted
@@ -417,7 +447,9 @@ def bulk_upsert(
417447 ConflictAction .UPDATE ,
418448 index_predicate = index_predicate ,
419449 update_condition = update_condition ,
450+ update_values = update_values ,
420451 )
452+
421453 return self .bulk_insert (rows , return_model , using = using )
422454
423455 def _create_model_instance (
@@ -505,15 +537,19 @@ def _build_insert_compiler(
505537 )
506538
507539 # get the fields to be used during update/insert
508- insert_fields , update_fields = self ._get_upsert_fields (first_row )
540+ insert_fields , update_values = self ._get_upsert_fields (first_row )
541+
542+ # allow the user to override what should happen on update
543+ if self .update_values is not None :
544+ update_values = self .update_values
509545
510546 # build a normal insert query
511547 query = PostgresInsertQuery (self .model )
512548 query .conflict_action = self .conflict_action
513549 query .conflict_target = self .conflict_target
514550 query .conflict_update_condition = self .conflict_update_condition
515551 query .index_predicate = self .index_predicate
516- query .values (objs , insert_fields , update_fields )
552+ query .insert_on_conflict_values (objs , insert_fields , update_values )
517553
518554 compiler = query .get_compiler (using )
519555 return compiler
@@ -578,13 +614,13 @@ def _get_upsert_fields(self, kwargs):
578614
579615 model_instance = self .model (** kwargs )
580616 insert_fields = []
581- update_fields = []
617+ update_values = {}
582618
583619 for field in model_instance ._meta .local_concrete_fields :
584620 has_default = field .default != NOT_PROVIDED
585621 if field .name in kwargs or field .column in kwargs :
586622 insert_fields .append (field )
587- update_fields . append (field )
623+ update_values [ field . name ] = ExcludedCol (field . column )
588624 continue
589625 elif has_default :
590626 insert_fields .append (field )
@@ -595,13 +631,13 @@ def _get_upsert_fields(self, kwargs):
595631 # instead of a concrete field, we have to handle that
596632 if field .primary_key is True and "pk" in kwargs :
597633 insert_fields .append (field )
598- update_fields . append (field )
634+ update_values [ field . name ] = ExcludedCol (field . column )
599635 continue
600636
601637 if self ._is_magical_field (model_instance , field , is_insert = True ):
602638 insert_fields .append (field )
603639
604640 if self ._is_magical_field (model_instance , field , is_insert = False ):
605- update_fields . append (field )
641+ update_values [ field . name ] = ExcludedCol (field . column )
606642
607- return insert_fields , update_fields
643+ return insert_fields , update_values
0 commit comments