Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions docs/source/conflict_handling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,42 @@ Alternatively, with Django 3.1 or newer, :class:`~django:django.db.models.Q` obj
Q(name__gt=ExcludedCol('priority'))


Update values
"""""""""""""

Optionally, the fields to update can be overriden. The default is to update the same fields that were specified in the rows to insert.

Refer to the insert values using the :class:`psqlextra.expressions.ExcludedCol` expression which translates to PostgreSQL's ``EXCLUDED.<column>`` expression. All expressions and features that can be used with Django's :meth:`~django:django.db.models.query.QuerySet.update` can be used here.

.. warning::

Specifying an empty ``update_values`` (``{}``) will transform the query into :attr:`~psqlextra.types.ConflictAction.NOTHING`. Only ``None`` makes the default behaviour kick in of updating all fields that were specified.

.. code-block:: python

from django.db.models import F

from psqlextra.expressions import ExcludedCol

(
MyModel
.objects
.on_conflict(
['name'],
ConflictAction.UPDATE,
update_values=dict(
name=ExcludedCol('name'),
count=F('count') + 1,
),
)
.insert(
name='henk',
count=0,
)
)



ConflictAction.NOTHING
**********************

Expand Down
79 changes: 41 additions & 38 deletions psqlextra/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys

from collections.abc import Iterable
from typing import Tuple, Union
from typing import TYPE_CHECKING, Tuple, Union, cast

import django

Expand All @@ -12,11 +12,13 @@
from django.db.models import Expression, Model, Q
from django.db.models.fields.related import RelatedField
from django.db.models.sql import compiler as django_compiler
from django.db.utils import ProgrammingError

from .expressions import HStoreValue
from .types import ConflictAction

if TYPE_CHECKING:
from .sql import PostgresInsertQuery


def append_caller_to_sql(sql):
"""Append the caller to SQL queries.
Expand Down Expand Up @@ -104,8 +106,7 @@ def as_sql(self, *args, **kwargs):
return append_caller_to_sql(sql), params

def _prepare_query_values(self):
"""Extra prep on query values by converting dictionaries into.

"""Extra prep on query values by converting dictionaries into
:see:HStoreValue expressions.

This allows putting expressions in a dictionary. The
Expand Down Expand Up @@ -162,6 +163,8 @@ def as_sql(self, *args, **kwargs):
class PostgresInsertOnConflictCompiler(django_compiler.SQLInsertCompiler): # type: ignore [name-defined]
"""Compiler for SQL INSERT statements."""

query: "PostgresInsertQuery"

def __init__(self, *args, **kwargs):
"""Initializes a new instance of
:see:PostgresInsertOnConflictCompiler."""
Expand All @@ -170,35 +173,14 @@ def __init__(self, *args, **kwargs):

def as_sql(self, return_id=False, *args, **kwargs):
"""Builds the SQL INSERT statement."""

queries = [
self._rewrite_insert(sql, params, return_id)
for sql, params in super().as_sql(*args, **kwargs)
]

return queries

def execute_sql(self, return_id=False):
# execute all the generate queries
with self.connection.cursor() as cursor:
rows = []
for sql, params in self.as_sql(return_id):
cursor.execute(sql, params)
try:
rows.extend(cursor.fetchall())
except ProgrammingError:
pass
description = cursor.description

# create a mapping between column names and column value
return [
{
column.name: row[column_index]
for column_index, column in enumerate(description)
if row
}
for row in rows
]

def _rewrite_insert(self, sql, params, return_id=False):
"""Rewrites a formed SQL INSERT query to include the ON CONFLICT
clause.
Expand All @@ -210,9 +192,9 @@ def _rewrite_insert(self, sql, params, return_id=False):
params:
The parameters passed to the query.

returning:
What to put in the `RETURNING` clause
of the resulting query.
return_id:
Whether to only return the ID or all
columns.

Returns:
A tuple of the rewritten SQL query and new params.
Expand All @@ -234,13 +216,6 @@ def _rewrite_insert_on_conflict(
"""Rewrites a normal SQL INSERT query to add the 'ON CONFLICT'
clause."""

update_columns = ", ".join(
[
"{0} = EXCLUDED.{0}".format(self.qn(field.column))
for field in self.query.update_fields # type: ignore[attr-defined]
]
)

# build the conflict target, the columns to watch
# for conflicts
on_conflict_clause = self._build_on_conflict_clause()
Expand All @@ -254,10 +229,21 @@ def _rewrite_insert_on_conflict(
rewritten_sql += f" WHERE {expr_sql}"
params += tuple(expr_params)

# Fallback in case the user didn't specify any update values. We can still
# make the query work if we switch to ConflictAction.NOTHING
if (
conflict_action == ConflictAction.UPDATE.value
and not self.query.update_values
):
conflict_action = ConflictAction.NOTHING

rewritten_sql += f" DO {conflict_action}"

if conflict_action == "UPDATE":
rewritten_sql += f" SET {update_columns}"
if conflict_action == ConflictAction.UPDATE.value:
set_sql, sql_params = self._build_set_statement()

rewritten_sql += f" SET {set_sql}"
params += sql_params

if update_condition:
expr_sql, expr_params = self._compile_expression(
Expand All @@ -270,6 +256,23 @@ def _rewrite_insert_on_conflict(

return (rewritten_sql, params)

def _build_set_statement(self) -> Tuple[str, tuple]:
"""Builds the SET statement for the ON CONFLICT DO UPDATE clause.

This uses the update compiler to provide full compatibility with
the standard Django's `update(...)`.
"""

# Local import to work around the circular dependency between
# the compiler and the queries.
from .sql import PostgresUpdateQuery

query = cast(PostgresUpdateQuery, self.query.chain(PostgresUpdateQuery))
query.add_update_values(self.query.update_values)

sql, params = query.get_compiler(self.connection.alias).as_sql()
return sql.split("SET")[1].split(" WHERE")[0], tuple(params)

def _build_on_conflict_clause(self):
if django.VERSION >= (2, 2):
from django.db.models.constraints import BaseConstraint
Expand Down
8 changes: 8 additions & 0 deletions psqlextra/introspect/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .fields import inspect_model_local_concrete_fields
from .models import model_from_cursor, models_from_cursor

__all__ = [
"models_from_cursor",
"model_from_cursor",
"inspect_model_local_concrete_fields",
]
21 changes: 21 additions & 0 deletions psqlextra/introspect/fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import List, Type

from django.db.models import Field, Model


def inspect_model_local_concrete_fields(model: Type[Model]) -> List[Field]:
"""Gets a complete list of local and concrete fields on a model, these are
fields that directly map to a database colmn directly on the table backing
the model.

This is similar to Django's `Meta.local_concrete_fields`, which is a
private API. This method utilizes only public APIs.
"""

local_concrete_fields = []

for field in model._meta.get_fields(include_parents=False):
if isinstance(field, Field) and field.column and not field.many_to_many:
local_concrete_fields.append(field)

return local_concrete_fields
170 changes: 170 additions & 0 deletions psqlextra/introspect/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from typing import (
Any,
Dict,
Generator,
Iterable,
List,
Optional,
Type,
TypeVar,
cast,
)

from django.core.exceptions import FieldDoesNotExist
from django.db import connection, models
from django.db.models import Field, Model
from django.db.models.expressions import Expression

from .fields import inspect_model_local_concrete_fields

TModel = TypeVar("TModel", bound=models.Model)


def _construct_model(
model: Type[TModel],
columns: Iterable[str],
values: Iterable[Any],
*,
apply_converters: bool = True
) -> TModel:
fields_by_name_and_column = {}
for field in inspect_model_local_concrete_fields(model):
fields_by_name_and_column[field.attname] = field

if field.db_column:
fields_by_name_and_column[field.db_column] = field

indexable_columns = list(columns)

row = {}

for index, value in enumerate(values):
column = indexable_columns[index]
try:
field = cast(Field, model._meta.get_field(column))
except FieldDoesNotExist:
field = fields_by_name_and_column[column]

field_column_expression = field.get_col(model._meta.db_table)

if apply_converters:
converters = cast(Expression, field).get_db_converters(
connection
) + connection.ops.get_db_converters(field_column_expression)

converted_value = value
for converter in converters:
converted_value = converter(
converted_value,
field_column_expression,
connection,
)
else:
converted_value = value

row[field.attname] = converted_value

instance = model(**row)
instance._state.adding = False
instance._state.db = connection.alias

return instance


def models_from_cursor(
model: Type[TModel], cursor, *, related_fields: List[str] = []
) -> Generator[TModel, None, None]:
"""Fetches all rows from a cursor and converts the values into model
instances.

This is roughly what Django does internally when you do queries. This
goes further than `Model.from_db` as it also applies converters to make
sure that values are converted into their Python equivalent.

Use this when you've outgrown the ORM and you are writing performant
queries yourself and you need to map the results back into ORM objects.

Arguments:
model:
Model to construct.

cursor:
Cursor to read the rows from.

related_fields:
List of ForeignKey/OneToOneField names that were joined
into the raw query. Use this to achieve the same thing
that Django's `.select_related()` does.

Field names should be specified in the order that they
are SELECT'd in.
"""

columns = [col[0] for col in cursor.description]
field_offset = len(inspect_model_local_concrete_fields(model))

rows = cursor.fetchmany()

while rows:
for values in rows:
instance = _construct_model(
model, columns[:field_offset], values[:field_offset]
)

for index, related_field_name in enumerate(related_fields):
related_model = model._meta.get_field(
related_field_name
).related_model
if not related_model:
continue

related_field_count = len(
inspect_model_local_concrete_fields(related_model)
)

# autopep8: off
related_columns = columns[
field_offset : field_offset + related_field_count # noqa
]
related_values = values[
field_offset : field_offset + related_field_count # noqa
]
# autopep8: one

if (
not related_columns
or not related_values
or all([value is None for value in related_values])
):
continue

related_instance = _construct_model(
cast(Type[Model], related_model),
related_columns,
related_values,
)
instance._state.fields_cache[related_field_name] = related_instance # type: ignore

field_offset += len(
inspect_model_local_concrete_fields(related_model)
)

yield instance

rows = cursor.fetchmany()


def model_from_cursor(
model: Type[TModel], cursor, *, related_fields: List[str] = []
) -> Optional[TModel]:
return next(
models_from_cursor(model, cursor, related_fields=related_fields), None
)


def model_from_dict(
model: Type[TModel], row: Dict[str, Any], *, apply_converters: bool = True
) -> TModel:
return _construct_model(
model, row.keys(), row.values(), apply_converters=apply_converters
)
Loading