diff --git a/docs/source/conflict_handling.rst b/docs/source/conflict_handling.rst index 89d1a0c..cb9423a 100644 --- a/docs/source/conflict_handling.rst +++ b/docs/source/conflict_handling.rst @@ -232,6 +232,42 @@ Alternatively, with Django 3.1 or newer, :class:`~django:django.db.models.Q` obj Q(name__gt=ExcludedCol('priority')) +Update values +""""""""""""" + +Optionally, the fields to update can be overriden. The default is to update the same fields that were specified in the rows to insert. + +Refer to the insert values using the :class:`psqlextra.expressions.ExcludedCol` expression which translates to PostgreSQL's ``EXCLUDED.`` expression. All expressions and features that can be used with Django's :meth:`~django:django.db.models.query.QuerySet.update` can be used here. + +.. warning:: + + Specifying an empty ``update_values`` (``{}``) will transform the query into :attr:`~psqlextra.types.ConflictAction.NOTHING`. Only ``None`` makes the default behaviour kick in of updating all fields that were specified. + +.. code-block:: python + + from django.db.models import F + + from psqlextra.expressions import ExcludedCol + + ( + MyModel + .objects + .on_conflict( + ['name'], + ConflictAction.UPDATE, + update_values=dict( + name=ExcludedCol('name'), + count=F('count') + 1, + ), + ) + .insert( + name='henk', + count=0, + ) + ) + + + ConflictAction.NOTHING ********************** diff --git a/psqlextra/compiler.py b/psqlextra/compiler.py index 88a65e9..36aad20 100644 --- a/psqlextra/compiler.py +++ b/psqlextra/compiler.py @@ -3,7 +3,7 @@ import sys from collections.abc import Iterable -from typing import Tuple, Union +from typing import TYPE_CHECKING, Tuple, Union, cast import django @@ -12,11 +12,13 @@ from django.db.models import Expression, Model, Q from django.db.models.fields.related import RelatedField from django.db.models.sql import compiler as django_compiler -from django.db.utils import ProgrammingError from .expressions import HStoreValue from .types import ConflictAction +if TYPE_CHECKING: + from .sql import PostgresInsertQuery + def append_caller_to_sql(sql): """Append the caller to SQL queries. @@ -104,8 +106,7 @@ def as_sql(self, *args, **kwargs): return append_caller_to_sql(sql), params def _prepare_query_values(self): - """Extra prep on query values by converting dictionaries into. - + """Extra prep on query values by converting dictionaries into :see:HStoreValue expressions. This allows putting expressions in a dictionary. The @@ -162,6 +163,8 @@ def as_sql(self, *args, **kwargs): class PostgresInsertOnConflictCompiler(django_compiler.SQLInsertCompiler): # type: ignore [name-defined] """Compiler for SQL INSERT statements.""" + query: "PostgresInsertQuery" + def __init__(self, *args, **kwargs): """Initializes a new instance of :see:PostgresInsertOnConflictCompiler.""" @@ -170,6 +173,7 @@ def __init__(self, *args, **kwargs): def as_sql(self, return_id=False, *args, **kwargs): """Builds the SQL INSERT statement.""" + queries = [ self._rewrite_insert(sql, params, return_id) for sql, params in super().as_sql(*args, **kwargs) @@ -177,28 +181,6 @@ def as_sql(self, return_id=False, *args, **kwargs): return queries - def execute_sql(self, return_id=False): - # execute all the generate queries - with self.connection.cursor() as cursor: - rows = [] - for sql, params in self.as_sql(return_id): - cursor.execute(sql, params) - try: - rows.extend(cursor.fetchall()) - except ProgrammingError: - pass - description = cursor.description - - # create a mapping between column names and column value - return [ - { - column.name: row[column_index] - for column_index, column in enumerate(description) - if row - } - for row in rows - ] - def _rewrite_insert(self, sql, params, return_id=False): """Rewrites a formed SQL INSERT query to include the ON CONFLICT clause. @@ -210,9 +192,9 @@ def _rewrite_insert(self, sql, params, return_id=False): params: The parameters passed to the query. - returning: - What to put in the `RETURNING` clause - of the resulting query. + return_id: + Whether to only return the ID or all + columns. Returns: A tuple of the rewritten SQL query and new params. @@ -234,13 +216,6 @@ def _rewrite_insert_on_conflict( """Rewrites a normal SQL INSERT query to add the 'ON CONFLICT' clause.""" - update_columns = ", ".join( - [ - "{0} = EXCLUDED.{0}".format(self.qn(field.column)) - for field in self.query.update_fields # type: ignore[attr-defined] - ] - ) - # build the conflict target, the columns to watch # for conflicts on_conflict_clause = self._build_on_conflict_clause() @@ -254,10 +229,21 @@ def _rewrite_insert_on_conflict( rewritten_sql += f" WHERE {expr_sql}" params += tuple(expr_params) + # Fallback in case the user didn't specify any update values. We can still + # make the query work if we switch to ConflictAction.NOTHING + if ( + conflict_action == ConflictAction.UPDATE.value + and not self.query.update_values + ): + conflict_action = ConflictAction.NOTHING + rewritten_sql += f" DO {conflict_action}" - if conflict_action == "UPDATE": - rewritten_sql += f" SET {update_columns}" + if conflict_action == ConflictAction.UPDATE.value: + set_sql, sql_params = self._build_set_statement() + + rewritten_sql += f" SET {set_sql}" + params += sql_params if update_condition: expr_sql, expr_params = self._compile_expression( @@ -270,6 +256,23 @@ def _rewrite_insert_on_conflict( return (rewritten_sql, params) + def _build_set_statement(self) -> Tuple[str, tuple]: + """Builds the SET statement for the ON CONFLICT DO UPDATE clause. + + This uses the update compiler to provide full compatibility with + the standard Django's `update(...)`. + """ + + # Local import to work around the circular dependency between + # the compiler and the queries. + from .sql import PostgresUpdateQuery + + query = cast(PostgresUpdateQuery, self.query.chain(PostgresUpdateQuery)) + query.add_update_values(self.query.update_values) + + sql, params = query.get_compiler(self.connection.alias).as_sql() + return sql.split("SET")[1].split(" WHERE")[0], tuple(params) + def _build_on_conflict_clause(self): if django.VERSION >= (2, 2): from django.db.models.constraints import BaseConstraint diff --git a/psqlextra/introspect/__init__.py b/psqlextra/introspect/__init__.py new file mode 100644 index 0000000..bd85935 --- /dev/null +++ b/psqlextra/introspect/__init__.py @@ -0,0 +1,8 @@ +from .fields import inspect_model_local_concrete_fields +from .models import model_from_cursor, models_from_cursor + +__all__ = [ + "models_from_cursor", + "model_from_cursor", + "inspect_model_local_concrete_fields", +] diff --git a/psqlextra/introspect/fields.py b/psqlextra/introspect/fields.py new file mode 100644 index 0000000..27ef28f --- /dev/null +++ b/psqlextra/introspect/fields.py @@ -0,0 +1,21 @@ +from typing import List, Type + +from django.db.models import Field, Model + + +def inspect_model_local_concrete_fields(model: Type[Model]) -> List[Field]: + """Gets a complete list of local and concrete fields on a model, these are + fields that directly map to a database colmn directly on the table backing + the model. + + This is similar to Django's `Meta.local_concrete_fields`, which is a + private API. This method utilizes only public APIs. + """ + + local_concrete_fields = [] + + for field in model._meta.get_fields(include_parents=False): + if isinstance(field, Field) and field.column and not field.many_to_many: + local_concrete_fields.append(field) + + return local_concrete_fields diff --git a/psqlextra/introspect/models.py b/psqlextra/introspect/models.py new file mode 100644 index 0000000..cad84e9 --- /dev/null +++ b/psqlextra/introspect/models.py @@ -0,0 +1,170 @@ +from typing import ( + Any, + Dict, + Generator, + Iterable, + List, + Optional, + Type, + TypeVar, + cast, +) + +from django.core.exceptions import FieldDoesNotExist +from django.db import connection, models +from django.db.models import Field, Model +from django.db.models.expressions import Expression + +from .fields import inspect_model_local_concrete_fields + +TModel = TypeVar("TModel", bound=models.Model) + + +def _construct_model( + model: Type[TModel], + columns: Iterable[str], + values: Iterable[Any], + *, + apply_converters: bool = True +) -> TModel: + fields_by_name_and_column = {} + for field in inspect_model_local_concrete_fields(model): + fields_by_name_and_column[field.attname] = field + + if field.db_column: + fields_by_name_and_column[field.db_column] = field + + indexable_columns = list(columns) + + row = {} + + for index, value in enumerate(values): + column = indexable_columns[index] + try: + field = cast(Field, model._meta.get_field(column)) + except FieldDoesNotExist: + field = fields_by_name_and_column[column] + + field_column_expression = field.get_col(model._meta.db_table) + + if apply_converters: + converters = cast(Expression, field).get_db_converters( + connection + ) + connection.ops.get_db_converters(field_column_expression) + + converted_value = value + for converter in converters: + converted_value = converter( + converted_value, + field_column_expression, + connection, + ) + else: + converted_value = value + + row[field.attname] = converted_value + + instance = model(**row) + instance._state.adding = False + instance._state.db = connection.alias + + return instance + + +def models_from_cursor( + model: Type[TModel], cursor, *, related_fields: List[str] = [] +) -> Generator[TModel, None, None]: + """Fetches all rows from a cursor and converts the values into model + instances. + + This is roughly what Django does internally when you do queries. This + goes further than `Model.from_db` as it also applies converters to make + sure that values are converted into their Python equivalent. + + Use this when you've outgrown the ORM and you are writing performant + queries yourself and you need to map the results back into ORM objects. + + Arguments: + model: + Model to construct. + + cursor: + Cursor to read the rows from. + + related_fields: + List of ForeignKey/OneToOneField names that were joined + into the raw query. Use this to achieve the same thing + that Django's `.select_related()` does. + + Field names should be specified in the order that they + are SELECT'd in. + """ + + columns = [col[0] for col in cursor.description] + field_offset = len(inspect_model_local_concrete_fields(model)) + + rows = cursor.fetchmany() + + while rows: + for values in rows: + instance = _construct_model( + model, columns[:field_offset], values[:field_offset] + ) + + for index, related_field_name in enumerate(related_fields): + related_model = model._meta.get_field( + related_field_name + ).related_model + if not related_model: + continue + + related_field_count = len( + inspect_model_local_concrete_fields(related_model) + ) + + # autopep8: off + related_columns = columns[ + field_offset : field_offset + related_field_count # noqa + ] + related_values = values[ + field_offset : field_offset + related_field_count # noqa + ] + # autopep8: one + + if ( + not related_columns + or not related_values + or all([value is None for value in related_values]) + ): + continue + + related_instance = _construct_model( + cast(Type[Model], related_model), + related_columns, + related_values, + ) + instance._state.fields_cache[related_field_name] = related_instance # type: ignore + + field_offset += len( + inspect_model_local_concrete_fields(related_model) + ) + + yield instance + + rows = cursor.fetchmany() + + +def model_from_cursor( + model: Type[TModel], cursor, *, related_fields: List[str] = [] +) -> Optional[TModel]: + return next( + models_from_cursor(model, cursor, related_fields=related_fields), None + ) + + +def model_from_dict( + model: Type[TModel], row: Dict[str, Any], *, apply_converters: bool = True +) -> TModel: + return _construct_model( + model, row.keys(), row.values(), apply_converters=apply_converters + ) diff --git a/psqlextra/query.py b/psqlextra/query.py index b3feec1..a014d9a 100644 --- a/psqlextra/query.py +++ b/psqlextra/query.py @@ -2,6 +2,7 @@ from itertools import chain from typing import ( TYPE_CHECKING, + Any, Dict, Generic, Iterable, @@ -13,10 +14,13 @@ ) from django.core.exceptions import SuspiciousOperation -from django.db import connections, models, router +from django.db import models, router +from django.db.backends.utils import CursorWrapper from django.db.models import Expression, Q, QuerySet from django.db.models.fields import NOT_PROVIDED +from .expressions import ExcludedCol +from .introspect import model_from_cursor, models_from_cursor from .sql import PostgresInsertQuery, PostgresQuery from .types import ConflictAction @@ -51,6 +55,7 @@ def __init__(self, model=None, query=None, using=None, hints=None): self.conflict_action = None self.conflict_update_condition = None self.index_predicate = None + self.update_values = None def annotate(self, **annotations) -> "Self": # type: ignore[valid-type, override] """Custom version of the standard annotate function that allows using @@ -108,6 +113,7 @@ def on_conflict( action: ConflictAction, index_predicate: Optional[Union[Expression, Q, str]] = None, update_condition: Optional[Union[Expression, Q, str]] = None, + update_values: Optional[Dict[str, Union[Any, Expression]]] = None, ): """Sets the action to take when conflicts arise when attempting to insert/create a new row. @@ -125,18 +131,24 @@ def on_conflict( update_condition: Only update if this SQL expression evaluates to true. + + update_values: + Optionally, values/expressions to use when rows + conflict. If not specified, all columns specified + in the rows are updated with the values you specified. """ self.conflict_target = fields self.conflict_action = action self.conflict_update_condition = update_condition self.index_predicate = index_predicate + self.update_values = update_values return self def bulk_insert( self, - rows: Iterable[dict], + rows: Iterable[Dict[str, Any]], return_model: bool = False, using: Optional[str] = None, ): @@ -195,14 +207,17 @@ def is_empty(r): deduped_rows.append(row) compiler = self._build_insert_compiler(deduped_rows, using=using) - objs = compiler.execute_sql(return_id=not return_model) - if return_model: - return [ - self._create_model_instance(dict(row, **obj), compiler.using) - for row, obj in zip(deduped_rows, objs) - ] - return [dict(row, **obj) for row, obj in zip(deduped_rows, objs)] + with compiler.connection.cursor() as cursor: + for sql, params in compiler.as_sql(return_id=not return_model): + cursor.execute(sql, params) + + if return_model: + return list(models_from_cursor(self.model, cursor)) + + return self._consume_cursor_as_dicts( + cursor, original_rows=deduped_rows + ) def insert(self, using: Optional[str] = None, **fields): """Creates a new record in the database. @@ -223,17 +238,20 @@ def insert(self, using: Optional[str] = None, **fields): """ if self.conflict_target or self.conflict_action: - compiler = self._build_insert_compiler([fields], using=using) - rows = compiler.execute_sql(return_id=True) - if not self.model or not self.model.pk: return None - _, pk_db_column = self.model._meta.pk.get_attname_column() # type: ignore[union-attr] - if not rows or len(rows) == 0: - return None + compiler = self._build_insert_compiler([fields], using=using) - return rows[0][pk_db_column] + with compiler.connection.cursor() as cursor: + for sql, params in compiler.as_sql(return_id=True): + cursor.execute(sql, params) + + row = cursor.fetchone() + if not row: + return None + + return row[0] # no special action required, use the standard Django create(..) return super().create(**fields).pk @@ -261,30 +279,12 @@ def insert_and_get(self, using: Optional[str] = None, **fields): return super().create(**fields) compiler = self._build_insert_compiler([fields], using=using) - rows = compiler.execute_sql(return_id=False) - - if not rows: - return None - - columns = rows[0] - # get a list of columns that are officially part of the model and - # preserve the fact that the attribute name - # might be different than the database column name - model_columns = {} - for field in self.model._meta.local_concrete_fields: # type: ignore[attr-defined] - model_columns[field.column] = field.attname + with compiler.connection.cursor() as cursor: + for sql, params in compiler.as_sql(return_id=False): + cursor.execute(sql, params) - # strip out any columns/fields returned by the db that - # are not present in the model - model_init_fields = {} - for column_name, column_value in columns.items(): - try: - model_init_fields[model_columns[column_name]] = column_value - except KeyError: - pass - - return self._create_model_instance(model_init_fields, compiler.using) + return model_from_cursor(self.model, cursor) def upsert( self, @@ -293,6 +293,7 @@ def upsert( index_predicate: Optional[Union[Expression, Q, str]] = None, using: Optional[str] = None, update_condition: Optional[Union[Expression, Q, str]] = None, + update_values: Optional[Dict[str, Union[Any, Expression]]] = None, ) -> int: """Creates a new record or updates the existing one with the specified data. @@ -315,6 +316,11 @@ def upsert( update_condition: Only update if this SQL expression evaluates to true. + update_values: + Optionally, values/expressions to use when rows + conflict. If not specified, all columns specified + in the rows are updated with the values you specified. + Returns: The primary key of the row that was created/updated. """ @@ -324,6 +330,7 @@ def upsert( ConflictAction.UPDATE, index_predicate=index_predicate, update_condition=update_condition, + update_values=update_values, ) kwargs = {**fields, "using": using} @@ -336,6 +343,7 @@ def upsert_and_get( index_predicate: Optional[Union[Expression, Q, str]] = None, using: Optional[str] = None, update_condition: Optional[Union[Expression, Q, str]] = None, + update_values: Optional[Dict[str, Union[Any, Expression]]] = None, ): """Creates a new record or updates the existing one with the specified data and then gets the row. @@ -358,6 +366,11 @@ def upsert_and_get( update_condition: Only update if this SQL expression evaluates to true. + update_values: + Optionally, values/expressions to use when rows + conflict. If not specified, all columns specified + in the rows are updated with the values you specified. + Returns: The model instance representing the row that was created/updated. @@ -368,6 +381,7 @@ def upsert_and_get( ConflictAction.UPDATE, index_predicate=index_predicate, update_condition=update_condition, + update_values=update_values, ) kwargs = {**fields, "using": using} @@ -381,6 +395,7 @@ def bulk_upsert( return_model: bool = False, using: Optional[str] = None, update_condition: Optional[Union[Expression, Q, str]] = None, + update_values: Optional[Dict[str, Union[Any, Expression]]] = None, ): """Creates a set of new records or updates the existing ones with the specified data. @@ -407,6 +422,11 @@ def bulk_upsert( update_condition: Only update if this SQL expression evaluates to true. + update_values: + Optionally, values/expressions to use when rows + conflict. If not specified, all columns specified + in the rows are updated with the values you specified. + Returns: A list of either the dicts of the rows upserted, including the pk or the models of the rows upserted @@ -417,46 +437,28 @@ def bulk_upsert( ConflictAction.UPDATE, index_predicate=index_predicate, update_condition=update_condition, + update_values=update_values, ) - return self.bulk_insert(rows, return_model, using=using) - - def _create_model_instance( - self, field_values: dict, using: str, apply_converters: bool = True - ): - """Creates a new instance of the model with the specified field. - - Use this after the row was inserted into the database. The new - instance will marked as "saved". - """ - - converted_field_values = field_values.copy() - - if apply_converters: - connection = connections[using] - for field in self.model._meta.local_concrete_fields: # type: ignore[attr-defined] - if field.attname not in converted_field_values: - continue - - # converters can be defined on the field, or by - # the database back-end we're using - field_column = field.get_col(self.model._meta.db_table) - converters = field.get_db_converters( - connection - ) + connection.ops.get_db_converters(field_column) - - for converter in converters: - converted_field_values[field.attname] = converter( - converted_field_values[field.attname], - field_column, - connection, - ) - - instance = self.model(**converted_field_values) - instance._state.db = using - instance._state.adding = False + return self.bulk_insert(rows, return_model, using=using) - return instance + @staticmethod + def _consume_cursor_as_dicts( + cursor: CursorWrapper, *, original_rows: Iterable[Dict[str, Any]] + ) -> List[dict]: + cursor_description = cursor.description + + return [ + { + **original_row, + **{ + column.name: row[column_index] + for column_index, column in enumerate(cursor_description) + if row + }, + } + for original_row, row in zip(original_rows, cursor) + ] def _build_insert_compiler( self, rows: Iterable[Dict], using: Optional[str] = None @@ -500,12 +502,17 @@ def _build_insert_compiler( ).format(index) ) - objs.append( - self._create_model_instance(row, using, apply_converters=False) - ) + obj = self.model(**row.copy()) + obj._state.db = using + obj._state.adding = False + objs.append(obj) # get the fields to be used during update/insert - insert_fields, update_fields = self._get_upsert_fields(first_row) + insert_fields, update_values = self._get_upsert_fields(first_row) + + # allow the user to override what should happen on update + if self.update_values is not None: + update_values = self.update_values # build a normal insert query query = PostgresInsertQuery(self.model) @@ -513,7 +520,7 @@ def _build_insert_compiler( query.conflict_target = self.conflict_target query.conflict_update_condition = self.conflict_update_condition query.index_predicate = self.index_predicate - query.values(objs, insert_fields, update_fields) + query.insert_on_conflict_values(objs, insert_fields, update_values) compiler = query.get_compiler(using) return compiler @@ -578,13 +585,13 @@ def _get_upsert_fields(self, kwargs): model_instance = self.model(**kwargs) insert_fields = [] - update_fields = [] + update_values = {} for field in model_instance._meta.local_concrete_fields: has_default = field.default != NOT_PROVIDED if field.name in kwargs or field.column in kwargs: insert_fields.append(field) - update_fields.append(field) + update_values[field.name] = ExcludedCol(field.column) continue elif has_default: insert_fields.append(field) @@ -595,13 +602,13 @@ def _get_upsert_fields(self, kwargs): # instead of a concrete field, we have to handle that if field.primary_key is True and "pk" in kwargs: insert_fields.append(field) - update_fields.append(field) + update_values[field.name] = ExcludedCol(field.column) continue if self._is_magical_field(model_instance, field, is_insert=True): insert_fields.append(field) if self._is_magical_field(model_instance, field, is_insert=False): - update_fields.append(field) + update_values[field.name] = ExcludedCol(field.column) - return insert_fields, update_fields + return insert_fields, update_values diff --git a/psqlextra/sql.py b/psqlextra/sql.py index 2a5b418..3ceb596 100644 --- a/psqlextra/sql.py +++ b/psqlextra/sql.py @@ -1,5 +1,5 @@ from collections import OrderedDict -from typing import Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import django @@ -154,10 +154,14 @@ def __init__(self, *args, **kwargs): self.conflict_action = ConflictAction.UPDATE self.conflict_update_condition = None self.index_predicate = None - - self.update_fields = [] - - def values(self, objs, insert_fields, update_fields=[]): + self.update_values = {} + + def insert_on_conflict_values( + self, + objs: List, + insert_fields: List, + update_values: Dict[str, Union[Any, Expression]] = {}, + ): """Sets the values to be used in this query. Insert fields are fields that are definitely @@ -176,12 +180,13 @@ def values(self, objs, insert_fields, update_fields=[]): insert_fields: The fields to use in the INSERT statement - update_fields: - The fields to only use in the UPDATE statement. + update_values: + Expressions/values to use when a conflict + occurs and an UPDATE is performed. """ self.insert_values(insert_fields, objs, raw=False) - self.update_fields = update_fields + self.update_values = update_values def get_compiler(self, using=None, connection=None): if using: diff --git a/psqlextra/types.py b/psqlextra/types.py index a325fd9..f111807 100644 --- a/psqlextra/types.py +++ b/psqlextra/types.py @@ -28,6 +28,9 @@ class ConflictAction(Enum): def all(cls) -> List["ConflictAction"]: return [choice for choice in cls] + def __str__(self) -> str: + return self.value + class PostgresPartitioningMethod(StrEnum): """Methods of partitioning supported by PostgreSQL 11.x native support for diff --git a/setup.py b/setup.py index dd2ef9b..b3217fb 100644 --- a/setup.py +++ b/setup.py @@ -81,6 +81,8 @@ def run(self): "pytest-benchmark==3.4.1", "pytest-django==4.4.0", "pytest-cov==3.0.0", + "pytest-lazy-fixture==0.6.3", + "pytest-freezegun==0.4.2", "tox==3.24.4", "freezegun==1.1.0", "coveralls==3.3.0", diff --git a/tests/test_introspect.py b/tests/test_introspect.py new file mode 100644 index 0000000..21e842a --- /dev/null +++ b/tests/test_introspect.py @@ -0,0 +1,417 @@ +import django +import pytest + +from django.contrib.postgres.fields import ArrayField +from django.db import connection, models +from django.test.utils import CaptureQueriesContext +from django.utils import timezone + +from psqlextra.introspect import model_from_cursor, models_from_cursor + +from .fake_model import get_fake_model + +django_31_skip_reason = "Django < 3.1 does not support JSON fields which are required for these tests" + + +@pytest.fixture +def mocked_model_varying_fields(): + return get_fake_model( + { + "title": models.TextField(null=True), + "updated_at": models.DateTimeField(null=True), + "content": models.JSONField(null=True), + "items": ArrayField(models.TextField(), null=True), + } + ) + + +@pytest.fixture +def mocked_model_single_field(): + return get_fake_model( + { + "name": models.TextField(), + } + ) + + +@pytest.fixture +def mocked_model_foreign_keys( + mocked_model_varying_fields, mocked_model_single_field +): + return get_fake_model( + { + "varying_fields": models.ForeignKey( + mocked_model_varying_fields, null=True, on_delete=models.CASCADE + ), + "single_field": models.ForeignKey( + mocked_model_single_field, null=True, on_delete=models.CASCADE + ), + } + ) + + +@pytest.fixture +def mocked_model_varying_fields_instance(freezer, mocked_model_varying_fields): + return mocked_model_varying_fields.objects.create( + title="hello world", + updated_at=timezone.now(), + content={"a": 1}, + items=["a", "b"], + ) + + +@pytest.fixture +def models_from_cursor_wrapper_multiple(): + def _wrapper(*args, **kwargs): + return list(models_from_cursor(*args, **kwargs))[0] + + return _wrapper + + +@pytest.fixture +def models_from_cursor_wrapper_single(): + return model_from_cursor + + +@pytest.mark.skipif( + django.VERSION < (3, 1), + reason=django_31_skip_reason, +) +@pytest.mark.parametrize( + "models_from_cursor_wrapper", + [ + pytest.lazy_fixture("models_from_cursor_wrapper_multiple"), + pytest.lazy_fixture("models_from_cursor_wrapper_single"), + ], +) +def test_models_from_cursor_applies_converters( + mocked_model_varying_fields, + mocked_model_varying_fields_instance, + models_from_cursor_wrapper, +): + with connection.cursor() as cursor: + cursor.execute( + *mocked_model_varying_fields.objects.all().query.sql_with_params() + ) + queried_instance = models_from_cursor_wrapper( + mocked_model_varying_fields, cursor + ) + + assert queried_instance.id == mocked_model_varying_fields_instance.id + assert queried_instance.title == mocked_model_varying_fields_instance.title + assert ( + queried_instance.updated_at + == mocked_model_varying_fields_instance.updated_at + ) + assert ( + queried_instance.content == mocked_model_varying_fields_instance.content + ) + assert queried_instance.items == mocked_model_varying_fields_instance.items + + +@pytest.mark.skipif( + django.VERSION < (3, 1), + reason=django_31_skip_reason, +) +@pytest.mark.parametrize( + "models_from_cursor_wrapper", + [ + pytest.lazy_fixture("models_from_cursor_wrapper_multiple"), + pytest.lazy_fixture("models_from_cursor_wrapper_single"), + ], +) +def test_models_from_cursor_handles_field_order( + mocked_model_varying_fields, + mocked_model_varying_fields_instance, + models_from_cursor_wrapper, +): + with connection.cursor() as cursor: + cursor.execute( + f'SELECT content, items, id, title, updated_at FROM "{mocked_model_varying_fields._meta.db_table}"', + tuple(), + ) + queried_instance = models_from_cursor_wrapper( + mocked_model_varying_fields, cursor + ) + + assert queried_instance.id == mocked_model_varying_fields_instance.id + assert queried_instance.title == mocked_model_varying_fields_instance.title + assert ( + queried_instance.updated_at + == mocked_model_varying_fields_instance.updated_at + ) + assert ( + queried_instance.content == mocked_model_varying_fields_instance.content + ) + assert queried_instance.items == mocked_model_varying_fields_instance.items + + +@pytest.mark.skipif( + django.VERSION < (3, 1), + reason=django_31_skip_reason, +) +@pytest.mark.parametrize( + "models_from_cursor_wrapper", + [ + pytest.lazy_fixture("models_from_cursor_wrapper_multiple"), + pytest.lazy_fixture("models_from_cursor_wrapper_single"), + ], +) +def test_models_from_cursor_handles_partial_fields( + mocked_model_varying_fields, + mocked_model_varying_fields_instance, + models_from_cursor_wrapper, +): + with connection.cursor() as cursor: + cursor.execute( + f'SELECT id FROM "{mocked_model_varying_fields._meta.db_table}"', + tuple(), + ) + queried_instance = models_from_cursor_wrapper( + mocked_model_varying_fields, cursor + ) + + assert queried_instance.id == mocked_model_varying_fields_instance.id + assert queried_instance.title is None + assert queried_instance.updated_at is None + assert queried_instance.content is None + assert queried_instance.items is None + + +@pytest.mark.skipif( + django.VERSION < (3, 1), + reason=django_31_skip_reason, +) +@pytest.mark.parametrize( + "models_from_cursor_wrapper", + [ + pytest.lazy_fixture("models_from_cursor_wrapper_multiple"), + pytest.lazy_fixture("models_from_cursor_wrapper_single"), + ], +) +def test_models_from_cursor_handles_null( + mocked_model_varying_fields, models_from_cursor_wrapper +): + instance = mocked_model_varying_fields.objects.create() + + with connection.cursor() as cursor: + cursor.execute( + *mocked_model_varying_fields.objects.all().query.sql_with_params() + ) + queried_instance = models_from_cursor_wrapper( + mocked_model_varying_fields, cursor + ) + + assert queried_instance.id == instance.id + assert queried_instance.title is None + assert queried_instance.updated_at is None + assert queried_instance.content is None + assert queried_instance.items is None + + +@pytest.mark.skipif( + django.VERSION < (3, 1), + reason=django_31_skip_reason, +) +@pytest.mark.parametrize( + "models_from_cursor_wrapper", + [ + pytest.lazy_fixture("models_from_cursor_wrapper_multiple"), + pytest.lazy_fixture("models_from_cursor_wrapper_single"), + ], +) +def test_models_from_cursor_foreign_key( + mocked_model_single_field, + mocked_model_foreign_keys, + models_from_cursor_wrapper, +): + instance = mocked_model_foreign_keys.objects.create( + varying_fields=None, + single_field=mocked_model_single_field.objects.create(name="test"), + ) + + with connection.cursor() as cursor: + cursor.execute( + *mocked_model_foreign_keys.objects.all().query.sql_with_params() + ) + queried_instance = models_from_cursor_wrapper( + mocked_model_foreign_keys, cursor + ) + + with CaptureQueriesContext(connection) as ctx: + assert queried_instance.id == instance.id + assert queried_instance.varying_fields_id is None + assert queried_instance.varying_fields is None + assert queried_instance.single_field_id == instance.single_field_id + assert queried_instance.single_field.id == instance.single_field.id + assert queried_instance.single_field.name == instance.single_field.name + + assert len(ctx.captured_queries) == 1 + + +@pytest.mark.skipif( + django.VERSION < (3, 1), + reason=django_31_skip_reason, +) +@pytest.mark.parametrize( + "models_from_cursor_wrapper", + [ + pytest.lazy_fixture("models_from_cursor_wrapper_multiple"), + pytest.lazy_fixture("models_from_cursor_wrapper_single"), + ], +) +def test_models_from_cursor_related_fields( + mocked_model_varying_fields, + mocked_model_single_field, + mocked_model_foreign_keys, + models_from_cursor_wrapper, +): + instance = mocked_model_foreign_keys.objects.create( + varying_fields=mocked_model_varying_fields.objects.create( + title="test", updated_at=timezone.now() + ), + single_field=mocked_model_single_field.objects.create(name="test"), + ) + + with connection.cursor() as cursor: + cursor.execute( + *mocked_model_foreign_keys.objects.select_related( + "varying_fields", "single_field" + ) + .all() + .query.sql_with_params() + ) + queried_instance = models_from_cursor_wrapper( + mocked_model_foreign_keys, + cursor, + related_fields=["varying_fields", "single_field"], + ) + + with CaptureQueriesContext(connection) as ctx: + assert queried_instance.id == instance.id + + assert queried_instance.varying_fields_id == instance.varying_fields_id + assert queried_instance.varying_fields.id == instance.varying_fields.id + assert ( + queried_instance.varying_fields.title + == instance.varying_fields.title + ) + assert ( + queried_instance.varying_fields.updated_at + == instance.varying_fields.updated_at + ) + assert ( + queried_instance.varying_fields.content + == instance.varying_fields.content + ) + assert ( + queried_instance.varying_fields.items + == instance.varying_fields.items + ) + + assert queried_instance.single_field_id == instance.single_field_id + assert queried_instance.single_field.id == instance.single_field.id + assert queried_instance.single_field.name == instance.single_field.name + + assert len(ctx.captured_queries) == 0 + + +@pytest.mark.skipif( + django.VERSION < (3, 1), + reason=django_31_skip_reason, +) +@pytest.mark.parametrize( + "models_from_cursor_wrapper", + [ + pytest.lazy_fixture("models_from_cursor_wrapper_multiple"), + pytest.lazy_fixture("models_from_cursor_wrapper_single"), + ], +) +@pytest.mark.parametrize( + "selected", [True, False], ids=["selected", "not_selected"] +) +def test_models_from_cursor_related_fields_optional( + mocked_model_varying_fields, + mocked_model_foreign_keys, + models_from_cursor_wrapper, + selected, +): + instance = mocked_model_foreign_keys.objects.create( + varying_fields=mocked_model_varying_fields.objects.create( + title="test", updated_at=timezone.now() + ), + single_field=None, + ) + + with connection.cursor() as cursor: + select_related = ["varying_fields"] + if selected: + select_related.append("single_field") + + cursor.execute( + *mocked_model_foreign_keys.objects.select_related(*select_related) + .all() + .query.sql_with_params() + ) + queried_instance = models_from_cursor_wrapper( + mocked_model_foreign_keys, + cursor, + related_fields=["varying_fields", "single_field"], + ) + + assert queried_instance.id == instance.id + assert queried_instance.varying_fields_id == instance.varying_fields_id + assert queried_instance.single_field_id == instance.single_field_id + + with CaptureQueriesContext(connection) as ctx: + assert queried_instance.varying_fields.id == instance.varying_fields.id + assert ( + queried_instance.varying_fields.title + == instance.varying_fields.title + ) + assert ( + queried_instance.varying_fields.updated_at + == instance.varying_fields.updated_at + ) + assert ( + queried_instance.varying_fields.content + == instance.varying_fields.content + ) + assert ( + queried_instance.varying_fields.items + == instance.varying_fields.items + ) + + assert queried_instance.single_field is None + + assert len(ctx.captured_queries) == 0 + + +@pytest.mark.skipif( + django.VERSION < (3, 1), + reason=django_31_skip_reason, +) +def test_models_from_cursor_generator_efficiency( + mocked_model_varying_fields, mocked_model_single_field +): + mocked_model_single_field.objects.create(name="a") + mocked_model_single_field.objects.create(name="b") + + with connection.cursor() as cursor: + cursor.execute( + *mocked_model_single_field.objects.all().query.sql_with_params() + ) + + instances_generator = models_from_cursor( + mocked_model_single_field, cursor + ) + assert cursor.rownumber == 0 + + next(instances_generator) + assert cursor.rownumber == 1 + + next(instances_generator) + assert cursor.rownumber == 2 + + assert not next(instances_generator, None) + assert cursor.rownumber == 2 diff --git a/tests/test_upsert.py b/tests/test_upsert.py index b9176da..3aa6207 100644 --- a/tests/test_upsert.py +++ b/tests/test_upsert.py @@ -1,8 +1,8 @@ import django import pytest -from django.db import models -from django.db.models import Q +from django.db import connection, models +from django.db.models import F, Q from django.db.models.expressions import CombinedExpression, Value from psqlextra.expressions import ExcludedCol @@ -144,6 +144,54 @@ def test_upsert_with_update_condition(): assert obj1.active +def test_upsert_with_update_values(): + """Tests that the default update values can be overriden with custom + expressions.""" + + model = get_fake_model( + { + "name": models.TextField(unique=True), + "count": models.IntegerField(default=0), + } + ) + + obj1 = model.objects.create(name="joe") + + model.objects.upsert( + conflict_target=["name"], + fields=dict(name="joe"), + update_values=dict( + count=F("count") + 1, + ), + ) + + obj1.refresh_from_db() + assert obj1.count == 1 + + +def test_upsert_with_update_values_empty(): + """Tests that an upsert with an empty dict turns into ON CONFLICT DO + NOTHING.""" + + model = get_fake_model( + { + "name": models.TextField(unique=True), + "count": models.IntegerField(default=0), + } + ) + + obj1 = model.objects.create(name="joe") + + model.objects.upsert( + conflict_target=["name"], + fields=dict(name="joe"), + update_values={}, + ) + + obj1.refresh_from_db() + assert obj1.count == 0 + + @pytest.mark.skipif( django.VERSION < (3, 1), reason="requires django 3.1 or newer" ) @@ -200,7 +248,7 @@ def from_db_value(self, value, expression, connection): assert obj.title == "bye" -def test_upsert_bulk(): +def test_bulk_upsert(): """Tests whether bulk_upsert works properly.""" model = get_fake_model( @@ -337,3 +385,93 @@ def __iter__(self): for index, obj in enumerate(objs, 1): assert isinstance(obj, model) assert obj.id == index + + +def test_bulk_upsert_update_values(): + model = get_fake_model( + { + "name": models.CharField(max_length=255, unique=True), + "count": models.IntegerField(default=0), + } + ) + + model.objects.bulk_create( + [ + model(name="joe"), + model(name="john"), + ] + ) + + objs = model.objects.bulk_upsert( + conflict_target=["name"], + rows=[], + return_model=True, + update_values=dict(count=F("count") + 1), + ) + + assert all([obj for obj in objs if obj.count == 1]) + + +@pytest.mark.parametrize("return_model", [True]) +def test_bulk_upsert_extra_columns_in_schema(return_model): + """Tests that extra columns being returned by the database that aren't + known by Django don't make the bulk upsert crash.""" + + model = get_fake_model( + { + "name": models.CharField(max_length=255, unique=True), + } + ) + + with connection.cursor() as cursor: + cursor.execute( + f"ALTER TABLE {model._meta.db_table} ADD COLUMN new_name text NOT NULL DEFAULT %s", + ("newjoe",), + ) + + objs = model.objects.bulk_upsert( + conflict_target=["name"], + rows=[ + dict(name="joe"), + ], + return_model=return_model, + ) + + assert len(objs) == 1 + + if return_model: + assert objs[0].name == "joe" + else: + assert objs[0]["name"] == "joe" + assert sorted(list(objs[0].keys())) == ["id", "name"] + + +def test_upsert_extra_columns_in_schema(): + """Tests that extra columns being returned by the database that aren't + known by Django don't make the upsert crash.""" + + model = get_fake_model( + { + "name": models.CharField(max_length=255, unique=True), + } + ) + + with connection.cursor() as cursor: + cursor.execute( + f"ALTER TABLE {model._meta.db_table} ADD COLUMN new_name text NOT NULL DEFAULT %s", + ("newjoe",), + ) + + obj_id = model.objects.upsert( + conflict_target=["name"], + fields=dict(name="joe"), + ) + + assert obj_id == 1 + + obj = model.objects.upsert_and_get( + conflict_target=["name"], + fields=dict(name="joe"), + ) + + assert obj.name == "joe"