From 2adc13fdc99856d23ecc55a53597225024dc6b05 Mon Sep 17 00:00:00 2001 From: Christoph Krybus Date: Sat, 11 Feb 2023 12:51:12 +0100 Subject: [PATCH 01/43] Make time partition suffix customizable --- .../partitioning/current_time_strategy.py | 10 ++++- psqlextra/partitioning/shorthands.py | 10 ++++- psqlextra/partitioning/time_partition.py | 11 +++++- tests/test_partitioning_time.py | 39 +++++++++++++++++++ 4 files changed, 65 insertions(+), 5 deletions(-) diff --git a/psqlextra/partitioning/current_time_strategy.py b/psqlextra/partitioning/current_time_strategy.py index a0268be6..114a1aaf 100644 --- a/psqlextra/partitioning/current_time_strategy.py +++ b/psqlextra/partitioning/current_time_strategy.py @@ -24,6 +24,7 @@ def __init__( size: PostgresTimePartitionSize, count: int, max_age: Optional[relativedelta] = None, + name_format: Optional[str] = None, ) -> None: """Initializes a new instance of :see:PostgresTimePartitioningStrategy. @@ -44,13 +45,16 @@ def __init__( self.size = size self.count = count self.max_age = max_age + self.name_format = name_format def to_create(self) -> Generator[PostgresTimePartition, None, None]: current_datetime = self.size.start(self.get_start_datetime()) for _ in range(self.count): yield PostgresTimePartition( - start_datetime=current_datetime, size=self.size + start_datetime=current_datetime, + size=self.size, + name_format=self.name_format, ) current_datetime += self.size.as_delta() @@ -65,7 +69,9 @@ def to_delete(self) -> Generator[PostgresTimePartition, None, None]: while True: yield PostgresTimePartition( - start_datetime=current_datetime, size=self.size + start_datetime=current_datetime, + size=self.size, + name_format=self.name_format, ) current_datetime -= self.size.as_delta() diff --git a/psqlextra/partitioning/shorthands.py b/psqlextra/partitioning/shorthands.py index 05ce4a34..dab65e4f 100644 --- a/psqlextra/partitioning/shorthands.py +++ b/psqlextra/partitioning/shorthands.py @@ -17,6 +17,7 @@ def partition_by_current_time( weeks: Optional[int] = None, days: Optional[int] = None, max_age: Optional[relativedelta] = None, + name_format: Optional[str] = None, ) -> PostgresPartitioningConfig: """Short-hand for generating a partitioning config that partitions the specified model by time. @@ -48,6 +49,10 @@ def partition_by_current_time( Partitions older than this are deleted when running a delete/cleanup run. + + name_format: + The datetime format which is being passed to datetime.strftime + to generate the partition name. """ size = PostgresTimePartitionSize( @@ -57,7 +62,10 @@ def partition_by_current_time( return PostgresPartitioningConfig( model=model, strategy=PostgresCurrentTimePartitioningStrategy( - size=size, count=count, max_age=max_age + size=size, + count=count, + max_age=max_age, + name_format=name_format, ), ) diff --git a/psqlextra/partitioning/time_partition.py b/psqlextra/partitioning/time_partition.py index b6be67a1..3c8a4d87 100644 --- a/psqlextra/partitioning/time_partition.py +++ b/psqlextra/partitioning/time_partition.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import Optional from .error import PostgresPartitioningError from .range_partition import PostgresRangePartition @@ -22,7 +23,10 @@ class PostgresTimePartition(PostgresRangePartition): } def __init__( - self, size: PostgresTimePartitionSize, start_datetime: datetime + self, + size: PostgresTimePartitionSize, + start_datetime: datetime, + name_format: Optional[str] = None, ) -> None: end_datetime = start_datetime + size.as_delta() @@ -34,9 +38,12 @@ def __init__( self.size = size self.start_datetime = start_datetime self.end_datetime = end_datetime + self.name_format = name_format def name(self) -> str: - name_format = self._unit_name_format.get(self.size.unit) + name_format = self.name_format or self._unit_name_format.get( + self.size.unit + ) if not name_format: raise PostgresPartitioningError("Unknown size/unit") diff --git a/tests/test_partitioning_time.py b/tests/test_partitioning_time.py index 68808324..9f6b5bf1 100644 --- a/tests/test_partitioning_time.py +++ b/tests/test_partitioning_time.py @@ -115,6 +115,45 @@ def test_partitioning_time_monthly_apply(): assert table.partitions[13].name == "2020_feb" +@pytest.mark.postgres_version(lt=110000) +def test_partitioning_time_monthly_with_custom_naming_apply(): + """Tests whether automatically created new partitions are named according + to the specified name_format.""" + + model = define_fake_partitioned_model( + {"timestamp": models.DateTimeField()}, {"key": ["timestamp"]} + ) + + schema_editor = connection.schema_editor() + schema_editor.create_partitioned_model(model) + + # create partitions for the next 12 months (including the current) + with freezegun.freeze_time("2019-1-30"): + manager = PostgresPartitioningManager( + [ + partition_by_current_time( + model, months=1, count=12, name_format="%Y_%m" + ) + ] + ) + manager.plan().apply() + + table = _get_partitioned_table(model) + assert len(table.partitions) == 12 + assert table.partitions[0].name == "2019_01" + assert table.partitions[1].name == "2019_02" + assert table.partitions[2].name == "2019_03" + assert table.partitions[3].name == "2019_04" + assert table.partitions[4].name == "2019_05" + assert table.partitions[5].name == "2019_06" + assert table.partitions[6].name == "2019_07" + assert table.partitions[7].name == "2019_08" + assert table.partitions[8].name == "2019_09" + assert table.partitions[9].name == "2019_10" + assert table.partitions[10].name == "2019_11" + assert table.partitions[11].name == "2019_12" + + @pytest.mark.postgres_version(lt=110000) def test_partitioning_time_weekly_apply(): """Tests whether automatically creating new partitions ahead weekly works From 2e13eb14134e2cb42bdd22ee1b326affb0b58709 Mon Sep 17 00:00:00 2001 From: Selcuk Ayguney Date: Sat, 25 Mar 2023 22:01:38 +1000 Subject: [PATCH 02/43] Fixed typo in partitioning docs `hash` -> `hash` (#201) --- docs/source/table_partitioning.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/table_partitioning.rst b/docs/source/table_partitioning.rst index 1869aed2..8ad3d115 100644 --- a/docs/source/table_partitioning.rst +++ b/docs/source/table_partitioning.rst @@ -196,7 +196,7 @@ You can look at :class:`psqlextra.partitioning.PostgresCurrentTimePartitioningSt Manually managing partitions ---------------------------- -If you are using list or has partitioning, you most likely have a fixed amount of partitions that can be created up front using migrations or using the schema editor. +If you are using list or hash partitioning, you most likely have a fixed amount of partitions that can be created up front using migrations or using the schema editor. Using migration operations ************************** From 2156dfe0ea2276507d8d56ac7e82791a05faaf15 Mon Sep 17 00:00:00 2001 From: Christoph Krybus Date: Sat, 25 Mar 2023 13:02:20 +0100 Subject: [PATCH 03/43] Fix typo in CONTRIBUTING.md, `accomponied` -> `accompanied` (#199) --- CONTRIBUTING.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 04ba94f3..cd0836a9 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -16,4 +16,4 @@ If you're unsure whether your change would be a good fit for `django-postgres-ex * PyLint passes. * PEP8 passes. * Features that allow creating custom indexes or fields must also implement the associated migrations. `django-postgres-extra` prides itself on the fact that it integrates smoothly with Django migrations. We'd like to keep it that way for all features. -* Sufficiently complicated changes must be accomponied by tests. +* Sufficiently complicated changes must be accompanied by tests. From 2df1056a8f0e65f18189b70d768482453f020bf8 Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Mon, 3 Apr 2023 16:48:34 +0300 Subject: [PATCH 04/43] Fix PostgresIntrospection.get_constraints crashing for PK in PostgreSQL 13.x --- psqlextra/backend/introspection.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/psqlextra/backend/introspection.py b/psqlextra/backend/introspection.py index a85f27cd..5d9e8d5b 100644 --- a/psqlextra/backend/introspection.py +++ b/psqlextra/backend/introspection.py @@ -187,8 +187,14 @@ def get_constraints(self, cursor, table_name: str): "SELECT indexname, indexdef FROM pg_indexes WHERE tablename = %s", (table_name,), ) - for index, definition in cursor.fetchall(): - if constraints[index].get("definition") is None: - constraints[index]["definition"] = definition + for index_name, definition in cursor.fetchall(): + # PostgreSQL 13 or older won't give a definition if the + # index is actually a primary key. + constraint = constraints.get(index_name) + if not constraint: + continue + + if constraint.get("definition") is None: + constraint["definition"] = definition return constraints From 9b809d158e7285e8c43ca4fc888aa984556d3a10 Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Tue, 4 Apr 2023 09:16:57 +0300 Subject: [PATCH 05/43] Don't point Postgres documentation links to a specific version --- psqlextra/expressions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/psqlextra/expressions.py b/psqlextra/expressions.py index 1840283c..75351e68 100644 --- a/psqlextra/expressions.py +++ b/psqlextra/expressions.py @@ -212,7 +212,7 @@ class ExcludedCol(expressions.Expression): """References a column in PostgreSQL's special EXCLUDED column, which is used in upserts to refer to the data about to be inserted/updated. - See: https://www.postgresql.org/docs/9.5/sql-insert.html#SQL-ON-CONFLICT + See: https://www.postgresql.org/docs/current/sql-insert.html#SQL-ON-CONFLICT """ def __init__(self, name: str): From 5477e7a1412782602b7e77d44741c71f6548925a Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Tue, 4 Apr 2023 09:14:31 +0300 Subject: [PATCH 06/43] Add support for explicit table-level locks --- docs/source/api_reference.rst | 3 + docs/source/index.rst | 5 + docs/source/locking.rst | 56 ++++++++++ docs/source/snippets/postgres_doc_links.rst | 1 + psqlextra/backend/introspection.py | 18 +++- psqlextra/locking.py | 97 ++++++++++++++++++ tests/test_locking.py | 107 ++++++++++++++++++++ 7 files changed, 286 insertions(+), 1 deletion(-) create mode 100644 docs/source/locking.rst create mode 100644 psqlextra/locking.py create mode 100644 tests/test_locking.py diff --git a/docs/source/api_reference.rst b/docs/source/api_reference.rst index 554f50ea..1d64bc8b 100644 --- a/docs/source/api_reference.rst +++ b/docs/source/api_reference.rst @@ -37,6 +37,9 @@ API Reference .. autoclass:: ConditionalUniqueIndex .. autoclass:: CaseInsensitiveUniqueIndex +.. automodule:: psqlextra.locking + :members: + .. automodule:: psqlextra.partitioning :members: diff --git a/docs/source/index.rst b/docs/source/index.rst index 28b61560..76600702 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -35,6 +35,10 @@ Explore the documentation to learn about all features: Support for ``TRUNCATE TABLE`` statements (including cascading). +* :ref:`Locking models & tables ` + + Support for explicit table-level locks. + .. toctree:: :maxdepth: 2 @@ -49,6 +53,7 @@ Explore the documentation to learn about all features: table_partitioning expressions annotations + locking settings api_reference major_releases diff --git a/docs/source/locking.rst b/docs/source/locking.rst new file mode 100644 index 00000000..160d18e2 --- /dev/null +++ b/docs/source/locking.rst @@ -0,0 +1,56 @@ +.. include:: ./snippets/postgres_doc_links.rst + +.. _locking_page: + +Locking +======= + +`Explicit table-level locks`_ are supported through the :meth:`psqlextra.locking.postgres_lock_model` and :meth:`psqlextra.locking.postgres_lock_table` methods. All table-level lock methods are supported. + +Locks are always bound to the current transaction and are released when the transaction is comitted or rolled back. There is no support (in PostgreSQL) for explicitly releasing a lock. + +.. warning:: + + Locks are only released when the *outer* transaction commits or when a nested transaction is rolled back. You can ensure that the transaction you created is the outermost one by passing the ``durable=True`` argument to ``transaction.atomic``. + +.. note:: + + Use `django-pglocks `_ if you need a advisory lock. + +Locking a model +--------------- + +Use :class:`psqlextra.locking.PostgresTableLockMode` to indicate the type of lock to acquire. + +.. code-block:: python + + from django.db import transaction + + from psqlextra.locking import PostgresTableLockMode, postgres_lock_table + + with transaction.atomic(durable=True): + postgres_lock_model(MyModel, PostgresTableLockMode.EXCLUSIVE) + + # locks are released here, when the transaction comitted + + +Locking a table +--------------- + +Use :meth:`psqlextra.locking.postgres_lock_table` to lock arbitrary tables in arbitrary schemas. + +.. code-block:: python + + from django.db import transaction + + from psqlextra.locking import PostgresTableLockMode, postgres_lock_table + + with transaction.atomic(durable=True): + postgres_lock_table("mytable", PostgresTableLockMode.EXCLUSIVE) + postgres_lock_table( + "tableinotherschema", + PostgresTableLockMode.EXCLUSIVE, + schema_name="myschema" + ) + + # locks are released here, when the transaction comitted diff --git a/docs/source/snippets/postgres_doc_links.rst b/docs/source/snippets/postgres_doc_links.rst index 90ebb51c..fe0f4d76 100644 --- a/docs/source/snippets/postgres_doc_links.rst +++ b/docs/source/snippets/postgres_doc_links.rst @@ -2,3 +2,4 @@ .. _TRUNCATE TABLE: https://www.postgresql.org/docs/9.1/sql-truncate.html .. _hstore: https://www.postgresql.org/docs/11/hstore.html .. _PostgreSQL Declarative Table Partitioning: https://www.postgresql.org/docs/current/ddl-partitioning.html#DDL-PARTITIONING-DECLARATIVE +.. _Explicit table-level locks: https://www.postgresql.org/docs/current/explicit-locking.html#LOCKING-TABLES diff --git a/psqlextra/backend/introspection.py b/psqlextra/backend/introspection.py index 5d9e8d5b..fe9c3324 100644 --- a/psqlextra/backend/introspection.py +++ b/psqlextra/backend/introspection.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List, Optional +from typing import List, Optional, Tuple from psqlextra.types import PostgresPartitioningMethod @@ -198,3 +198,19 @@ def get_constraints(self, cursor, table_name: str): constraint["definition"] = definition return constraints + + def get_table_locks(self, cursor) -> List[Tuple[str, str, str]]: + cursor.execute( + """ + SELECT + n.nspname, + t.relname, + l.mode + FROM pg_locks l + INNER JOIN pg_class t ON t.oid = l.relation + INNER JOIN pg_namespace n ON n.oid = t.relnamespace + WHERE t.relnamespace >= 2200 + ORDER BY n.nspname, t.relname, l.mode""" + ) + + return cursor.fetchall() diff --git a/psqlextra/locking.py b/psqlextra/locking.py new file mode 100644 index 00000000..2a791ce2 --- /dev/null +++ b/psqlextra/locking.py @@ -0,0 +1,97 @@ +from enum import Enum +from typing import Optional, Type + +from django.db import DEFAULT_DB_ALIAS, connections, models + + +class PostgresTableLockMode(Enum): + """List of table locking modes. + + See: https://www.postgresql.org/docs/current/explicit-locking.html + """ + + ACCESS_SHARE = "ACCESS SHARE" + ROW_SHARE = "ROW SHARE" + ROW_EXCLUSIVE = "ROW EXCLUSIVE" + SHARE_UPDATE_EXCLUSIVE = "SHARE UPDATE EXCLUSIVE" + SHARE = "SHARE" + SHARE_ROW_EXCLUSIVE = "SHARE ROW EXCLUSIVE" + EXCLUSIVE = "EXCLUSIVE" + ACCESS_EXCLUSIVE = "ACCESS EXCLUSIVE" + + +def postgres_lock_table( + table_name: str, + lock_mode: PostgresTableLockMode, + *, + schema_name: Optional[str] = None, + using: str = DEFAULT_DB_ALIAS, +) -> None: + """Locks the specified table with the specified mode. + + The lock is held until the end of the current transaction. + + Arguments: + table_name: + Unquoted table name to acquire the lock on. + + lock_mode: + Type of lock to acquire. + + schema_name: + Optionally, the unquoted name of the schema + the table to lock is in. If not specified, + the table name is resolved by PostgreSQL + using it's ``search_path``. + + using: + Name of the database alias to use. + """ + + connection = connections[using] + + with connection.cursor() as cursor: + quoted_fqn = connection.ops.quote_name(table_name) + if schema_name: + quoted_fqn = ( + connection.ops.quote_name(schema_name) + "." + quoted_fqn + ) + + cursor.execute(f"LOCK TABLE {quoted_fqn} IN {lock_mode.value} MODE") + + +def postgres_lock_model( + model: Type[models.Model], + lock_mode: PostgresTableLockMode, + *, + using: str = DEFAULT_DB_ALIAS, + schema_name: Optional[str] = None, +) -> None: + """Locks the specified model with the specified mode. + + The lock is held until the end of the current transaction. + + Arguments: + model: + The model of which to lock the table. + + lock_mode: + Type of lock to acquire. + + schema_name: + Optionally, the unquoted name of the schema + the table to lock is in. If not specified, + the table name is resolved by PostgreSQL + using it's ``search_path``. + + Django models always reside in the default + ("public") schema. You should not specify + this unless you're doing something special. + + using: + Name of the database alias to use. + """ + + postgres_lock_table( + model._meta.db_table, lock_mode, schema_name=schema_name, using=using + ) diff --git a/tests/test_locking.py b/tests/test_locking.py new file mode 100644 index 00000000..d5bc5173 --- /dev/null +++ b/tests/test_locking.py @@ -0,0 +1,107 @@ +import uuid + +import pytest + +from django.db import connection, models, transaction + +from psqlextra.locking import ( + PostgresTableLockMode, + postgres_lock_model, + postgres_lock_table, +) + +from .fake_model import get_fake_model + + +@pytest.fixture +def mocked_model(): + return get_fake_model( + { + "name": models.TextField(), + } + ) + + +def get_table_locks(): + with connection.cursor() as cursor: + return connection.introspection.get_table_locks(cursor) + + +@pytest.mark.django_db(transaction=True) +def test_postgres_lock_table(mocked_model): + lock_signature = ( + "public", + mocked_model._meta.db_table, + "AccessExclusiveLock", + ) + with transaction.atomic(): + postgres_lock_table( + mocked_model._meta.db_table, PostgresTableLockMode.ACCESS_EXCLUSIVE + ) + assert lock_signature in get_table_locks() + + assert lock_signature not in get_table_locks() + + +@pytest.mark.django_db(transaction=True) +def test_postgres_lock_table_in_schema(): + schema_name = str(uuid.uuid4())[:8] + table_name = str(uuid.uuid4())[:8] + quoted_schema_name = connection.ops.quote_name(schema_name) + quoted_table_name = connection.ops.quote_name(table_name) + + with connection.cursor() as cursor: + cursor.execute(f"CREATE SCHEMA {quoted_schema_name}") + cursor.execute( + f"CREATE TABLE {quoted_schema_name}.{quoted_table_name} AS SELECT 'hello world'" + ) + + lock_signature = (schema_name, table_name, "ExclusiveLock") + with transaction.atomic(): + postgres_lock_table( + table_name, PostgresTableLockMode.EXCLUSIVE, schema_name=schema_name + ) + assert lock_signature in get_table_locks() + + assert lock_signature not in get_table_locks() + + +@pytest.mark.django_db(transaction=True) +def test_postgres_lock_mode(mocked_model): + lock_signature = ( + "public", + mocked_model._meta.db_table, + "AccessExclusiveLock", + ) + + with transaction.atomic(): + postgres_lock_model( + mocked_model, PostgresTableLockMode.ACCESS_EXCLUSIVE + ) + assert lock_signature in get_table_locks() + + assert lock_signature not in get_table_locks() + + +@pytest.mark.django_db(transaction=True) +def test_postgres_lock_model_in_schema(mocked_model): + schema_name = str(uuid.uuid4())[:8] + quoted_schema_name = connection.ops.quote_name(schema_name) + quoted_table_name = connection.ops.quote_name(mocked_model._meta.db_table) + + with connection.cursor() as cursor: + cursor.execute(f"CREATE SCHEMA {quoted_schema_name}") + cursor.execute( + f"CREATE TABLE {quoted_schema_name}.{quoted_table_name} (LIKE public.{quoted_table_name} INCLUDING ALL)" + ) + + lock_signature = (schema_name, mocked_model._meta.db_table, "ExclusiveLock") + with transaction.atomic(): + postgres_lock_model( + mocked_model, + PostgresTableLockMode.EXCLUSIVE, + schema_name=schema_name, + ) + assert lock_signature in get_table_locks() + + assert lock_signature not in get_table_locks() From 5d55c25240958c8b3fe6d46511257c5fdcf08a63 Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Wed, 5 Apr 2023 10:47:32 +0300 Subject: [PATCH 07/43] Fix typo ind docs, `comitted` -> `committed` --- docs/source/locking.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/locking.rst b/docs/source/locking.rst index 160d18e2..8cf8cf8e 100644 --- a/docs/source/locking.rst +++ b/docs/source/locking.rst @@ -7,7 +7,7 @@ Locking `Explicit table-level locks`_ are supported through the :meth:`psqlextra.locking.postgres_lock_model` and :meth:`psqlextra.locking.postgres_lock_table` methods. All table-level lock methods are supported. -Locks are always bound to the current transaction and are released when the transaction is comitted or rolled back. There is no support (in PostgreSQL) for explicitly releasing a lock. +Locks are always bound to the current transaction and are released when the transaction is committed or rolled back. There is no support (in PostgreSQL) for explicitly releasing a lock. .. warning:: @@ -31,7 +31,7 @@ Use :class:`psqlextra.locking.PostgresTableLockMode` to indicate the type of loc with transaction.atomic(durable=True): postgres_lock_model(MyModel, PostgresTableLockMode.EXCLUSIVE) - # locks are released here, when the transaction comitted + # locks are released here, when the transaction committed Locking a table @@ -53,4 +53,4 @@ Use :meth:`psqlextra.locking.postgres_lock_table` to lock arbitrary tables in ar schema_name="myschema" ) - # locks are released here, when the transaction comitted + # locks are released here, when the transaction committed From 0bb392db265f68c4dd0d0663f8f5bf4629de35b4 Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Wed, 5 Apr 2023 10:47:59 +0300 Subject: [PATCH 08/43] Consistently document the `using` parameter --- docs/source/table_partitioning.rst | 2 +- psqlextra/locking.py | 4 ++-- psqlextra/management/commands/pgpartition.py | 2 +- psqlextra/partitioning/manager.py | 2 +- psqlextra/partitioning/plan.py | 2 +- psqlextra/query.py | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/source/table_partitioning.rst b/docs/source/table_partitioning.rst index 8ad3d115..1bb5ba6f 100644 --- a/docs/source/table_partitioning.rst +++ b/docs/source/table_partitioning.rst @@ -101,7 +101,7 @@ Command-line options Long flag Short flag Default Description ==================== ============= ================ ==================================================================================================== === === === === === === ``--yes`` ``-y`` ``False`` Specifies yes to all questions. You will NOT be asked for confirmation before partition deletion. - ``--using`` ``-u`` ``'default'`` Name of the database connection to use. + ``--using`` ``-u`` ``'default'`` Optional name of the database connection to use. ``--skip-create`` ``False`` Whether to skip creating partitions. ``--skip-delete`` ``False`` Whether to skip deleting partitions. diff --git a/psqlextra/locking.py b/psqlextra/locking.py index 2a791ce2..42f41b67 100644 --- a/psqlextra/locking.py +++ b/psqlextra/locking.py @@ -45,7 +45,7 @@ def postgres_lock_table( using it's ``search_path``. using: - Name of the database alias to use. + Optional name of the database connection to use. """ connection = connections[using] @@ -89,7 +89,7 @@ def postgres_lock_model( this unless you're doing something special. using: - Name of the database alias to use. + Optional name of the database connection to use. """ postgres_lock_table( diff --git a/psqlextra/management/commands/pgpartition.py b/psqlextra/management/commands/pgpartition.py index 80f2cecc..592b57d7 100644 --- a/psqlextra/management/commands/pgpartition.py +++ b/psqlextra/management/commands/pgpartition.py @@ -37,7 +37,7 @@ def add_arguments(self, parser): parser.add_argument( "--using", "-u", - help="Name of the database connection to use.", + help="Optional name of the database connection to use.", default="default", ) diff --git a/psqlextra/partitioning/manager.py b/psqlextra/partitioning/manager.py index 28aee91e..4dcbb599 100644 --- a/psqlextra/partitioning/manager.py +++ b/psqlextra/partitioning/manager.py @@ -39,7 +39,7 @@ def plan( for deletion, regardless of the configuration. using: - Name of the database connection to use. + Optional name of the database connection to use. Returns: A plan describing what partitions would be created diff --git a/psqlextra/partitioning/plan.py b/psqlextra/partitioning/plan.py index fdb1eee2..31746360 100644 --- a/psqlextra/partitioning/plan.py +++ b/psqlextra/partitioning/plan.py @@ -28,7 +28,7 @@ def apply(self, using: Optional[str]) -> None: Arguments: using: - Name of the database connection to use. + Optional name of the database connection to use. """ connection = connections[using or "default"] diff --git a/psqlextra/query.py b/psqlextra/query.py index 2756fd8c..2f117e3d 100644 --- a/psqlextra/query.py +++ b/psqlextra/query.py @@ -131,7 +131,7 @@ def bulk_insert( just dicts. using: - Name of the database connection to use for + Optional name of the database connection to use for this query. Returns: From c79a8ca25b9be77604967e16a17ae099b66080b3 Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Wed, 5 Apr 2023 10:48:12 +0300 Subject: [PATCH 09/43] Test all table lock modes --- psqlextra/locking.py | 7 +++++++ tests/test_locking.py | 9 ++++----- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/psqlextra/locking.py b/psqlextra/locking.py index 42f41b67..da8ff567 100644 --- a/psqlextra/locking.py +++ b/psqlextra/locking.py @@ -19,6 +19,13 @@ class PostgresTableLockMode(Enum): EXCLUSIVE = "EXCLUSIVE" ACCESS_EXCLUSIVE = "ACCESS EXCLUSIVE" + @property + def alias(self) -> str: + return ( + "".join([word.title() for word in self.name.lower().split("_")]) + + "Lock" + ) + def postgres_lock_table( table_name: str, diff --git a/tests/test_locking.py b/tests/test_locking.py index d5bc5173..6414689d 100644 --- a/tests/test_locking.py +++ b/tests/test_locking.py @@ -66,18 +66,17 @@ def test_postgres_lock_table_in_schema(): assert lock_signature not in get_table_locks() +@pytest.mark.parametrize("lock_mode", list(PostgresTableLockMode)) @pytest.mark.django_db(transaction=True) -def test_postgres_lock_mode(mocked_model): +def test_postgres_lock_model(mocked_model, lock_mode): lock_signature = ( "public", mocked_model._meta.db_table, - "AccessExclusiveLock", + lock_mode.alias, ) with transaction.atomic(): - postgres_lock_model( - mocked_model, PostgresTableLockMode.ACCESS_EXCLUSIVE - ) + postgres_lock_model(mocked_model, lock_mode) assert lock_signature in get_table_locks() assert lock_signature not in get_table_locks() From b428fdaf1716caf4c20ad80ff1b6bf6ea005e5b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bla=C5=BE=20=C5=A0nuderl?= Date: Thu, 6 Apr 2023 13:51:40 +0200 Subject: [PATCH 10/43] Support Django 4.2 + psycopg3 (#208) --- .circleci/config.yml | 9 ++++----- psqlextra/backend/schema.py | 4 +++- psqlextra/compiler.py | 3 ++- setup.py | 4 ++++ tests/conftest.py | 2 +- tox.ini | 3 ++- 6 files changed, 16 insertions(+), 9 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index e494f4bd..7f245a40 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -7,7 +7,7 @@ executors: type: string docker: - image: python:<< parameters.version >>-buster - - image: postgres:11.0 + - image: postgres:12.0 environment: POSTGRES_DB: 'psqlextra' POSTGRES_USER: 'psqlextra' @@ -42,7 +42,6 @@ commands: environment: DATABASE_URL: 'postgres://psqlextra:psqlextra@localhost:5432/psqlextra' - jobs: test-python36: executor: @@ -78,7 +77,7 @@ jobs: extra: test - run-tests: pyversion: 38 - djversions: 20,21,22,30,31,32,40 + djversions: 20,21,22,30,31,32,40,41,42 test-python39: executor: @@ -90,7 +89,7 @@ jobs: extra: test - run-tests: pyversion: 39 - djversions: 21,22,30,31,32,40 + djversions: 21,22,30,31,32,40,41,42 test-python310: executor: @@ -102,7 +101,7 @@ jobs: extra: test - run-tests: pyversion: 310 - djversions: 21,22,30,31,32,40 + djversions: 21,22,30,31,32,40,41,42 - store_test_results: path: reports - run: diff --git a/psqlextra/backend/schema.py b/psqlextra/backend/schema.py index b59ed617..413f039d 100644 --- a/psqlextra/backend/schema.py +++ b/psqlextra/backend/schema.py @@ -430,7 +430,9 @@ def _create_view_model(self, sql: str, model: Model) -> None: meta = self._view_properties_for_model(model) with self.connection.cursor() as cursor: - view_sql = cursor.mogrify(*meta.query).decode("utf-8") + view_sql = cursor.mogrify(*meta.query) + if isinstance(view_sql, bytes): + view_sql = view_sql.decode("utf-8") self.execute(sql % (self.quote_name(model._meta.db_table), view_sql)) diff --git a/psqlextra/compiler.py b/psqlextra/compiler.py index ee414bfd..be96e50d 100644 --- a/psqlextra/compiler.py +++ b/psqlextra/compiler.py @@ -187,12 +187,13 @@ def execute_sql(self, return_id=False): rows.extend(cursor.fetchall()) except ProgrammingError: pass + description = cursor.description # create a mapping between column names and column value return [ { column.name: row[column_index] - for column_index, column in enumerate(cursor.description) + for column_index, column in enumerate(description) if row } for row in rows diff --git a/setup.py b/setup.py index 281be89d..a7f6b8f0 100644 --- a/setup.py +++ b/setup.py @@ -95,6 +95,10 @@ def run(self): "build==0.7.0", "twine==3.7.1", ], + "psycopg3": [ + "django>=4.2,<5.0", + "psycopg[binary]>=3.0.0", + ], }, cmdclass={ "lint": create_command( diff --git a/tests/conftest.py b/tests/conftest.py index f90692af..387edd3b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -29,7 +29,7 @@ def fake_app(): def postgres_server_version(db) -> int: """Gets the PostgreSQL server version.""" - return connection.cursor().connection.server_version + return connection.cursor().connection.info.server_version @pytest.fixture(autouse=True) diff --git a/tox.ini b/tox.ini index 26be5d15..7cf24e55 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py36-dj{20,21,22,30,31,32}, py37-dj{20,21,22,30,31,32}, py38-dj{20,21,22,30,31,32,40, 41}, py39-dj{21,22,30,31,32,40,41}, py310-dj{21,22,30,31,32,40,41} +envlist = py36-dj{20,21,22,30,31,32}, py37-dj{20,21,22,30,31,32}, py38-dj{20,21,22,30,31,32,40,41,42}, py39-dj{21,22,30,31,32,40,41,42}, py310-dj{21,22,30,31,32,40,41,42} [testenv] deps = @@ -11,6 +11,7 @@ deps = dj32: Django~=3.2.0 dj40: Django~=4.0.0 dj41: Django~=4.1.0 + dj42: .[psycopg3] .[test] setenv = DJANGO_SETTINGS_MODULE=settings From 309f1664f220bf537d99c6673f449d40908aef61 Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Thu, 6 Apr 2023 14:47:06 +0300 Subject: [PATCH 11/43] Run tests against pyscopg3 for Django 4.2 --- .circleci/config.yml | 9 +-------- setup.py | 4 ---- tox.ini | 9 +++++++-- 3 files changed, 8 insertions(+), 14 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 7f245a40..b794d7ad 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -32,13 +32,11 @@ commands: parameters: pyversion: type: integer - djversions: - type: string steps: - run: name: Run tests - command: tox -e 'py<< parameters.pyversion >>-dj{<< parameters.djversions >>}' + command: tox --listenvs | grep ^py<< parameters.pyversion >> | circleci tests split | xargs -n 1 tox -e environment: DATABASE_URL: 'postgres://psqlextra:psqlextra@localhost:5432/psqlextra' @@ -53,7 +51,6 @@ jobs: extra: test - run-tests: pyversion: 36 - djversions: 20,21,22,30,31,32 test-python37: executor: @@ -65,7 +62,6 @@ jobs: extra: test - run-tests: pyversion: 37 - djversions: 20,21,22,30,31,32 test-python38: executor: @@ -77,7 +73,6 @@ jobs: extra: test - run-tests: pyversion: 38 - djversions: 20,21,22,30,31,32,40,41,42 test-python39: executor: @@ -89,7 +84,6 @@ jobs: extra: test - run-tests: pyversion: 39 - djversions: 21,22,30,31,32,40,41,42 test-python310: executor: @@ -101,7 +95,6 @@ jobs: extra: test - run-tests: pyversion: 310 - djversions: 21,22,30,31,32,40,41,42 - store_test_results: path: reports - run: diff --git a/setup.py b/setup.py index a7f6b8f0..281be89d 100644 --- a/setup.py +++ b/setup.py @@ -95,10 +95,6 @@ def run(self): "build==0.7.0", "twine==3.7.1", ], - "psycopg3": [ - "django>=4.2,<5.0", - "psycopg[binary]>=3.0.0", - ], }, cmdclass={ "lint": create_command( diff --git a/tox.ini b/tox.ini index 7cf24e55..f50ac614 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,8 @@ [tox] -envlist = py36-dj{20,21,22,30,31,32}, py37-dj{20,21,22,30,31,32}, py38-dj{20,21,22,30,31,32,40,41,42}, py39-dj{21,22,30,31,32,40,41,42}, py310-dj{21,22,30,31,32,40,41,42} +envlist = + {py36,py37}-dj{20,21,22,30,31,32}-psycopg{2} + {py38,py39,py310}-dj{20,21,22,30,31,32,40,41}-psycopg{2} + {py38,py39,py310}-dj{42}-psycopg{2,3} [testenv] deps = @@ -11,7 +14,9 @@ deps = dj32: Django~=3.2.0 dj40: Django~=4.0.0 dj41: Django~=4.1.0 - dj42: .[psycopg3] + dj42: Django~=4.2.0 + psycopg2: psycopg2[binary]~=2.9 + psycopg3: psycopg[binary]~=3.1 .[test] setenv = DJANGO_SETTINGS_MODULE=settings From 50bb361905dead400449d102b966046ef2ccc57f Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Thu, 6 Apr 2023 14:55:44 +0300 Subject: [PATCH 12/43] Document suppport for Django 4.2 and psycopg3 --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 77eccb9a..8127b254 100644 --- a/README.md +++ b/README.md @@ -8,8 +8,9 @@ | :memo: | **License** | [![License](https://img.shields.io/:license-mit-blue.svg)](http://doge.mit-license.org) | | :package: | **PyPi** | [![PyPi](https://badge.fury.io/py/django-postgres-extra.svg)](https://pypi.python.org/pypi/django-postgres-extra) | | :four_leaf_clover: | **Code coverage** | [![Coverage Status](https://coveralls.io/repos/github/SectorLabs/django-postgres-extra/badge.svg?branch=coveralls)](https://coveralls.io/github/SectorLabs/django-postgres-extra?branch=master) | -| | **Django Versions** | 2.0, 2.1, 2.2, 3.0, 3.1, 3.2, 4.0, 4.1 | -| | **Python Versions** | 3.6, 3.7, 3.8, 3.9, 3.10 | +| | **Django Versions** | 2.0, 2.1, 2.2, 3.0, 3.1, 3.2, 4.0, 4.1, 4.2 | +| | **Python Versions** | 3.6, 3.7, 3.8, 3.9, 3.10, 3.11 | +| | **Psycopg Versions** | 2, 3 | | :book: | **Documentation** | [Read The Docs](https://django-postgres-extra.readthedocs.io/en/master/) | | :warning: | **Upgrade** | [Upgrade from v1.x](https://django-postgres-extra.readthedocs.io/en/master/major_releases.html#new-features) | :checkered_flag: | **Installation** | [Installation Guide](https://django-postgres-extra.readthedocs.io/en/master/installation.html) | From 316ccfe8c1e0f825fa22fe431e31cbf9a0e9df7d Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Thu, 6 Apr 2023 14:59:41 +0300 Subject: [PATCH 13/43] Run CI against Python 3.11 --- .circleci/config.yml | 18 ++++++++++++++++++ tox.ini | 12 +++++++----- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index b794d7ad..f85823b7 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -95,6 +95,17 @@ jobs: extra: test - run-tests: pyversion: 310 + + test-python311: + executor: + name: python + version: "3.11" + steps: + - checkout + - install-dependencies: + extra: test + - run-tests: + pyversion: 311 - store_test_results: path: reports - run: @@ -171,6 +182,12 @@ workflows: only: /.*/ branches: only: /.*/ + - test-python311: + filters: + tags: + only: /.*/ + branches: + only: /.*/ - analysis: filters: tags: @@ -184,6 +201,7 @@ workflows: - test-python38 - test-python39 - test-python310 + - test-python311 - analysis filters: tags: diff --git a/tox.ini b/tox.ini index f50ac614..3e229d0d 100644 --- a/tox.ini +++ b/tox.ini @@ -1,8 +1,9 @@ [tox] envlist = - {py36,py37}-dj{20,21,22,30,31,32}-psycopg{2} - {py38,py39,py310}-dj{20,21,22,30,31,32,40,41}-psycopg{2} - {py38,py39,py310}-dj{42}-psycopg{2,3} + {py36,py37}-dj{20,21,22,30,31,32}-psycopg{28,29} + {py38,py39,py310}-dj{21,22,30,31,32,40}-psycopg{28,29} + {py38,py39,py310,py311}-dj{41}-psycopg{28,29} + {py38,py39,py310,py311}-dj{42}-psycopg{28,29,31} [testenv] deps = @@ -15,8 +16,9 @@ deps = dj40: Django~=4.0.0 dj41: Django~=4.1.0 dj42: Django~=4.2.0 - psycopg2: psycopg2[binary]~=2.9 - psycopg3: psycopg[binary]~=3.1 + psycopg28: psycopg2[binary]~=2.8 + psycopg29: psycopg2[binary]~=2.9 + psycopg31: psycopg[binary]~=3.1 .[test] setenv = DJANGO_SETTINGS_MODULE=settings From 498124a218fa6e57a5c501d195726a08a9a1734e Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Tue, 4 Apr 2023 22:31:10 +0300 Subject: [PATCH 14/43] Add support to schema editor for cloning table into schema --- psqlextra/backend/introspection.py | 120 ++++- psqlextra/backend/operations.py | 3 + psqlextra/backend/schema.py | 435 ++++++++++++++++-- tests/db_introspection.py | 91 +++- tests/fake_model.py | 23 +- ...est_schema_editor_clone_model_to_schema.py | 321 +++++++++++++ tests/test_schema_editor_storage_settings.py | 47 ++ 7 files changed, 989 insertions(+), 51 deletions(-) create mode 100644 tests/test_schema_editor_clone_model_to_schema.py create mode 100644 tests/test_schema_editor_storage_settings.py diff --git a/psqlextra/backend/introspection.py b/psqlextra/backend/introspection.py index fe9c3324..03e1fdef 100644 --- a/psqlextra/backend/introspection.py +++ b/psqlextra/backend/introspection.py @@ -1,5 +1,8 @@ +from contextlib import contextmanager from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple + +from django.db import transaction from psqlextra.types import PostgresPartitioningMethod @@ -48,6 +51,22 @@ def partition_by_name( class PostgresIntrospection(base_impl.introspection()): """Adds introspection features specific to PostgreSQL.""" + # TODO: This class is a mess, both here and in the + # the base. + # + # Some methods return untyped dicts, some named tuples, + # some flat lists of strings. It's horribly inconsistent. + # + # Most methods are poorly named. For example; `get_table_description` + # does not return a complete table description. It merely returns + # the columns. + # + # We do our best in this class to stay consistent with + # the base in Django by respecting its naming scheme + # and commonly used return types. Creating an API that + # matches the look&feel from the Django base class + # is more important than fixing those issues. + def get_partitioned_tables( self, cursor ) -> PostgresIntrospectedPartitonedTable: @@ -172,6 +191,9 @@ def get_partition_key(self, cursor, table_name: str) -> List[str]: cursor.execute(sql, (table_name,)) return [row[0] for row in cursor.fetchall()] + def get_columns(self, cursor, table_name: str): + return self.get_table_description(cursor, table_name) + def get_constraints(self, cursor, table_name: str): """Retrieve any constraints or keys (unique, pk, fk, check, index) across one or more columns. @@ -202,15 +224,93 @@ def get_constraints(self, cursor, table_name: str): def get_table_locks(self, cursor) -> List[Tuple[str, str, str]]: cursor.execute( """ - SELECT - n.nspname, - t.relname, - l.mode - FROM pg_locks l - INNER JOIN pg_class t ON t.oid = l.relation - INNER JOIN pg_namespace n ON n.oid = t.relnamespace - WHERE t.relnamespace >= 2200 - ORDER BY n.nspname, t.relname, l.mode""" + SELECT + n.nspname, + t.relname, + l.mode + FROM pg_locks l + INNER JOIN pg_class t ON t.oid = l.relation + INNER JOIN pg_namespace n ON n.oid = t.relnamespace + WHERE t.relnamespace >= 2200 + ORDER BY n.nspname, t.relname, l.mode + """ ) return cursor.fetchall() + + def get_storage_settings(self, cursor, table_name: str) -> Dict[str, str]: + sql = """ + SELECT + unnest(c.reloptions || array(select 'toast.' || x from pg_catalog.unnest(tc.reloptions) x)) + FROM + pg_catalog.pg_class c + LEFT JOIN + pg_catalog.pg_class tc ON (c.reltoastrelid = tc.oid) + LEFT JOIN + pg_catalog.pg_am am ON (c.relam = am.oid) + WHERE + c.relname::text = %s + """ + + cursor.execute(sql, (table_name,)) + + storage_settings = {} + for row in cursor.fetchall(): + # It's hard to believe, but storage settings are really + # represented as `key=value` strings in Postgres. + # See: https://www.postgresql.org/docs/current/catalog-pg-class.html + name, value = row[0].split("=") + storage_settings[name] = value + + return storage_settings + + def get_relations(self, cursor, table_name: str): + """Gets a dictionary {field_name: (field_name_other_table, + other_table)} representing all relations in the specified table. + + This is overriden because the query in Django does not handle + relations between tables in different schemas properly. + """ + + cursor.execute( + """ + SELECT a1.attname, c2.relname, a2.attname + FROM pg_constraint con + LEFT JOIN pg_class c1 ON con.conrelid = c1.oid + LEFT JOIN pg_class c2 ON con.confrelid = c2.oid + LEFT JOIN pg_attribute a1 ON c1.oid = a1.attrelid AND a1.attnum = con.conkey[1] + LEFT JOIN pg_attribute a2 ON c2.oid = a2.attrelid AND a2.attnum = con.confkey[1] + WHERE + con.conrelid = %s::regclass AND + con.contype = 'f' AND + pg_catalog.pg_table_is_visible(c1.oid) + """, + [table_name], + ) + return {row[0]: (row[2], row[1]) for row in cursor.fetchall()} + + @contextmanager + def in_search_path(self, search_path: List[str]): + """Changes the Postgres `search_path` within the context and switches + it back when it exits.""" + + # Wrap in a transaction so a savepoint is created. If + # something goes wrong, the `SET LOCAL search_path` + # statement will be rolled back. + with transaction.atomic(using=self.connection.alias): + with self.connection.cursor() as cursor: + cursor.execute("SHOW search_path") + (original_search_path,) = cursor.fetchone() + + # Syntax in Postgres is a bit weird here. It isn't really + # a list of names like in `WHERE bla in (val1, val2)`. + placeholder = ", ".join(["%s" for _ in search_path]) + cursor.execute( + f"SET LOCAL search_path = {placeholder}", search_path + ) + + yield self + + cursor.execute( + f"SET LOCAL search_path = {original_search_path}" + ) diff --git a/psqlextra/backend/operations.py b/psqlextra/backend/operations.py index 52793fac..24adf5d0 100644 --- a/psqlextra/backend/operations.py +++ b/psqlextra/backend/operations.py @@ -21,3 +21,6 @@ class PostgresOperations(base_impl.operations()): SQLUpdateCompiler, SQLInsertCompiler, ] + + def default_schema_name(self) -> str: + return "public" diff --git a/psqlextra/backend/schema.py b/psqlextra/backend/schema.py index 413f039d..b45ea8de 100644 --- a/psqlextra/backend/schema.py +++ b/psqlextra/backend/schema.py @@ -1,12 +1,15 @@ -from typing import Any, List, Optional +from typing import Any, List, Optional, Type from unittest import mock +import django + from django.core.exceptions import ( FieldDoesNotExist, ImproperlyConfigured, SuspiciousOperation, ) from django.db import transaction +from django.db.backends.ddl_references import Statement from django.db.models import Field, Model from psqlextra.type_assertions import is_sql_with_params @@ -19,12 +22,24 @@ HStoreUniqueSchemaEditorSideEffect, ) +SchemaEditor = base_impl.schema_editor() + -class PostgresSchemaEditor(base_impl.schema_editor()): +class PostgresSchemaEditor(SchemaEditor): """Schema editor that adds extra methods for PostgreSQL specific features and hooks into existing implementations to add side effects specific to PostgreSQL.""" + sql_add_pk = "ALTER TABLE %s ADD PRIMARY KEY (%s)" + + sql_create_fk_not_valid = f"{SchemaEditor.sql_create_fk} NOT VALID" + sql_validate_fk = "ALTER TABLE %s VALIDATE CONSTRAINT %s" + + sql_create_sequence_with_owner = "CREATE SEQUENCE %s OWNED BY %s.%s" + + sql_alter_table_storage_setting = "ALTER TABLE %s SET (%s = %s)" + sql_reset_table_storage_setting = "ALTER TABLE %s RESET (%s)" + sql_create_view = "CREATE VIEW %s AS (%s)" sql_replace_view = "CREATE OR REPLACE VIEW %s AS (%s)" sql_drop_view = "DROP VIEW IF EXISTS %s" @@ -63,7 +78,7 @@ def __init__(self, connection, collect_sql=False, atomic=True): self.deferred_sql = [] self.introspection = PostgresIntrospection(self.connection) - def create_model(self, model: Model) -> None: + def create_model(self, model: Type[Model]) -> None: """Creates a new model.""" super().create_model(model) @@ -71,7 +86,7 @@ def create_model(self, model: Model) -> None: for side_effect in self.side_effects: side_effect.create_model(model) - def delete_model(self, model: Model) -> None: + def delete_model(self, model: Type[Model]) -> None: """Drops/deletes an existing model.""" for side_effect in self.side_effects: @@ -79,8 +94,352 @@ def delete_model(self, model: Model) -> None: super().delete_model(model) + def clone_model_structure_to_schema( + self, model: Type[Model], *, schema_name: str + ) -> None: + """Creates a clone of the columns for the specified model in a separate + schema. + + The table will have exactly the same name as the model table + in the default schema. It will have none of the constraints, + foreign keys and indexes. + + Use this to create a temporary clone of a model table to + replace the original model table later on. The lack of + indices and constraints allows for greater write speeds. + + The original model table will be unaffected. + + Arguments: + model: + Model to clone the table of into the + specified schema. + + schema_name: + Name of the schema to create the cloned + table in. + """ + + table_name = model._meta.db_table + quoted_table_name = self.quote_name(model._meta.db_table) + quoted_schema_name = self.quote_name(schema_name) + + quoted_table_fqn = f"{quoted_schema_name}.{quoted_table_name}" + + self.execute( + self.sql_create_table + % { + "table": quoted_table_fqn, + "definition": f"LIKE {quoted_table_name} INCLUDING ALL EXCLUDING CONSTRAINTS EXCLUDING INDEXES", + } + ) + + # Copy sequences + # + # Django 4.0 and older do not use IDENTITY so Postgres does + # not copy the sequences into the new table. We do it manually. + if django.VERSION < (4, 1): + with self.connection.cursor() as cursor: + sequences = self.introspection.get_sequences(cursor, table_name) + + for sequence in sequences: + if sequence["table"] != table_name: + continue + + quoted_sequence_name = self.quote_name(sequence["name"]) + quoted_sequence_fqn = ( + f"{quoted_schema_name}.{quoted_sequence_name}" + ) + quoted_column_name = self.quote_name(sequence["column"]) + + self.execute( + self.sql_create_sequence_with_owner + % ( + quoted_sequence_fqn, + quoted_table_fqn, + quoted_column_name, + ) + ) + + self.execute( + self.sql_alter_column + % { + "table": quoted_table_fqn, + "changes": self.sql_alter_column_default + % { + "column": quoted_column_name, + "default": "nextval('%s')" % quoted_sequence_fqn, + }, + } + ) + + # Copy storage settings + # + # Postgres only copies column-level storage options, not + # the table-level storage options. + with self.connection.cursor() as cursor: + storage_settings = self.introspection.get_storage_settings( + cursor, model._meta.db_table + ) + + for setting_name, setting_value in storage_settings.items(): + self.alter_table_storage_setting( + quoted_table_fqn, setting_name, setting_value + ) + + def clone_model_constraints_and_indexes_to_schema( + self, model: Type[Model], *, schema_name: str + ) -> None: + """Adds the constraints, foreign keys and indexes to a model table that + was cloned into a separate table without them by + `clone_model_structure_to_schema`. + + Arguments: + model: + Model for which the cloned table was created. + + schema_name: + Name of the schema in which the cloned table + resides. + """ + + with self.introspection.in_search_path( + [schema_name, self.connection.ops.default_schema_name()] + ): + for constraint in model._meta.constraints: + self.add_constraint(model, constraint) + + for index in model._meta.indexes: + self.add_index(model, index) + + if model._meta.unique_together: + self.alter_unique_together( + model, tuple(), model._meta.unique_together + ) + + if model._meta.index_together: + self.alter_index_together( + model, tuple(), model._meta.index_together + ) + + for field in model._meta.local_concrete_fields: + # Django creates primary keys later added to the model with + # a custom name. We want the name as it was created originally. + if field.primary_key: + with self.introspection.in_search_path( + [self.connection.ops.default_schema_name()] + ): + [primary_key_name] = self._constraint_names( + model, primary_key=True + ) + + self.execute( + self.sql_create_pk + % { + "table": self.quote_name(model._meta.db_table), + "name": self.quote_name(primary_key_name), + "columns": self.quote_name( + field.db_column or field.attname + ), + } + ) + continue + + # Django creates foreign keys in a single statement which acquires + # a AccessExclusiveLock on the referenced table. We want to avoid + # that and created the FK as NOT VALID. We can run VALIDATE in + # a separate transaction later to validate the entries without + # acquiring a AccessExclusiveLock. + if field.remote_field: + with self.introspection.in_search_path( + [self.connection.ops.default_schema_name()] + ): + [fk_name] = self._constraint_names( + model, [field.column], foreign_key=True + ) + + sql = Statement( + self.sql_create_fk_not_valid, + table=self.quote_name(model._meta.db_table), + name=self.quote_name(fk_name), + column=self.quote_name(field.column), + to_table=self.quote_name( + field.target_field.model._meta.db_table + ), + to_column=self.quote_name(field.target_field.column), + deferrable=self.connection.ops.deferrable_sql(), + ) + + self.execute(sql) + + # It's hard to alter a field's check because it is defined + # by the field class, not the field instance. Handle this + # manually. + field_check = field.db_parameters(self.connection).get("check") + if field_check: + with self.introspection.in_search_path( + [self.connection.ops.default_schema_name()] + ): + [field_check_name] = self._constraint_names( + model, + [field.column], + check=True, + exclude={ + constraint.name + for constraint in model._meta.constraints + }, + ) + + self.execute( + self._create_check_sql( + model, field_check_name, field_check + ) + ) + + # Clone the field and alter its state to math our current + # table definition. This will cause Django see the missing + # indices and create them. + if field.remote_field: + # We add the foreign key constraint ourselves with NOT VALID, + # hence, we specify `db_constraint=False` on both old/new. + # Django won't touch the foreign key constraint. + old_field = self._clone_model_field( + field, db_index=False, unique=False, db_constraint=False + ) + new_field = self._clone_model_field( + field, db_constraint=False + ) + self.alter_field(model, old_field, new_field) + else: + old_field = self._clone_model_field( + field, db_index=False, unique=False + ) + new_field = self._clone_model_field(field) + self.alter_field(model, old_field, new_field) + + def clone_model_foreign_keys_to_schema( + self, model: Type[Model], schema_name: str + ) -> None: + """Validates the foreign keys in the cloned model table created by + `clone_model_structure_to_schema` and + `clone_model_constraints_and_indexes_to_schema`. + + Do NOT run this in the same transaction as the + foreign keys were added to the table. It WILL + acquire a long-lived AccessExclusiveLock. + + Arguments: + model: + Model for which the cloned table was created. + + schema_name: + Name of the schema in which the cloned table + resides. + """ + + with self.introspection.in_search_path( + [schema_name, self.connection.ops.default_schema_name()] + ): + for fk_name in self._constraint_names(model, foreign_key=True): + self.execute( + self.sql_validate_fk + % ( + self.quote_name(model._meta.db_table), + self.quote_name(fk_name), + ) + ) + + def alter_table_storage_setting( + self, table_name: str, name: str, value: str + ) -> None: + """Alters a storage setting for a table. + + See: https://www.postgresql.org/docs/current/sql-createtable.html#SQL-CREATETABLE-STORAGE-PARAMETERS + + Arguments: + table_name: + Name of the table to alter the setting for. + + name: + Name of the setting to alter. + + value: + Value to alter the setting to. + + Note that this is always a string, even if it looks + like a number or a boolean. That's how Postgres + stores storage settings internally. + """ + + self.execute( + self.sql_alter_table_storage_setting + % (self.quote_name(table_name), name, value) + ) + + def alter_model_storage_setting( + self, model: Type[Model], name: str, value: str + ) -> None: + """Alters a storage setting for the model's table. + + See: https://www.postgresql.org/docs/current/sql-createtable.html#SQL-CREATETABLE-STORAGE-PARAMETERS + + Arguments: + model: + Model of which to alter the table + setting. + + name: + Name of the setting to alter. + + value: + Value to alter the setting to. + + Note that this is always a string, even if it looks + like a number or a boolean. That's how Postgres + stores storage settings internally. + """ + + self.alter_table_storage_setting(model._meta.db_table, name, value) + + def reset_table_storage_setting(self, table_name: str, name: str) -> None: + """Resets a table's storage setting to the database or server default. + + See: https://www.postgresql.org/docs/current/sql-createtable.html#SQL-CREATETABLE-STORAGE-PARAMETERS + + Arguments: + table_name: + Name of the table to reset the setting for. + + name: + Name of the setting to reset. + """ + + self.execute( + self.sql_reset_table_storage_setting + % (self.quote_name(table_name), name) + ) + + def reset_model_storage_setting( + self, model: Type[Model], name: str + ) -> None: + """Resets a model's table storage setting to the database or server + default. + + See: https://www.postgresql.org/docs/current/sql-createtable.html#SQL-CREATETABLE-STORAGE-PARAMETERS + + Arguments: + table_name: + model: + Model for which to reset the table setting for. + + name: + Name of the setting to reset. + """ + + self.reset_table_storage_setting(model._meta.db_table, name) + def refresh_materialized_view_model( - self, model: Model, concurrently: bool = False + self, model: Type[Model], concurrently: bool = False ) -> None: """Refreshes a materialized view.""" @@ -93,12 +452,12 @@ def refresh_materialized_view_model( sql = sql_template % self.quote_name(model._meta.db_table) self.execute(sql) - def create_view_model(self, model: Model) -> None: + def create_view_model(self, model: Type[Model]) -> None: """Creates a new view model.""" self._create_view_model(self.sql_create_view, model) - def replace_view_model(self, model: Model) -> None: + def replace_view_model(self, model: Type[Model]) -> None: """Replaces a view model with a newer version. This is used to alter the backing query of a view. @@ -106,18 +465,18 @@ def replace_view_model(self, model: Model) -> None: self._create_view_model(self.sql_replace_view, model) - def delete_view_model(self, model: Model) -> None: + def delete_view_model(self, model: Type[Model]) -> None: """Deletes a view model.""" sql = self.sql_drop_view % self.quote_name(model._meta.db_table) self.execute(sql) - def create_materialized_view_model(self, model: Model) -> None: + def create_materialized_view_model(self, model: Type[Model]) -> None: """Creates a new materialized view model.""" self._create_view_model(self.sql_create_materialized_view, model) - def replace_materialized_view_model(self, model: Model) -> None: + def replace_materialized_view_model(self, model: Type[Model]) -> None: """Replaces a materialized view with a newer version. This is used to alter the backing query of a materialized view. @@ -148,7 +507,7 @@ def replace_materialized_view_model(self, model: Model) -> None: self.execute(constraint_options["definition"]) - def delete_materialized_view_model(self, model: Model) -> None: + def delete_materialized_view_model(self, model: Type[Model]) -> None: """Deletes a materialized view model.""" sql = self.sql_drop_materialized_view % self.quote_name( @@ -156,7 +515,7 @@ def delete_materialized_view_model(self, model: Model) -> None: ) self.execute(sql) - def create_partitioned_model(self, model: Model) -> None: + def create_partitioned_model(self, model: Type[Model]) -> None: """Creates a new partitioned model.""" meta = self._partitioning_properties_for_model(model) @@ -188,14 +547,14 @@ def create_partitioned_model(self, model: Model) -> None: self.execute(sql, params) - def delete_partitioned_model(self, model: Model) -> None: + def delete_partitioned_model(self, model: Type[Model]) -> None: """Drops the specified partitioned model.""" return self.delete_model(model) def add_range_partition( self, - model: Model, + model: Type[Model], name: str, from_values: Any, to_values: Any, @@ -246,7 +605,7 @@ def add_range_partition( def add_list_partition( self, - model: Model, + model: Type[Model], name: str, values: List[Any], comment: Optional[str] = None, @@ -289,7 +648,7 @@ def add_list_partition( def add_hash_partition( self, - model: Model, + model: Type[Model], name: str, modulus: int, remainder: int, @@ -334,7 +693,7 @@ def add_hash_partition( self.set_comment_on_table(table_name, comment) def add_default_partition( - self, model: Model, name: str, comment: Optional[str] = None + self, model: Type[Model], name: str, comment: Optional[str] = None ) -> None: """Creates a new default partition for the specified partitioned model. @@ -370,7 +729,7 @@ def add_default_partition( if comment: self.set_comment_on_table(table_name, comment) - def delete_partition(self, model: Model, name: str) -> None: + def delete_partition(self, model: Type[Model], name: str) -> None: """Deletes the partition with the specified name.""" sql = self.sql_delete_partition % self.quote_name( @@ -379,7 +738,7 @@ def delete_partition(self, model: Model, name: str) -> None: self.execute(sql) def alter_db_table( - self, model: Model, old_db_table: str, new_db_table: str + self, model: Type[Model], old_db_table: str, new_db_table: str ) -> None: """Alters a table/model.""" @@ -388,7 +747,7 @@ def alter_db_table( for side_effect in self.side_effects: side_effect.alter_db_table(model, old_db_table, new_db_table) - def add_field(self, model: Model, field: Field) -> None: + def add_field(self, model: Type[Model], field: Field) -> None: """Adds a new field to an exisiting model.""" super().add_field(model, field) @@ -396,7 +755,7 @@ def add_field(self, model: Model, field: Field) -> None: for side_effect in self.side_effects: side_effect.add_field(model, field) - def remove_field(self, model: Model, field: Field) -> None: + def remove_field(self, model: Type[Model], field: Field) -> None: """Removes a field from an existing model.""" for side_effect in self.side_effects: @@ -406,7 +765,7 @@ def remove_field(self, model: Model, field: Field) -> None: def alter_field( self, - model: Model, + model: Type[Model], old_field: Field, new_field: Field, strict: bool = False, @@ -424,7 +783,7 @@ def set_comment_on_table(self, table_name: str, comment: str) -> None: sql = self.sql_table_comment % (self.quote_name(table_name), "%s") self.execute(sql, (comment,)) - def _create_view_model(self, sql: str, model: Model) -> None: + def _create_view_model(self, sql: str, model: Type[Model]) -> None: """Creates a new view model using the specified SQL query.""" meta = self._view_properties_for_model(model) @@ -451,7 +810,7 @@ def _extract_sql(self, method, *args): return tuple(execute.mock_calls[0])[1] @staticmethod - def _view_properties_for_model(model: Model): + def _view_properties_for_model(model: Type[Model]): """Gets the view options for the specified model. Raises: @@ -483,7 +842,7 @@ def _view_properties_for_model(model: Model): return meta @staticmethod - def _partitioning_properties_for_model(model: Model): + def _partitioning_properties_for_model(model: Type[Model]): """Gets the partitioning options for the specified model. Raises: @@ -546,5 +905,29 @@ def _partitioning_properties_for_model(model: Model): return meta - def create_partition_table_name(self, model: Model, name: str) -> str: + def create_partition_table_name(self, model: Type[Model], name: str) -> str: return "%s_%s" % (model._meta.db_table.lower(), name.lower()) + + def _clone_model_field(self, field: Field, **overrides) -> Field: + """Clones the specified model field and overrides its kwargs with the + specified overrides. + + The cloned field will not be contributed to the model. + """ + + _, _, field_args, field_kwargs = field.deconstruct() + + cloned_field_args = field_args[:] + cloned_field_kwargs = {**field_kwargs, **overrides} + + cloned_field = field.__class__( + *cloned_field_args, **cloned_field_kwargs + ) + cloned_field.model = field.model + cloned_field.set_attributes_from_name(field.name) + + if cloned_field.remote_field: + cloned_field.remote_field.model = field.remote_field.model + cloned_field.set_attributes_from_rel() + + return cloned_field diff --git a/tests/db_introspection.py b/tests/db_introspection.py index bdcd4b19..eabc7414 100644 --- a/tests/db_introspection.py +++ b/tests/db_introspection.py @@ -4,38 +4,101 @@ This makes test code less verbose and easier to read/write. """ +from contextlib import contextmanager +from typing import Optional + from django.db import connection -def table_names(include_views: bool = True): +@contextmanager +def introspect(schema_name: Optional[str] = None): + default_schema_name = connection.ops.default_schema_name() + search_path = [schema_name or default_schema_name] + + with connection.introspection.in_search_path(search_path) as introspection: + with connection.cursor() as cursor: + yield introspection, cursor + + +def table_names( + include_views: bool = True, *, schema_name: Optional[str] = None +): """Gets a flat list of tables in the default database.""" - with connection.cursor() as cursor: - introspection = connection.introspection + with introspect(schema_name) as (introspection, cursor): return introspection.table_names(cursor, include_views) -def get_partitioned_table(table_name: str): +def get_partitioned_table( + table_name: str, + *, + schema_name: Optional[str] = None, +): """Gets the definition of a partitioned table in the default database.""" - with connection.cursor() as cursor: - introspection = connection.introspection + with introspect(schema_name) as (introspection, cursor): return introspection.get_partitioned_table(cursor, table_name) -def get_partitions(table_name: str): +def get_partitions( + table_name: str, + *, + schema_name: Optional[str] = None, +): """Gets a list of partitions for the specified partitioned table in the default database.""" - with connection.cursor() as cursor: - introspection = connection.introspection + with introspect(schema_name) as (introspection, cursor): return introspection.get_partitions(cursor, table_name) -def get_constraints(table_name: str): - """Gets a complete list of constraints and indexes for the specified - table.""" +def get_columns( + table_name: str, + *, + schema_name: Optional[str] = None, +): + """Gets a list of columns for the specified table.""" + + with introspect(schema_name) as (introspection, cursor): + return introspection.get_columns(cursor, table_name) + + +def get_relations( + table_name: str, + *, + schema_name: Optional[str] = None, +): + """Gets a list of relations for the specified table.""" + + with introspect(schema_name) as (introspection, cursor): + return introspection.get_relations(cursor, table_name) - with connection.cursor() as cursor: - introspection = connection.introspection + +def get_constraints( + table_name: str, + *, + schema_name: Optional[str] = None, +): + """Gets a list of constraints and indexes for the specified table.""" + + with introspect(schema_name) as (introspection, cursor): return introspection.get_constraints(cursor, table_name) + + +def get_sequences( + table_name: str, + *, + schema_name: Optional[str] = None, +): + """Gets a list of sequences own by the specified table.""" + + with introspect(schema_name) as (introspection, cursor): + return introspection.get_sequences(cursor, table_name) + + +def get_storage_settings(table_name: str, *, schema_name: Optional[str] = None): + """Gets a list of all storage settings that have been set on the specified + table.""" + + with introspect(schema_name) as (introspection, cursor): + return introspection.get_storage_settings(cursor, table_name) diff --git a/tests/fake_model.py b/tests/fake_model.py index 1254e762..ec626f3a 100644 --- a/tests/fake_model.py +++ b/tests/fake_model.py @@ -3,9 +3,10 @@ import uuid from contextlib import contextmanager +from typing import Type from django.apps import AppConfig, apps -from django.db import connection +from django.db import connection, models from psqlextra.models import ( PostgresMaterializedViewModel, @@ -39,6 +40,17 @@ def define_fake_model( return model +def undefine_fake_model(model: Type[models.Model]) -> None: + """Removes the fake model from the app registry.""" + + app_label = model._meta.app_label or "tests" + app_models = apps.app_configs[app_label].models + + for model_name in [model.__name__, model.__name__.lower()]: + if model_name in app_models: + del app_models[model_name] + + def define_fake_view_model( fields=None, view_options={}, meta_options={}, model_base=PostgresViewModel ): @@ -115,6 +127,15 @@ def get_fake_model(fields=None, model_base=PostgresModel, meta_options={}): return model +def delete_fake_model(model: Type[models.Model]) -> None: + """Deletes a fake model from the database and the internal app registry.""" + + undefine_fake_model(model) + + with connection.schema_editor() as schema_editor: + schema_editor.delete_model(model) + + @contextmanager def define_fake_app(): """Creates and registers a fake Django app.""" diff --git a/tests/test_schema_editor_clone_model_to_schema.py b/tests/test_schema_editor_clone_model_to_schema.py new file mode 100644 index 00000000..712d1433 --- /dev/null +++ b/tests/test_schema_editor_clone_model_to_schema.py @@ -0,0 +1,321 @@ +import os + +from typing import Set, Tuple + +import django +import pytest + +from django.contrib.postgres.fields import ArrayField +from django.contrib.postgres.indexes import GinIndex +from django.db import connection, models, transaction +from django.db.models import Q + +from psqlextra.backend.schema import PostgresSchemaEditor + +from . import db_introspection +from .fake_model import delete_fake_model, get_fake_model + +django_32_skip_reason = "Django < 3.2 can't support cloning models because it has hard coded references to the public schema" + + +def _create_schema() -> str: + name = os.urandom(4).hex() + + with connection.cursor() as cursor: + cursor.execute( + "CREATE SCHEMA %s" % connection.ops.quote_name(name), tuple() + ) + + return name + + +def _assert_cloned_table_is_same( + source_table_fqn: Tuple[str, str], + target_table_fqn: Tuple[str, str], + excluding_constraints_and_indexes: bool = False, +): + source_schema_name, source_table_name = source_table_fqn + target_schema_name, target_table_name = target_table_fqn + + source_columns = db_introspection.get_columns( + source_table_name, schema_name=source_schema_name + ) + source_columns = db_introspection.get_columns( + target_table_name, schema_name=target_schema_name + ) + assert source_columns == source_columns + + source_relations = db_introspection.get_relations( + source_table_name, schema_name=source_schema_name + ) + source_relations = db_introspection.get_relations( + target_table_name, schema_name=target_schema_name + ) + if excluding_constraints_and_indexes: + assert source_relations == {} + else: + assert source_relations == source_relations + + source_constraints = db_introspection.get_constraints( + source_table_name, schema_name=source_schema_name + ) + source_constraints = db_introspection.get_constraints( + target_table_name, schema_name=target_schema_name + ) + if excluding_constraints_and_indexes: + assert source_constraints == {} + else: + assert source_constraints == source_constraints + + source_sequences = db_introspection.get_sequences( + source_table_name, schema_name=source_schema_name + ) + source_sequences = db_introspection.get_sequences( + target_table_name, schema_name=target_schema_name + ) + assert source_sequences == source_sequences + + source_storage_settings = db_introspection.get_storage_settings( + source_table_name, + schema_name=source_schema_name, + ) + source_storage_settings = db_introspection.get_storage_settings( + target_table_name, schema_name=target_schema_name + ) + assert source_storage_settings == source_storage_settings + + +def _list_lock_modes_in_schema(schema_name: str) -> Set[str]: + with connection.cursor() as cursor: + cursor.execute( + """ + SELECT + l.mode + FROM pg_locks l + INNER JOIN pg_class t ON t.oid = l.relation + INNER JOIN pg_namespace n ON n.oid = t.relnamespace + WHERE + t.relnamespace >= 2200 + AND n.nspname = %s + ORDER BY n.nspname, t.relname, l.mode + """, + (schema_name,), + ) + + return {lock_mode for lock_mode, in cursor.fetchall()} + + +def _clone_model_into_schema(model): + schema_name = _create_schema() + + schema_editor = PostgresSchemaEditor(connection) + schema_editor.clone_model_structure_to_schema( + model, schema_name=schema_name + ) + schema_editor.clone_model_constraints_and_indexes_to_schema( + model, schema_name=schema_name + ) + schema_editor.clone_model_foreign_keys_to_schema( + model, schema_name=schema_name + ) + + return schema_name + + +@pytest.fixture +def fake_model_fk_target_1(): + model = get_fake_model( + { + "name": models.TextField(), + }, + ) + + yield model + + delete_fake_model(model) + + +@pytest.fixture +def fake_model_fk_target_2(): + model = get_fake_model( + { + "name": models.TextField(), + }, + ) + + yield model + + delete_fake_model(model) + + +@pytest.fixture +def fake_model(fake_model_fk_target_1, fake_model_fk_target_2): + model = get_fake_model( + { + "first_name": models.TextField(null=True), + "last_name": models.TextField(), + "age": models.PositiveIntegerField(), + "height": models.FloatField(), + "nicknames": ArrayField(base_field=models.TextField()), + "blob": models.JSONField(), + "family": models.ForeignKey( + fake_model_fk_target_1, on_delete=models.CASCADE + ), + "alternative_family": models.ForeignKey( + fake_model_fk_target_2, null=True, on_delete=models.SET_NULL + ), + }, + meta_options={ + "indexes": [ + models.Index(fields=["age", "height"]), + models.Index(fields=["age"], name="age_index"), + GinIndex(fields=["nicknames"], name="nickname_index"), + ], + "constraints": [ + models.UniqueConstraint( + fields=["first_name", "last_name"], + name="first_last_name_uniq", + ), + models.CheckConstraint( + check=Q(age__gt=0, height__gt=0), name="age_height_check" + ), + ], + "unique_together": ( + "first_name", + "nicknames", + ), + "index_together": ( + "blob", + "age", + ), + }, + ) + + yield model + + delete_fake_model(model) + + +@pytest.mark.skipif( + django.VERSION < (3, 2), + reason=django_32_skip_reason, +) +@pytest.mark.django_db(transaction=True) +def test_schema_editor_clone_model_to_schema( + fake_model, fake_model_fk_target_1, fake_model_fk_target_2 +): + """Tests that cloning a model into a separate schema without obtaining + AccessExclusiveLock on the source table works as expected.""" + + schema_editor = PostgresSchemaEditor(connection) + schema_editor.alter_table_storage_setting( + fake_model._meta.db_table, "autovacuum_enabled", "false" + ) + + table_name = fake_model._meta.db_table + source_schema_name = connection.ops.default_schema_name() + target_schema_name = _create_schema() + + with transaction.atomic(durable=True): + schema_editor.clone_model_structure_to_schema( + fake_model, schema_name=target_schema_name + ) + + assert _list_lock_modes_in_schema(source_schema_name) == { + "AccessShareLock" + } + + _assert_cloned_table_is_same( + (source_schema_name, table_name), + (target_schema_name, table_name), + excluding_constraints_and_indexes=True, + ) + + with transaction.atomic(durable=True): + schema_editor.clone_model_constraints_and_indexes_to_schema( + fake_model, schema_name=target_schema_name + ) + + assert _list_lock_modes_in_schema(source_schema_name) == { + "AccessShareLock", + "ShareRowExclusiveLock", + } + + _assert_cloned_table_is_same( + (source_schema_name, table_name), + (target_schema_name, table_name), + ) + + with transaction.atomic(durable=True): + schema_editor.clone_model_foreign_keys_to_schema( + fake_model, schema_name=target_schema_name + ) + + assert _list_lock_modes_in_schema(source_schema_name) == { + "AccessShareLock", + "RowShareLock", + } + + _assert_cloned_table_is_same( + (source_schema_name, table_name), + (target_schema_name, table_name), + ) + + +@pytest.mark.skipif( + django.VERSION < (3, 2), + reason=django_32_skip_reason, +) +def test_schema_editor_clone_model_to_schema_custom_constraint_names( + fake_model, +): + """Tests that even if constraints were given custom names, the cloned table + has those same custom names.""" + + table_name = fake_model._meta.db_table + source_schema_name = connection.ops.default_schema_name() + + constraints = db_introspection.get_constraints(table_name) + + primary_key_constraint = next( + ( + name + for name, constraint in constraints.items() + if constraint["primary_key"] + ), + None, + ) + foreign_key_constraint = next( + ( + name + for name, constraint in constraints.items() + if constraint["foreign_key"] + ), + None, + ) + check_constraint = next( + ( + name + for name, constraint in constraints.items() + if constraint["check"] + ), + None, + ) + + with connection.cursor() as cursor: + cursor.execute( + f"ALTER TABLE {table_name} RENAME CONSTRAINT {primary_key_constraint} TO custompkname" + ) + cursor.execute( + f"ALTER TABLE {table_name} RENAME CONSTRAINT {foreign_key_constraint} TO customfkname" + ) + cursor.execute( + f"ALTER TABLE {table_name} RENAME CONSTRAINT {check_constraint} TO customcheckname" + ) + + target_schema_name = _clone_model_into_schema(fake_model) + + _assert_cloned_table_is_same( + (source_schema_name, table_name), + (target_schema_name, table_name), + ) diff --git a/tests/test_schema_editor_storage_settings.py b/tests/test_schema_editor_storage_settings.py new file mode 100644 index 00000000..0f45934f --- /dev/null +++ b/tests/test_schema_editor_storage_settings.py @@ -0,0 +1,47 @@ +import pytest + +from django.db import connection, models + +from psqlextra.backend.schema import PostgresSchemaEditor + +from . import db_introspection +from .fake_model import get_fake_model + + +@pytest.fixture +def fake_model(): + return get_fake_model( + { + "text": models.TextField(), + } + ) + + +def test_schema_editor_storage_settings_table_alter_and_reset(fake_model): + table_name = fake_model._meta.db_table + schema_editor = PostgresSchemaEditor(connection) + + schema_editor.alter_table_storage_setting( + table_name, "autovacuum_enabled", "false" + ) + assert db_introspection.get_storage_settings(table_name) == { + "autovacuum_enabled": "false" + } + + schema_editor.reset_table_storage_setting(table_name, "autovacuum_enabled") + assert db_introspection.get_storage_settings(table_name) == {} + + +def test_schema_editor_storage_settings_model_alter_and_reset(fake_model): + table_name = fake_model._meta.db_table + schema_editor = PostgresSchemaEditor(connection) + + schema_editor.alter_model_storage_setting( + fake_model, "autovacuum_enabled", "false" + ) + assert db_introspection.get_storage_settings(table_name) == { + "autovacuum_enabled": "false" + } + + schema_editor.reset_model_storage_setting(fake_model, "autovacuum_enabled") + assert db_introspection.get_storage_settings(table_name) == {} From 3f2486c13e6098fded58e1f1fa9162c4f2cd3190 Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Thu, 6 Apr 2023 14:32:06 +0300 Subject: [PATCH 15/43] Add vacuum methods to schema editor --- psqlextra/backend/schema.py | 87 +++++++++++++++++ tests/test_schema_editor_vacuum.py | 147 +++++++++++++++++++++++++++++ 2 files changed, 234 insertions(+) create mode 100644 tests/test_schema_editor_vacuum.py diff --git a/psqlextra/backend/schema.py b/psqlextra/backend/schema.py index b45ea8de..435b0bdd 100644 --- a/psqlextra/backend/schema.py +++ b/psqlextra/backend/schema.py @@ -777,6 +777,93 @@ def alter_field( for side_effect in self.side_effects: side_effect.alter_field(model, old_field, new_field, strict) + def vacuum_table( + self, + table_name: str, + columns: List[str] = [], + *, + full: bool = False, + freeze: bool = False, + verbose: bool = False, + analyze: bool = False, + disable_page_skipping: bool = False, + skip_locked: bool = False, + index_cleanup: bool = False, + truncate: bool = False, + parallel: Optional[int] = None, + ) -> None: + """Runs the VACUUM statement on the specified table with the specified + options. + + Arguments: + table_name: + Name of the table to run VACUUM on. + + columns: + Optionally, a list of columns to vacuum. If not + specified, all columns are vacuumed. + """ + + if self.connection.in_atomic_block: + raise SuspiciousOperation("Vacuum cannot be done in a transaction") + + options = [] + if full: + options.append("FULL") + if freeze: + options.append("FREEZE") + if verbose: + options.append("VERBOSE") + if analyze: + options.append("ANALYZE") + if disable_page_skipping: + options.append("DISABLE_PAGE_SKIPPING") + if skip_locked: + options.append("SKIP_LOCKED") + if index_cleanup: + options.append("INDEX_CLEANUP") + if truncate: + options.append("TRUNCATE") + if parallel is not None: + options.append(f"PARALLEL {parallel}") + + sql = "VACUUM" + + if options: + options_sql = ", ".join(options) + sql += f" ({options_sql})" + + sql += f" {self.quote_name(table_name)}" + + if columns: + columns_sql = ", ".join( + [self.quote_name(column) for column in columns] + ) + sql += f" ({columns_sql})" + + self.execute(sql) + + def vacuum_model( + self, model: Type[Model], fields: List[Field] = [], **kwargs + ) -> None: + """Runs the VACUUM statement on the table of the specified model with + the specified options. + + Arguments: + table_name: + model: + Model of which to run VACUUM the table. + + fields: + Optionally, a list of fields to vacuum. If not + specified, all fields are vacuumed. + """ + + columns = [ + field.column for field in fields if field.concrete and field.column + ] + self.vacuum_table(model._meta.db_table, columns, **kwargs) + def set_comment_on_table(self, table_name: str, comment: str) -> None: """Sets the comment on the specified table.""" diff --git a/tests/test_schema_editor_vacuum.py b/tests/test_schema_editor_vacuum.py new file mode 100644 index 00000000..59772e86 --- /dev/null +++ b/tests/test_schema_editor_vacuum.py @@ -0,0 +1,147 @@ +import pytest + +from django.core.exceptions import SuspiciousOperation +from django.db import connection, models +from django.test.utils import CaptureQueriesContext + +from psqlextra.backend.schema import PostgresSchemaEditor + +from .fake_model import delete_fake_model, get_fake_model + + +@pytest.fixture +def fake_model(): + model = get_fake_model( + { + "name": models.TextField(), + } + ) + + yield model + + delete_fake_model(model) + + +@pytest.fixture +def fake_model_non_concrete_field(fake_model): + model = get_fake_model( + { + "fk": models.ForeignKey( + fake_model, on_delete=models.CASCADE, related_name="fakes" + ), + } + ) + + yield model + + delete_fake_model(model) + + +def test_schema_editor_vacuum_not_in_transaction(fake_model): + schema_editor = PostgresSchemaEditor(connection) + + with pytest.raises(SuspiciousOperation): + schema_editor.vacuum_table(fake_model._meta.db_table) + + +@pytest.mark.parametrize( + "kwargs,query", + [ + (dict(), "VACUUM %s"), + (dict(full=True), "VACUUM (FULL) %s"), + (dict(analyze=True), "VACUUM (ANALYZE) %s"), + (dict(parallel=8), "VACUUM (PARALLEL 8) %s"), + (dict(analyze=True, verbose=True), "VACUUM (VERBOSE, ANALYZE) %s"), + ( + dict(analyze=True, parallel=8, verbose=True), + "VACUUM (VERBOSE, ANALYZE, PARALLEL 8) %s", + ), + (dict(freeze=True), "VACUUM (FREEZE) %s"), + (dict(verbose=True), "VACUUM (VERBOSE) %s"), + (dict(disable_page_skipping=True), "VACUUM (DISABLE_PAGE_SKIPPING) %s"), + (dict(skip_locked=True), "VACUUM (SKIP_LOCKED) %s"), + (dict(index_cleanup=True), "VACUUM (INDEX_CLEANUP) %s"), + (dict(truncate=True), "VACUUM (TRUNCATE) %s"), + ], +) +@pytest.mark.django_db(transaction=True) +def test_schema_editor_vacuum_table(fake_model, kwargs, query): + schema_editor = PostgresSchemaEditor(connection) + + with CaptureQueriesContext(connection) as ctx: + schema_editor.vacuum_table(fake_model._meta.db_table, **kwargs) + + queries = [query["sql"] for query in ctx.captured_queries] + assert queries == [ + query % connection.ops.quote_name(fake_model._meta.db_table) + ] + + +@pytest.mark.django_db(transaction=True) +def test_schema_editor_vacuum_table_columns(fake_model): + schema_editor = PostgresSchemaEditor(connection) + + with CaptureQueriesContext(connection) as ctx: + schema_editor.vacuum_table( + fake_model._meta.db_table, ["id", "name"], analyze=True + ) + + queries = [query["sql"] for query in ctx.captured_queries] + assert queries == [ + 'VACUUM (ANALYZE) %s ("id", "name")' + % connection.ops.quote_name(fake_model._meta.db_table) + ] + + +@pytest.mark.django_db(transaction=True) +def test_schema_editor_vacuum_model(fake_model): + schema_editor = PostgresSchemaEditor(connection) + + with CaptureQueriesContext(connection) as ctx: + schema_editor.vacuum_model(fake_model, analyze=True, parallel=8) + + queries = [query["sql"] for query in ctx.captured_queries] + assert queries == [ + "VACUUM (ANALYZE, PARALLEL 8) %s" + % connection.ops.quote_name(fake_model._meta.db_table) + ] + + +@pytest.mark.django_db(transaction=True) +def test_schema_editor_vacuum_model_fields(fake_model): + schema_editor = PostgresSchemaEditor(connection) + + with CaptureQueriesContext(connection) as ctx: + schema_editor.vacuum_model( + fake_model, + [fake_model._meta.get_field("name")], + analyze=True, + parallel=8, + ) + + queries = [query["sql"] for query in ctx.captured_queries] + assert queries == [ + 'VACUUM (ANALYZE, PARALLEL 8) %s ("name")' + % connection.ops.quote_name(fake_model._meta.db_table) + ] + + +@pytest.mark.django_db(transaction=True) +def test_schema_editor_vacuum_model_non_concrete_fields( + fake_model, fake_model_non_concrete_field +): + schema_editor = PostgresSchemaEditor(connection) + + with CaptureQueriesContext(connection) as ctx: + schema_editor.vacuum_model( + fake_model, + [fake_model._meta.get_field("fakes")], + analyze=True, + parallel=8, + ) + + queries = [query["sql"] for query in ctx.captured_queries] + assert queries == [ + "VACUUM (ANALYZE, PARALLEL 8) %s" + % connection.ops.quote_name(fake_model._meta.db_table) + ] From e1a43cd72edbae796ed6e3c469a98ef8f271a2bd Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Fri, 7 Apr 2023 11:51:28 +0300 Subject: [PATCH 16/43] Use Postgres 13 on CI --- .circleci/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index f85823b7..bb545bad 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -7,7 +7,7 @@ executors: type: string docker: - image: python:<< parameters.version >>-buster - - image: postgres:12.0 + - image: postgres:13.0 environment: POSTGRES_DB: 'psqlextra' POSTGRES_USER: 'psqlextra' From 6eff3f12a1c21b4b7530b229b45a5b700a8062da Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Fri, 7 Apr 2023 16:19:29 +0300 Subject: [PATCH 17/43] Add support to schema editor for moving tables between schemas --- psqlextra/backend/introspection.py | 30 +---- psqlextra/backend/operations.py | 3 - psqlextra/backend/schema.py | 71 +++++++++-- psqlextra/settings.py | 118 ++++++++++++++++++ tests/db_introspection.py | 9 +- tests/test_schema_editor_alter_schema.py | 44 +++++++ ...est_schema_editor_clone_model_to_schema.py | 73 ++++++----- tests/test_settings.py | 93 ++++++++++++++ 8 files changed, 361 insertions(+), 80 deletions(-) create mode 100644 psqlextra/settings.py create mode 100644 tests/test_schema_editor_alter_schema.py create mode 100644 tests/test_settings.py diff --git a/psqlextra/backend/introspection.py b/psqlextra/backend/introspection.py index 03e1fdef..90717b6a 100644 --- a/psqlextra/backend/introspection.py +++ b/psqlextra/backend/introspection.py @@ -1,9 +1,6 @@ -from contextlib import contextmanager from dataclasses import dataclass from typing import Dict, List, Optional, Tuple -from django.db import transaction - from psqlextra.types import PostgresPartitioningMethod from . import base_impl @@ -250,6 +247,7 @@ def get_storage_settings(self, cursor, table_name: str) -> Dict[str, str]: pg_catalog.pg_am am ON (c.relam = am.oid) WHERE c.relname::text = %s + AND pg_catalog.pg_table_is_visible(c.oid) """ cursor.execute(sql, (table_name,)) @@ -288,29 +286,3 @@ def get_relations(self, cursor, table_name: str): [table_name], ) return {row[0]: (row[2], row[1]) for row in cursor.fetchall()} - - @contextmanager - def in_search_path(self, search_path: List[str]): - """Changes the Postgres `search_path` within the context and switches - it back when it exits.""" - - # Wrap in a transaction so a savepoint is created. If - # something goes wrong, the `SET LOCAL search_path` - # statement will be rolled back. - with transaction.atomic(using=self.connection.alias): - with self.connection.cursor() as cursor: - cursor.execute("SHOW search_path") - (original_search_path,) = cursor.fetchone() - - # Syntax in Postgres is a bit weird here. It isn't really - # a list of names like in `WHERE bla in (val1, val2)`. - placeholder = ", ".join(["%s" for _ in search_path]) - cursor.execute( - f"SET LOCAL search_path = {placeholder}", search_path - ) - - yield self - - cursor.execute( - f"SET LOCAL search_path = {original_search_path}" - ) diff --git a/psqlextra/backend/operations.py b/psqlextra/backend/operations.py index 24adf5d0..52793fac 100644 --- a/psqlextra/backend/operations.py +++ b/psqlextra/backend/operations.py @@ -21,6 +21,3 @@ class PostgresOperations(base_impl.operations()): SQLUpdateCompiler, SQLInsertCompiler, ] - - def default_schema_name(self) -> str: - return "public" diff --git a/psqlextra/backend/schema.py b/psqlextra/backend/schema.py index 435b0bdd..1e21b366 100644 --- a/psqlextra/backend/schema.py +++ b/psqlextra/backend/schema.py @@ -12,6 +12,10 @@ from django.db.backends.ddl_references import Statement from django.db.models import Field, Model +from psqlextra.settings import ( + postgres_prepend_local_search_path, + postgres_reset_local_search_path, +) from psqlextra.type_assertions import is_sql_with_params from psqlextra.types import PostgresPartitioningMethod @@ -40,6 +44,8 @@ class PostgresSchemaEditor(SchemaEditor): sql_alter_table_storage_setting = "ALTER TABLE %s SET (%s = %s)" sql_reset_table_storage_setting = "ALTER TABLE %s RESET (%s)" + sql_alter_table_schema = "ALTER TABLE %s SET SCHEMA %s" + sql_create_view = "CREATE VIEW %s AS (%s)" sql_replace_view = "CREATE OR REPLACE VIEW %s AS (%s)" sql_drop_view = "DROP VIEW IF EXISTS %s" @@ -203,8 +209,8 @@ def clone_model_constraints_and_indexes_to_schema( resides. """ - with self.introspection.in_search_path( - [schema_name, self.connection.ops.default_schema_name()] + with postgres_prepend_local_search_path( + [schema_name], using=self.connection.alias ): for constraint in model._meta.constraints: self.add_constraint(model, constraint) @@ -226,8 +232,8 @@ def clone_model_constraints_and_indexes_to_schema( # Django creates primary keys later added to the model with # a custom name. We want the name as it was created originally. if field.primary_key: - with self.introspection.in_search_path( - [self.connection.ops.default_schema_name()] + with postgres_reset_local_search_path( + using=self.connection.alias ): [primary_key_name] = self._constraint_names( model, primary_key=True @@ -251,8 +257,8 @@ def clone_model_constraints_and_indexes_to_schema( # a separate transaction later to validate the entries without # acquiring a AccessExclusiveLock. if field.remote_field: - with self.introspection.in_search_path( - [self.connection.ops.default_schema_name()] + with postgres_reset_local_search_path( + using=self.connection.alias ): [fk_name] = self._constraint_names( model, [field.column], foreign_key=True @@ -277,8 +283,8 @@ def clone_model_constraints_and_indexes_to_schema( # manually. field_check = field.db_parameters(self.connection).get("check") if field_check: - with self.introspection.in_search_path( - [self.connection.ops.default_schema_name()] + with postgres_reset_local_search_path( + using=self.connection.alias ): [field_check_name] = self._constraint_names( model, @@ -337,10 +343,12 @@ def clone_model_foreign_keys_to_schema( resides. """ - with self.introspection.in_search_path( - [schema_name, self.connection.ops.default_schema_name()] + constraint_names = self._constraint_names(model, foreign_key=True) + + with postgres_prepend_local_search_path( + [schema_name], using=self.connection.alias ): - for fk_name in self._constraint_names(model, foreign_key=True): + for fk_name in constraint_names: self.execute( self.sql_validate_fk % ( @@ -438,6 +446,47 @@ def reset_model_storage_setting( self.reset_table_storage_setting(model._meta.db_table, name) + def alter_table_schema(self, table_name: str, schema_name: str) -> None: + """Moves the specified table into the specified schema. + + WARNING: Moving models into a different schema than the default + will break querying the model. + + Arguments: + table_name: + Name of the table to move into the specified schema. + + schema_name: + Name of the schema to move the table to. + """ + + self.execute( + self.sql_alter_table_schema + % (self.quote_name(table_name), self.quote_name(schema_name)) + ) + + def alter_model_schema(self, model: Type[Model], schema_name: str) -> None: + """Moves the specified model's table into the specified schema. + + WARNING: Moving models into a different schema than the default + will break querying the model. + + Arguments: + model: + Model of which to move the table. + + schema_name: + Name of the schema to move the model's table to. + """ + + self.execute( + self.sql_alter_table_schema + % ( + self.quote_name(model._meta.db_table), + self.quote_name(schema_name), + ) + ) + def refresh_materialized_view_model( self, model: Type[Model], concurrently: bool = False ) -> None: diff --git a/psqlextra/settings.py b/psqlextra/settings.py new file mode 100644 index 00000000..6dd32f37 --- /dev/null +++ b/psqlextra/settings.py @@ -0,0 +1,118 @@ +from contextlib import contextmanager +from typing import Dict, List, Optional, Union + +from django.core.exceptions import SuspiciousOperation +from django.db import DEFAULT_DB_ALIAS, connections + + +@contextmanager +def postgres_set_local( + *, + using: str = DEFAULT_DB_ALIAS, + **options: Dict[str, Optional[Union[str, int, float, List[str]]]], +) -> None: + """Sets the specified PostgreSQL options using SET LOCAL so that they apply + to the current transacton only. + + The effect is undone when the context manager exits. + + See https://www.postgresql.org/docs/current/runtime-config-client.html + for an overview of all available options. + """ + + connection = connections[using] + qn = connection.ops.quote_name + + if not connection.in_atomic_block: + raise SuspiciousOperation( + "SET LOCAL makes no sense outside a transaction. Start a transaction first." + ) + + sql = [] + params = [] + for name, value in options.items(): + if value is None: + sql.append(f"SET LOCAL {qn(name)} TO DEFAULT") + continue + + # Settings that accept a list of values are actually + # stored as string lists. We cannot just pass a list + # of values. We have to create the comma separated + # string ourselves. + if isinstance(value, list) or isinstance(value, tuple): + placeholder = ", ".join(["%s" for _ in value]) + params.extend(value) + else: + placeholder = "%s" + params.append(value) + + sql.append(f"SET LOCAL {qn(name)} = {placeholder}") + + with connection.cursor() as cursor: + cursor.execute( + "SELECT name, setting FROM pg_settings WHERE name = ANY(%s)", + (list(options.keys()),), + ) + original_values = dict(cursor.fetchall()) + cursor.execute("; ".join(sql), params) + + yield + + # Put everything back to how it was. DEFAULT is + # not good enough as a outer SET LOCAL might + # have set a different value. + with connection.cursor() as cursor: + sql = [] + params = [] + + for name, value in options.items(): + original_value = original_values.get(name) + if original_value: + sql.append(f"SET LOCAL {qn(name)} = {original_value}") + else: + sql.append(f"SET LOCAL {qn(name)} TO DEFAULT") + + cursor.execute("; ".join(sql), params) + + +@contextmanager +def postgres_set_local_search_path( + search_path: List[str], *, using: str = DEFAULT_DB_ALIAS +) -> None: + """Sets the search path to the specified schemas.""" + + with postgres_set_local(search_path=search_path, using=using): + yield + + +@contextmanager +def postgres_prepend_local_search_path( + search_path: List[str], *, using: str = DEFAULT_DB_ALIAS +) -> None: + """Prepends the current local search path with the specified schemas.""" + + connection = connections[using] + + with connection.cursor() as cursor: + cursor.execute("SHOW search_path") + [ + original_search_path, + ] = cursor.fetchone() + + placeholders = ", ".join(["%s" for _ in search_path]) + cursor.execute( + f"SET LOCAL search_path = {placeholders}, {original_search_path}", + tuple(search_path), + ) + + yield + + cursor.execute(f"SET LOCAL search_path = {original_search_path}") + + +@contextmanager +def postgres_reset_local_search_path(*, using: str = DEFAULT_DB_ALIAS) -> None: + """Resets the local search path to the default.""" + + with postgres_set_local(search_path=None, using=using): + yield diff --git a/tests/db_introspection.py b/tests/db_introspection.py index eabc7414..285cd0e4 100644 --- a/tests/db_introspection.py +++ b/tests/db_introspection.py @@ -9,15 +9,14 @@ from django.db import connection +from psqlextra.settings import postgres_set_local + @contextmanager def introspect(schema_name: Optional[str] = None): - default_schema_name = connection.ops.default_schema_name() - search_path = [schema_name or default_schema_name] - - with connection.introspection.in_search_path(search_path) as introspection: + with postgres_set_local(search_path=schema_name or None): with connection.cursor() as cursor: - yield introspection, cursor + yield connection.introspection, cursor def table_names( diff --git a/tests/test_schema_editor_alter_schema.py b/tests/test_schema_editor_alter_schema.py new file mode 100644 index 00000000..7fda103b --- /dev/null +++ b/tests/test_schema_editor_alter_schema.py @@ -0,0 +1,44 @@ +import pytest + +from django.db import connection, models + +from psqlextra.backend.schema import PostgresSchemaEditor + +from .fake_model import get_fake_model + + +@pytest.fixture +def fake_model(): + return get_fake_model( + { + "text": models.TextField(), + } + ) + + +def test_schema_editor_alter_table_schema(fake_model): + obj = fake_model.objects.create(text="hello") + + with connection.cursor() as cursor: + cursor.execute("CREATE SCHEMA target") + + schema_editor = PostgresSchemaEditor(connection) + schema_editor.alter_table_schema(fake_model._meta.db_table, "target") + + with connection.cursor() as cursor: + cursor.execute(f"SELECT * FROM target.{fake_model._meta.db_table}") + assert cursor.fetchall() == [(obj.id, obj.text)] + + +def test_schema_editor_alter_model_schema(fake_model): + obj = fake_model.objects.create(text="hello") + + with connection.cursor() as cursor: + cursor.execute("CREATE SCHEMA target") + + schema_editor = PostgresSchemaEditor(connection) + schema_editor.alter_model_schema(fake_model, "target") + + with connection.cursor() as cursor: + cursor.execute(f"SELECT * FROM target.{fake_model._meta.db_table}") + assert cursor.fetchall() == [(obj.id, obj.text)] diff --git a/tests/test_schema_editor_clone_model_to_schema.py b/tests/test_schema_editor_clone_model_to_schema.py index 712d1433..c3d41917 100644 --- a/tests/test_schema_editor_clone_model_to_schema.py +++ b/tests/test_schema_editor_clone_model_to_schema.py @@ -22,6 +22,11 @@ def _create_schema() -> str: name = os.urandom(4).hex() with connection.cursor() as cursor: + cursor.execute( + "DROP SCHEMA IF EXISTS %s CASCADE" + % connection.ops.quote_name(name), + tuple(), + ) cursor.execute( "CREATE SCHEMA %s" % connection.ops.quote_name(name), tuple() ) @@ -29,6 +34,7 @@ def _create_schema() -> str: return name +@transaction.atomic def _assert_cloned_table_is_same( source_table_fqn: Tuple[str, str], target_table_fqn: Tuple[str, str], @@ -40,49 +46,49 @@ def _assert_cloned_table_is_same( source_columns = db_introspection.get_columns( source_table_name, schema_name=source_schema_name ) - source_columns = db_introspection.get_columns( + target_columns = db_introspection.get_columns( target_table_name, schema_name=target_schema_name ) - assert source_columns == source_columns + assert source_columns == target_columns source_relations = db_introspection.get_relations( source_table_name, schema_name=source_schema_name ) - source_relations = db_introspection.get_relations( + target_relations = db_introspection.get_relations( target_table_name, schema_name=target_schema_name ) if excluding_constraints_and_indexes: - assert source_relations == {} + assert target_relations == {} else: - assert source_relations == source_relations + assert source_relations == target_relations source_constraints = db_introspection.get_constraints( source_table_name, schema_name=source_schema_name ) - source_constraints = db_introspection.get_constraints( + target_constraints = db_introspection.get_constraints( target_table_name, schema_name=target_schema_name ) if excluding_constraints_and_indexes: - assert source_constraints == {} + assert target_constraints == {} else: - assert source_constraints == source_constraints + assert source_constraints == target_constraints source_sequences = db_introspection.get_sequences( source_table_name, schema_name=source_schema_name ) - source_sequences = db_introspection.get_sequences( + target_sequences = db_introspection.get_sequences( target_table_name, schema_name=target_schema_name ) - assert source_sequences == source_sequences + assert source_sequences == target_sequences source_storage_settings = db_introspection.get_storage_settings( source_table_name, schema_name=source_schema_name, ) - source_storage_settings = db_introspection.get_storage_settings( + target_storage_settings = db_introspection.get_storage_settings( target_table_name, schema_name=target_schema_name ) - assert source_storage_settings == source_storage_settings + assert source_storage_settings == target_storage_settings def _list_lock_modes_in_schema(schema_name: str) -> Set[str]: @@ -108,16 +114,16 @@ def _list_lock_modes_in_schema(schema_name: str) -> Set[str]: def _clone_model_into_schema(model): schema_name = _create_schema() - schema_editor = PostgresSchemaEditor(connection) - schema_editor.clone_model_structure_to_schema( - model, schema_name=schema_name - ) - schema_editor.clone_model_constraints_and_indexes_to_schema( - model, schema_name=schema_name - ) - schema_editor.clone_model_foreign_keys_to_schema( - model, schema_name=schema_name - ) + with PostgresSchemaEditor(connection) as schema_editor: + schema_editor.clone_model_structure_to_schema( + model, schema_name=schema_name + ) + schema_editor.clone_model_constraints_and_indexes_to_schema( + model, schema_name=schema_name + ) + schema_editor.clone_model_foreign_keys_to_schema( + model, schema_name=schema_name + ) return schema_name @@ -208,15 +214,17 @@ def test_schema_editor_clone_model_to_schema( AccessExclusiveLock on the source table works as expected.""" schema_editor = PostgresSchemaEditor(connection) - schema_editor.alter_table_storage_setting( - fake_model._meta.db_table, "autovacuum_enabled", "false" - ) + + with schema_editor: + schema_editor.alter_table_storage_setting( + fake_model._meta.db_table, "autovacuum_enabled", "false" + ) table_name = fake_model._meta.db_table - source_schema_name = connection.ops.default_schema_name() + source_schema_name = "public" target_schema_name = _create_schema() - with transaction.atomic(durable=True): + with schema_editor: schema_editor.clone_model_structure_to_schema( fake_model, schema_name=target_schema_name ) @@ -231,7 +239,7 @@ def test_schema_editor_clone_model_to_schema( excluding_constraints_and_indexes=True, ) - with transaction.atomic(durable=True): + with schema_editor: schema_editor.clone_model_constraints_and_indexes_to_schema( fake_model, schema_name=target_schema_name ) @@ -246,7 +254,7 @@ def test_schema_editor_clone_model_to_schema( (target_schema_name, table_name), ) - with transaction.atomic(durable=True): + with schema_editor: schema_editor.clone_model_foreign_keys_to_schema( fake_model, schema_name=target_schema_name ) @@ -267,13 +275,13 @@ def test_schema_editor_clone_model_to_schema( reason=django_32_skip_reason, ) def test_schema_editor_clone_model_to_schema_custom_constraint_names( - fake_model, + fake_model, fake_model_fk_target_1 ): """Tests that even if constraints were given custom names, the cloned table has those same custom names.""" table_name = fake_model._meta.db_table - source_schema_name = connection.ops.default_schema_name() + source_schema_name = "public" constraints = db_introspection.get_constraints(table_name) @@ -290,6 +298,7 @@ def test_schema_editor_clone_model_to_schema_custom_constraint_names( name for name, constraint in constraints.items() if constraint["foreign_key"] + == (fake_model_fk_target_1._meta.db_table, "id") ), None, ) @@ -297,7 +306,7 @@ def test_schema_editor_clone_model_to_schema_custom_constraint_names( ( name for name, constraint in constraints.items() - if constraint["check"] + if constraint["check"] and constraint["columns"] == ["age"] ), None, ) diff --git a/tests/test_settings.py b/tests/test_settings.py new file mode 100644 index 00000000..44519714 --- /dev/null +++ b/tests/test_settings.py @@ -0,0 +1,93 @@ +import pytest + +from django.core.exceptions import SuspiciousOperation +from django.db import connection + +from psqlextra.settings import ( + postgres_prepend_local_search_path, + postgres_reset_local_search_path, + postgres_set_local, + postgres_set_local_search_path, +) + + +def _get_current_setting(name: str) -> None: + with connection.cursor() as cursor: + cursor.execute(f"SHOW {name}") + return cursor.fetchone()[0] + + +@postgres_set_local(statement_timeout="2s", lock_timeout="3s") +def test_postgres_set_local_function_decorator(): + assert _get_current_setting("statement_timeout") == "2s" + assert _get_current_setting("lock_timeout") == "3s" + + +def test_postgres_set_local_context_manager(): + with postgres_set_local(statement_timeout="2s"): + assert _get_current_setting("statement_timeout") == "2s" + + assert _get_current_setting("statement_timeout") == "0" + + +def test_postgres_set_local_iterable(): + with postgres_set_local(search_path=["a", "public"]): + assert _get_current_setting("search_path") == "a, public" + + assert _get_current_setting("search_path") == '"$user", public' + + +def test_postgres_set_local_nested(): + with postgres_set_local(statement_timeout="2s"): + assert _get_current_setting("statement_timeout") == "2s" + + with postgres_set_local(statement_timeout="3s"): + assert _get_current_setting("statement_timeout") == "3s" + + assert _get_current_setting("statement_timeout") == "2s" + + assert _get_current_setting("statement_timeout") == "0" + + +@pytest.mark.django_db(transaction=True) +def test_postgres_set_local_no_transaction(): + with pytest.raises(SuspiciousOperation): + with postgres_set_local(statement_timeout="2s"): + pass + + +def test_postgres_set_local_search_path(): + with postgres_set_local_search_path(["a", "public"]): + assert _get_current_setting("search_path") == "a, public" + + assert _get_current_setting("search_path") == '"$user", public' + + +def test_postgres_reset_local_search_path(): + with postgres_set_local_search_path(["a", "public"]): + with postgres_reset_local_search_path(): + assert _get_current_setting("search_path") == '"$user", public' + + assert _get_current_setting("search_path") == "a, public" + + assert _get_current_setting("search_path") == '"$user", public' + + +def test_postgres_prepend_local_search_path(): + with postgres_prepend_local_search_path(["a", "b"]): + assert _get_current_setting("search_path") == 'a, b, "$user", public' + + assert _get_current_setting("search_path") == '"$user", public' + + +def test_postgres_prepend_local_search_path_nested(): + with postgres_prepend_local_search_path(["a", "b"]): + with postgres_prepend_local_search_path(["c"]): + assert ( + _get_current_setting("search_path") + == 'c, a, b, "$user", public' + ) + + assert _get_current_setting("search_path") == 'a, b, "$user", public' + + assert _get_current_setting("search_path") == '"$user", public' From 91b873c663e916a70be3583342ac77787b0a13a4 Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Tue, 4 Apr 2023 10:46:03 +0300 Subject: [PATCH 18/43] Add `PostgresSchema` to manage Postgres schemas with --- docs/source/api_reference.rst | 5 + docs/source/index.rst | 6 + docs/source/schemas.rst | 169 +++++++++++++++++++ psqlextra/backend/introspection.py | 19 ++- psqlextra/backend/schema.py | 18 ++ psqlextra/error.py | 20 +++ psqlextra/schema.py | 212 ++++++++++++++++++++++++ setup.py | 1 + tests/test_schema.py | 253 +++++++++++++++++++++++++++++ 9 files changed, 700 insertions(+), 3 deletions(-) create mode 100644 docs/source/schemas.rst create mode 100644 psqlextra/error.py create mode 100644 psqlextra/schema.py create mode 100644 tests/test_schema.py diff --git a/docs/source/api_reference.rst b/docs/source/api_reference.rst index 1d64bc8b..7f175fe9 100644 --- a/docs/source/api_reference.rst +++ b/docs/source/api_reference.rst @@ -34,12 +34,17 @@ API Reference .. automodule:: psqlextra.indexes .. autoclass:: UniqueIndex + .. autoclass:: ConditionalUniqueIndex + .. autoclass:: CaseInsensitiveUniqueIndex .. automodule:: psqlextra.locking :members: +.. automodule:: psqlextra.schema + :members: + .. automodule:: psqlextra.partitioning :members: diff --git a/docs/source/index.rst b/docs/source/index.rst index 76600702..1959016e 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -40,6 +40,11 @@ Explore the documentation to learn about all features: Support for explicit table-level locks. +* :ref:`Creating/dropping schemas ` + + Support for managing Postgres schemas. + + .. toctree:: :maxdepth: 2 :caption: Overview @@ -54,6 +59,7 @@ Explore the documentation to learn about all features: expressions annotations locking + schemas settings api_reference major_releases diff --git a/docs/source/schemas.rst b/docs/source/schemas.rst new file mode 100644 index 00000000..23f0f91a --- /dev/null +++ b/docs/source/schemas.rst @@ -0,0 +1,169 @@ +.. include:: ./snippets/postgres_doc_links.rst + +.. _schemas_page: + +Schema +====== + +The :meth:`~psqlextra.schema.PostgresSchema` class provides basic schema management functionality. + +Django does **NOT** support custom schemas. This module does not attempt to solve that problem. + +This module merely allows you to create/drop schemas and allow you to execute raw SQL in a schema. It is not attempt at bringing multi-schema support to Django. + + +Reference an existing schema +---------------------------- + +.. code-block:: python + + for psqlextra.schema import PostgresSchema + + schema = PostgresSchema("myschema") + + with schema.connection.cursor() as cursor: + cursor.execute("SELECT * FROM tablethatexistsinmyschema") + + +Checking if a schema exists +--------------------------- + +.. code-block:: python + + for psqlextra.schema import PostgresSchema + + schema = PostgresSchema("myschema") + if PostgresSchema.exists("myschema"): + print("exists!") + else: + print('does not exist!") + + +Creating a new schema +--------------------- + +With a custom name +****************** + +.. code-block:: python + + for psqlextra.schema import PostgresSchema + + # will raise an error if the schema already exists + schema = PostgresSchema.create("myschema") + + +Re-create if necessary with a custom name +***************************************** + +.. warning:: + + If the schema already exists and it is non-empty or something is referencing it, it will **NOT** be dropped. Specify ``cascade=True`` to drop all of the schema's contents and **anything referencing it**. + +.. code-block:: python + + for psqlextra.schema import PostgresSchema + + # will drop existing schema named `myschema` if it + # exists and re-create it + schema = PostgresSchema.drop_and_create("myschema") + + # will drop the schema and cascade it to its contents + # and anything referencing the schema + schema = PostgresSchema.drop_and_create("otherschema", cascade=True) + + +With a random name +****************** + +.. code-block:: python + + for psqlextra.schema import PostgresSchema + + # schema name will be "myprefix_" + schema = PostgresSchema.create_random("myprefix") + print(schema.name) + + +Temporary schema with random name +********************************* + +Use the :meth:`~psqlextra.schema.postgres_temporary_schema` context manager to create a schema with a random name. The schema will only exist within the context manager. + +By default, the schema is not dropped if an exception occurs in the context manager. This prevents unexpected data loss. Specify ``drop_on_throw=True`` to drop the schema if an exception occurs. + +Without an outer transaction, the temporary schema might not be dropped when your program is exits unexpectedly (for example; if it is killed with SIGKILL). Wrap the creation of the schema in a transaction to make sure the schema is cleaned up when an error occurs or your program exits suddenly. + +.. warning:: + + By default, the drop will fail if the schema is not empty or there is anything referencing the schema. Specify ``cascade=True`` to drop all of the schema's contents and **anything referencing it**. + +.. note:: + + +.. code-block:: python + + for psqlextra.schema import postgres_temporary_schema + + with postgres_temporary_schema("myprefix") as schema: + pass + + with postgres_temporary_schema("otherprefix", drop_on_throw=True) as schema: + raise ValueError("drop it like it's hot") + + with postgres_temporary_schema("greatprefix", cascade=True) as schema: + with schema.connection.cursor() as cursor: + cursor.execute(f"CREATE TABLE {schema.name} AS SELECT 'hello'") + + with postgres_temporary_schema("amazingprefix", drop_on_throw=True, cascade=True) as schema: + with schema.connection.cursor() as cursor: + cursor.execute(f"CREATE TABLE {schema.name} AS SELECT 'hello'") + + raise ValueError("oops") + +Deleting a schema +----------------- + +Any schema can be dropped, including ones not created by :class:`~psqlextra.schema.PostgresSchema`. + +The ``public`` schema cannot be dropped. This is a Postgres built-in and it is almost always a mistake to drop it. A :class:`~django.core.exceptions.SuspiciousOperation` erorr will be raised if you attempt to drop the ``public`` schema. + +.. warning:: + + By default, the drop will fail if the schema is not empty or there is anything referencing the schema. Specify ``cascade=True`` to drop all of the schema's contents and **anything referencing it**. + +.. code-block:: python + + for psqlextra.schema import PostgresSchema + + schema = PostgresSchema.drop("myprefix") + schema = PostgresSchema.drop("myprefix", cascade=True) + + +Executing queries within a schema +--------------------------------- + +By default, a connection operates in the ``public`` schema. The schema offers a connection scoped to that schema that sets the Postgres ``search_path`` to only search within that schema. + +.. warning:: + + This can be abused to manage Django models in a custom schema. This is not a supported workflow and there might be unexpected issues from attempting to do so. + +.. warning:: + + Do not pass the connection to a different thread. It is **NOT** thread safe. + +.. code-block:: python + + from psqlextra.schema import PostgresSchema + + schema = PostgresSchema.create("myschema") + + with schema.connection.cursor() as cursor: + # table gets created within the `myschema` schema, without + # explicitly specifying the schema name + cursor.execute("CREATE TABLE mytable AS SELECT 'hello'") + + with schema.connection.schema_editor() as schema_editor: + # creates a table for the model within the schema + schema_editor.create_model(MyModel) diff --git a/psqlextra/backend/introspection.py b/psqlextra/backend/introspection.py index 90717b6a..16b3a8ba 100644 --- a/psqlextra/backend/introspection.py +++ b/psqlextra/backend/introspection.py @@ -69,7 +69,8 @@ def get_partitioned_tables( ) -> PostgresIntrospectedPartitonedTable: """Gets a list of partitioned tables.""" - sql = """ + cursor.execute( + """ SELECT pg_class.relname, pg_partitioned_table.partstrat @@ -80,8 +81,7 @@ def get_partitioned_tables( ON pg_class.oid = pg_partitioned_table.partrelid """ - - cursor.execute(sql) + ) return [ PostgresIntrospectedPartitonedTable( @@ -191,6 +191,19 @@ def get_partition_key(self, cursor, table_name: str) -> List[str]: def get_columns(self, cursor, table_name: str): return self.get_table_description(cursor, table_name) + def get_schema_list(self, cursor) -> List[str]: + """A flat list of available schemas.""" + + sql = """ + SELECT + schema_name + FROM + information_schema.schemata + """ + + cursor.execute(sql, tuple()) + return [name for name, in cursor.fetchall()] + def get_constraints(self, cursor, table_name: str): """Retrieve any constraints or keys (unique, pk, fk, check, index) across one or more columns. diff --git a/psqlextra/backend/schema.py b/psqlextra/backend/schema.py index 1e21b366..85978f05 100644 --- a/psqlextra/backend/schema.py +++ b/psqlextra/backend/schema.py @@ -45,6 +45,9 @@ class PostgresSchemaEditor(SchemaEditor): sql_reset_table_storage_setting = "ALTER TABLE %s RESET (%s)" sql_alter_table_schema = "ALTER TABLE %s SET SCHEMA %s" + sql_create_schema = "CREATE SCHEMA %s" + sql_delete_schema = "DROP SCHEMA %s" + sql_delete_schema_cascade = "DROP SCHEMA %s CASCADE" sql_create_view = "CREATE VIEW %s AS (%s)" sql_replace_view = "CREATE OR REPLACE VIEW %s AS (%s)" @@ -84,6 +87,21 @@ def __init__(self, connection, collect_sql=False, atomic=True): self.deferred_sql = [] self.introspection = PostgresIntrospection(self.connection) + def create_schema(self, name: str) -> None: + """Creates a Postgres schema.""" + + self.execute(self.sql_create_schema % self.quote_name(name)) + + def delete_schema(self, name: str, cascade: bool) -> None: + """Drops a Postgres schema.""" + + sql = ( + self.sql_delete_schema + if not cascade + else self.sql_delete_schema_cascade + ) + self.execute(sql % self.quote_name(name)) + def create_model(self, model: Type[Model]) -> None: """Creates a new model.""" diff --git a/psqlextra/error.py b/psqlextra/error.py new file mode 100644 index 00000000..082a1cfd --- /dev/null +++ b/psqlextra/error.py @@ -0,0 +1,20 @@ +from typing import Optional + +import psycopg2 + +from django import db + + +def extract_postgres_error(error: db.Error) -> Optional[psycopg2.Error]: + """Extracts the underlying :see:psycopg2.Error from the specified Django + database error. + + As per PEP-249, Django wraps all database errors in its own + exception. We can extract the underlying database error by examaning + the cause of the error. + """ + + if not isinstance(error.__cause__, psycopg2.Error): + return None + + return error.__cause__ diff --git a/psqlextra/schema.py b/psqlextra/schema.py new file mode 100644 index 00000000..d2cdabfe --- /dev/null +++ b/psqlextra/schema.py @@ -0,0 +1,212 @@ +from contextlib import contextmanager + +import wrapt + +from django.core.exceptions import SuspiciousOperation, ValidationError +from django.db import DEFAULT_DB_ALIAS, connections, transaction +from django.db.backends.base.base import BaseDatabaseWrapper +from django.db.backends.utils import CursorWrapper +from django.utils import timezone + + +class PostgresSchemaConnectionWrapper(wrapt.ObjectProxy): + """Wraps a Django database connection and ensures that each cursor operates + within the specified schema.""" + + def __init__(self, connection, schema) -> None: + super().__init__(connection) + + self._self_schema = schema + + @contextmanager + def schema_editor(self): + with self.__wrapped__.schema_editor() as schema_editor: + schema_editor.connection = self + yield schema_editor + + @contextmanager + def cursor(self) -> CursorWrapper: + schema = self._self_schema + + with self.__wrapped__.cursor() as cursor: + quoted_name = self.ops.quote_name(schema.name) + cursor.execute(f"SET search_path = {quoted_name}") + try: + yield cursor + finally: + cursor.execute("SET search_path TO DEFAULT") + + +class PostgresSchema: + """Represents a Postgres schema. + + See: https://www.postgresql.org/docs/current/ddl-schemas.html + """ + + NAME_MAX_LENGTH = 63 + + name: str + using: str + + default: "PostgresSchema" + + def __init__(self, name: str, *, using: str = DEFAULT_DB_ALIAS) -> None: + self.name = name + self.using = using + + @classmethod + def create( + cls, name: str, *, using: str = DEFAULT_DB_ALIAS + ) -> "PostgresSchema": + """Creates a new schema with the specified name. + + This throws if the schema already exists as that is most likely + a problem that requires careful handling. Pretending everything + is ok might cause the caller to overwrite data, thinking it got + a empty schema. + + Arguments: + name: + The name to give to the new schema (max 63 characters). + + using: + Name of the database connection to use. + """ + + if len(name) > cls.NAME_MAX_LENGTH: + raise ValidationError( + f"Schema name '{name}' is longer than Postgres's limit of {cls.NAME_MAX_LENGTH} characters" + ) + + with connections[using].schema_editor() as schema_editor: + schema_editor.create_schema(name) + + return cls(name, using=using) + + @classmethod + def create_random( + cls, prefix: str, *, using: str = DEFAULT_DB_ALIAS + ) -> "PostgresSchema": + """Creates a new schema with a random (time-based) suffix. + + Arguments: + prefix: + Name to prefix the final name with. The name plus + prefix cannot be longer than 63 characters. + + using: + Name of the database connection to use. + """ + + name_suffix = timezone.now().strftime("%Y%m%d%H%m%s") + return cls.create(f"{prefix}_{name_suffix}", using=using) + + @classmethod + def delete_and_create( + cls, name: str, *, cascade: bool = False, using: str = DEFAULT_DB_ALIAS + ) -> "PostgresSchema": + """Deletes the schema if it exists before re-creating it. + + Arguments: + name: + Name of the schema to delete+create (max 63 characters). + + cascade: + Whether to delete the contents of the schema + and anything that references it if it exists. + + using: + Name of the database connection to use. + """ + + with transaction.atomic(using=using): + cls(name, using=using).delete(cascade=cascade) + return cls.create(name, using=using) + + @classmethod + def exists(cls, name: str, *, using: str = DEFAULT_DB_ALIAS) -> bool: + """Gets whether a schema with the specified name exists. + + Arguments: + name: + Name of the schema to check of whether it + exists. + + using: + Name of the database connection to use. + """ + + connection = connections[using] + + with connection.cursor() as cursor: + return name in connection.introspection.get_schema_list(cursor) + + def delete(self, *, cascade: bool = False) -> None: + """Deletes the schema and optionally deletes the contents of the schema + and anything that references it. + + Arguments: + cascade: + Cascade the delete to the contents of the schema + and anything that references it. + + If not set, the schema will refuse to be deleted + unless it is empty and there are not remaining + references. + """ + + if self.name == "public": + raise SuspiciousOperation( + "Pretty sure you are about to make a mistake by trying to drop the 'public' schema. I have stopped you. Thank me later." + ) + + with connections[self.using].schema_editor() as schema_editor: + schema_editor.delete_schema(self.name, cascade=cascade) + + @property + def connection(self) -> BaseDatabaseWrapper: + """Obtains a database connection scoped to this schema.""" + + return PostgresSchemaConnectionWrapper(connections[self.using], self) + + +PostgresSchema.default = PostgresSchema("public") + + +@contextmanager +def postgres_temporary_schema( + prefix: str, + *, + cascade: bool = False, + delete_on_throw: bool = False, + using: str = DEFAULT_DB_ALIAS, +) -> PostgresSchema: + """Creates a temporary schema that only lives in the context of this + context manager. + + Arguments: + prefix: + Name to prefix the final name with. + + cascade: + Whether to cascade the delete when dropping the + schema. If enabled, the contents of the schema + are deleted as well as anything that references + the schema. + + delete_on_throw: + Whether to automatically drop the schema if + any error occurs within the context manager. + """ + + schema = PostgresSchema.create_random(prefix, using=using) + + try: + yield schema + except Exception as e: + if delete_on_throw: + schema.delete(cascade=cascade) + + raise e + + schema.delete(cascade=cascade) diff --git a/setup.py b/setup.py index 281be89d..4b1dbf44 100644 --- a/setup.py +++ b/setup.py @@ -67,6 +67,7 @@ def run(self): install_requires=[ "Django>=2.0,<5.0", "python-dateutil>=2.8.0,<=3.0.0", + "wrapt>=1.0,<2.0", ], extras_require={ ':python_version <= "3.6"': ["dataclasses"], diff --git a/tests/test_schema.py b/tests/test_schema.py new file mode 100644 index 00000000..fd966372 --- /dev/null +++ b/tests/test_schema.py @@ -0,0 +1,253 @@ +import uuid + +import freezegun +import pytest + +from django.core.exceptions import SuspiciousOperation, ValidationError +from django.db import ( + DEFAULT_DB_ALIAS, + InternalError, + ProgrammingError, + connection, +) +from psycopg2 import errorcodes + +from psqlextra.error import extract_postgres_error +from psqlextra.schema import PostgresSchema, postgres_temporary_schema + + +def _does_schema_exist(name: str) -> bool: + with connection.cursor() as cursor: + return name in connection.introspection.get_schema_list(cursor) + + +def test_postgres_schema_create(): + schema = PostgresSchema.create("myschema") + assert schema.name == "myschema" + assert schema.using == DEFAULT_DB_ALIAS + + assert _does_schema_exist(schema.name) + + +def test_postgres_schema_does_not_overwrite(): + schema = PostgresSchema.create("myschema") + + with pytest.raises(ProgrammingError): + PostgresSchema.create(schema.name) + + +def test_postgres_schema_create_max_name_length(): + with pytest.raises(ValidationError): + PostgresSchema.create( + "stringthatislongerhtan63charactersforsureabsolutelysurethisislongerthanthat" + ) + + +def test_postgres_schema_create_name_that_requires_escaping(): + # 'table' needs escaping because it conflicts with + # the SQL keyword TABLE + schema = PostgresSchema.create("table") + assert schema.name == "table" + + assert _does_schema_exist("table") + + +def test_postgres_schema_create_random(): + with freezegun.freeze_time("2023-04-07 13:37:00.0"): + schema = PostgresSchema.create_random("myprefix") + + assert schema.name == "myprefix_2023040713041680892620" + assert _does_schema_exist(schema.name) + + +def test_postgres_schema_delete_and_create(): + schema = PostgresSchema.create("test") + + with connection.cursor() as cursor: + cursor.execute("CREATE TABLE test.bla AS SELECT 'hello'") + cursor.execute("SELECT * FROM test.bla") + + assert cursor.fetchone() == ("hello",) + + # Should refuse to delete since we added a table to the schema + with pytest.raises(InternalError) as exc_info: + schema = PostgresSchema.delete_and_create(schema.name) + + pg_error = extract_postgres_error(exc_info.value) + assert pg_error.pgcode == errorcodes.DEPENDENT_OBJECTS_STILL_EXIST + + # Dropping the schema should work with cascade=True + schema = PostgresSchema.delete_and_create(schema.name, cascade=True) + assert _does_schema_exist(schema.name) + + # Since the schema was deleteped and re-created, the `bla` + # table should not exist anymore. + with pytest.raises(ProgrammingError) as exc_info: + with connection.cursor() as cursor: + cursor.execute("SELECT * FROM test.bla") + assert cursor.fetchone() == ("hello",) + + pg_error = extract_postgres_error(exc_info.value) + assert pg_error.pgcode == errorcodes.UNDEFINED_TABLE + + +def test_postgres_schema_delete(): + schema = PostgresSchema.create("test") + assert _does_schema_exist(schema.name) + + schema.delete() + assert not _does_schema_exist(schema.name) + + +def test_postgres_schema_delete_not_empty(): + schema = PostgresSchema.create("test") + assert _does_schema_exist(schema.name) + + with schema.connection.cursor() as cursor: + cursor.execute("CREATE TABLE test.bla AS SELECT 'hello'") + + with pytest.raises(InternalError) as exc_info: + schema.delete() + + pg_error = extract_postgres_error(exc_info.value) + assert pg_error.pgcode == errorcodes.DEPENDENT_OBJECTS_STILL_EXIST + + +def test_postgres_schema_delete_cascade_not_empty(): + schema = PostgresSchema.create("test") + assert _does_schema_exist(schema.name) + + with schema.connection.cursor() as cursor: + cursor.execute("CREATE TABLE test.bla AS SELECT 'hello'") + + schema.delete(cascade=True) + assert not _does_schema_exist(schema.name) + + +def test_postgres_schema_connection(): + schema = PostgresSchema.create("test") + + with schema.connection.cursor() as cursor: + # Creating a table without specifying the schema should create + # it in our schema and we should be able to select from it without + # specifying the schema. + cursor.execute("CREATE TABLE myschematable AS SELECT 'myschema'") + cursor.execute("SELECT * FROM myschematable") + assert cursor.fetchone() == ("myschema",) + + # Proof that the table was created in our schema even though we + # never explicitly told it to do so. + cursor.execute( + "SELECT table_schema FROM information_schema.tables WHERE table_name = %s", + ("myschematable",), + ) + assert cursor.fetchone() == (schema.name,) + + # Creating a table in another schema, we should not be able + # to select it without specifying the schema since our + # schema scoped connection only looks at our schema by default. + cursor.execute( + "CREATE TABLE public.otherschematable AS SELECT 'otherschema'" + ) + with pytest.raises(ProgrammingError) as exc_info: + cursor.execute("SELECT * FROM otherschematable") + + cursor.execute("ROLLBACK") + + pg_error = extract_postgres_error(exc_info.value) + assert pg_error.pgcode == errorcodes.UNDEFINED_TABLE + + +def test_postgres_schema_connection_does_not_affect_default(): + schema = PostgresSchema.create("test") + + with schema.connection.cursor() as cursor: + cursor.execute("SHOW search_path") + assert cursor.fetchone() == ("test",) + + with connection.cursor() as cursor: + cursor.execute("SHOW search_path") + assert cursor.fetchone() == ('"$user", public',) + + +@pytest.mark.django_db(transaction=True) +def test_postgres_schema_connection_does_not_affect_default_after_throw(): + schema = PostgresSchema.create(str(uuid.uuid4())) + + with pytest.raises(ProgrammingError): + with schema.connection.cursor() as cursor: + cursor.execute("COMMIT") + cursor.execute("SELECT frombadtable") + + with connection.cursor() as cursor: + cursor.execute("ROLLBACK") + cursor.execute("SHOW search_path") + assert cursor.fetchone() == ('"$user", public',) + + +def test_postgres_schema_connection_schema_editor(): + schema = PostgresSchema.create("test") + + with schema.connection.schema_editor() as schema_editor: + with schema_editor.connection.cursor() as cursor: + cursor.execute("SHOW search_path") + assert cursor.fetchone() == ("test",) + + with connection.cursor() as cursor: + cursor.execute("SHOW search_path") + assert cursor.fetchone() == ('"$user", public',) + + +def test_postgres_schema_connection_does_not_catch(): + schema = PostgresSchema.create("test") + + with pytest.raises(ValueError): + with schema.connection.cursor(): + raise ValueError("test") + + +def test_postgres_schema_connection_no_delete_default(): + with pytest.raises(SuspiciousOperation): + PostgresSchema.default.delete() + + with pytest.raises(SuspiciousOperation): + PostgresSchema("public").delete() + + +def test_postgres_temporary_schema(): + with freezegun.freeze_time("2023-04-07 13:37:00.0"): + with postgres_temporary_schema("temp") as schema: + assert schema.name == "temp_2023040713041680892620" + + assert _does_schema_exist(schema.name) + + assert not _does_schema_exist(schema.name) + + +def test_postgres_temporary_schema_not_empty(): + with pytest.raises(InternalError) as exc_info: + with postgres_temporary_schema("temp") as schema: + with schema.connection.cursor() as cursor: + cursor.execute("CREATE TABLE mytable AS SELECT 'hello world'") + + pg_error = extract_postgres_error(exc_info.value) + assert pg_error.pgcode == errorcodes.DEPENDENT_OBJECTS_STILL_EXIST + + +def test_postgres_temporary_schema_not_empty_cascade(): + with postgres_temporary_schema("temp", cascade=True) as schema: + with schema.connection.cursor() as cursor: + cursor.execute("CREATE TABLE mytable AS SELECT 'hello world'") + + assert not _does_schema_exist(schema.name) + + +@pytest.mark.parametrize("delete_on_throw", [True, False]) +def test_postgres_temporary_schema_no_delete_on_throw(delete_on_throw): + with pytest.raises(ValueError): + with postgres_temporary_schema( + "temp", delete_on_throw=delete_on_throw + ) as schema: + raise ValueError("test") + + assert _does_schema_exist(schema.name) != delete_on_throw From a79270648c95222c7136a5c0c5de0c4f74a26768 Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Thu, 6 Apr 2023 10:32:24 +0300 Subject: [PATCH 19/43] Document `using` parameter consistently --- psqlextra/schema.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/psqlextra/schema.py b/psqlextra/schema.py index d2cdabfe..a0c32ca9 100644 --- a/psqlextra/schema.py +++ b/psqlextra/schema.py @@ -70,7 +70,7 @@ def create( The name to give to the new schema (max 63 characters). using: - Name of the database connection to use. + Optional name of the database connection to use. """ if len(name) > cls.NAME_MAX_LENGTH: @@ -116,7 +116,7 @@ def delete_and_create( and anything that references it if it exists. using: - Name of the database connection to use. + Optional name of the database connection to use. """ with transaction.atomic(using=using): @@ -133,7 +133,7 @@ def exists(cls, name: str, *, using: str = DEFAULT_DB_ALIAS) -> bool: exists. using: - Name of the database connection to use. + Optional name of the database connection to use. """ connection = connections[using] @@ -197,6 +197,9 @@ def postgres_temporary_schema( delete_on_throw: Whether to automatically drop the schema if any error occurs within the context manager. + + using: + Optional name of the database connection to use. """ schema = PostgresSchema.create_random(prefix, using=using) From 164fbedc05980f57ae94f4224ee7f1c4a32bab2d Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Thu, 6 Apr 2023 10:33:29 +0300 Subject: [PATCH 20/43] Warn about using schema-scoped connections with transaction pooler --- docs/source/schemas.rst | 9 +++++---- psqlextra/schema.py | 10 +++++++++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/docs/source/schemas.rst b/docs/source/schemas.rst index 23f0f91a..5e143f49 100644 --- a/docs/source/schemas.rst +++ b/docs/source/schemas.rst @@ -98,9 +98,6 @@ Without an outer transaction, the temporary schema might not be dropped when you By default, the drop will fail if the schema is not empty or there is anything referencing the schema. Specify ``cascade=True`` to drop all of the schema's contents and **anything referencing it**. -.. note:: - - .. code-block:: python for psqlextra.schema import postgres_temporary_schema @@ -151,7 +148,11 @@ By default, a connection operates in the ``public`` schema. The schema offers a .. warning:: - Do not pass the connection to a different thread. It is **NOT** thread safe. + Do not use this in the following scenarios: + + 1. You access the connection from multiple threads. Scoped connections are **NOT** thread safe. + + 2. The underlying database connection is passed through a connection pooler in transaction pooling mode. .. code-block:: python diff --git a/psqlextra/schema.py b/psqlextra/schema.py index a0c32ca9..b680af56 100644 --- a/psqlextra/schema.py +++ b/psqlextra/schema.py @@ -165,7 +165,15 @@ def delete(self, *, cascade: bool = False) -> None: @property def connection(self) -> BaseDatabaseWrapper: - """Obtains a database connection scoped to this schema.""" + """Obtains a database connection scoped to this schema. + + Do not use this in the following scenarios: + + 1. You access the connection from multiple threads. Scoped + connections are NOT thread safe. + 2. The underlying database connection is passed through a + connection pooler in transaction pooling mode. + """ return PostgresSchemaConnectionWrapper(connections[self.using], self) From 43765fec7ca87a49119be2a33e111d6e4b2c3fcd Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Thu, 6 Apr 2023 10:48:12 +0300 Subject: [PATCH 21/43] Allow for both a random schema name and a time-based one --- docs/source/schemas.rst | 20 +++++++++++++++++++- psqlextra/schema.py | 28 ++++++++++++++++++++++++++-- tests/test_schema.py | 23 +++++++++++++++++------ 3 files changed, 62 insertions(+), 9 deletions(-) diff --git a/docs/source/schemas.rst b/docs/source/schemas.rst index 5e143f49..edf4d929 100644 --- a/docs/source/schemas.rst +++ b/docs/source/schemas.rst @@ -73,14 +73,32 @@ Re-create if necessary with a custom name schema = PostgresSchema.drop_and_create("otherschema", cascade=True) +With a time-based name +********************** + +.. warning:: + + The time-based suffix is precise up to the second. If two threads or processes both try to create a time-based schema name with the same suffix in the same second, they will have conflicts. + +.. code-block:: python + + for psqlextra.schema import PostgresSchema + + # schema name will be "myprefix_" + schema = PostgresSchema.create_time_based("myprefix") + print(schema.name) + + With a random name ****************** +A 8 character suffix is appended. Entropy is dependent on your system. See :meth:`~os.urandom` for more information. + .. code-block:: python for psqlextra.schema import PostgresSchema - # schema name will be "myprefix_" + # schema name will be "myprefix_<8 random characters>" schema = PostgresSchema.create_random("myprefix") print(schema.name) diff --git a/psqlextra/schema.py b/psqlextra/schema.py index b680af56..eeea14a3 100644 --- a/psqlextra/schema.py +++ b/psqlextra/schema.py @@ -1,3 +1,5 @@ +import os + from contextlib import contextmanager import wrapt @@ -84,10 +86,14 @@ def create( return cls(name, using=using) @classmethod - def create_random( + def create_time_based( cls, prefix: str, *, using: str = DEFAULT_DB_ALIAS ) -> "PostgresSchema": - """Creates a new schema with a random (time-based) suffix. + """Creates a new schema with a time-based suffix. + + The time is precise up to the second. Creating + multiple time based schema in the same second + WILL lead to conflicts. Arguments: prefix: @@ -101,6 +107,24 @@ def create_random( name_suffix = timezone.now().strftime("%Y%m%d%H%m%s") return cls.create(f"{prefix}_{name_suffix}", using=using) + @classmethod + def create_random( + cls, prefix: str, *, using: str = DEFAULT_DB_ALIAS + ) -> "PostgresSchema": + """Creates a new schema with a random suffix. + + Arguments: + prefix: + Name to prefix the final name with. The name plus + prefix cannot be longer than 63 characters. + + using: + Name of the database connection to use. + """ + + name_suffix = os.urandom(4).hex() + return cls.create(f"{prefix}_{name_suffix}", using=using) + @classmethod def delete_and_create( cls, name: str, *, cascade: bool = False, using: str = DEFAULT_DB_ALIAS diff --git a/tests/test_schema.py b/tests/test_schema.py index fd966372..e187f967 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -52,14 +52,24 @@ def test_postgres_schema_create_name_that_requires_escaping(): assert _does_schema_exist("table") -def test_postgres_schema_create_random(): +def test_postgres_schema_create_time_based(): with freezegun.freeze_time("2023-04-07 13:37:00.0"): - schema = PostgresSchema.create_random("myprefix") + schema = PostgresSchema.create_time_based("myprefix") assert schema.name == "myprefix_2023040713041680892620" assert _does_schema_exist(schema.name) +def test_postgres_schema_create_random(): + schema = PostgresSchema.create_random("myprefix") + + prefix, suffix = schema.name.split("_") + assert prefix == "myprefix" + assert len(suffix) == 8 + + assert _does_schema_exist(schema.name) + + def test_postgres_schema_delete_and_create(): schema = PostgresSchema.create("test") @@ -215,11 +225,12 @@ def test_postgres_schema_connection_no_delete_default(): def test_postgres_temporary_schema(): - with freezegun.freeze_time("2023-04-07 13:37:00.0"): - with postgres_temporary_schema("temp") as schema: - assert schema.name == "temp_2023040713041680892620" + with postgres_temporary_schema("temp") as schema: + name_prefix, name_suffix = schema.name.split("_") + assert name_prefix == "temp" + assert len(name_suffix) == 8 - assert _does_schema_exist(schema.name) + assert _does_schema_exist(schema.name) assert not _does_schema_exist(schema.name) From b368877873ca67ffe8169b45222b894f9e545d67 Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Thu, 6 Apr 2023 10:56:12 +0300 Subject: [PATCH 22/43] Improve `test_postgres_schema_delete_and_create` Verify that the table and schema still exists after Postgres refused to drop them. --- tests/test_schema.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/test_schema.py b/tests/test_schema.py index e187f967..840dfc38 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -86,11 +86,17 @@ def test_postgres_schema_delete_and_create(): pg_error = extract_postgres_error(exc_info.value) assert pg_error.pgcode == errorcodes.DEPENDENT_OBJECTS_STILL_EXIST + # Verify that the schema and table still exist + assert _does_schema_exist(schema.name) + with connection.cursor() as cursor: + cursor.execute("SELECT * FROM test.bla") + assert cursor.fetchone() == ("hello",) + # Dropping the schema should work with cascade=True schema = PostgresSchema.delete_and_create(schema.name, cascade=True) assert _does_schema_exist(schema.name) - # Since the schema was deleteped and re-created, the `bla` + # Since the schema was deleted and re-created, the `bla` # table should not exist anymore. with pytest.raises(ProgrammingError) as exc_info: with connection.cursor() as cursor: From 6f8fdb48fee8da5ec0fe979be9eb9fb3b70bc42e Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Thu, 6 Apr 2023 11:11:20 +0300 Subject: [PATCH 23/43] Raise specific validation error for schema name prefix + suffix exceeding limit --- psqlextra/schema.py | 21 +++++++++++++++++---- tests/test_schema.py | 18 +++++++++++++++++- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/psqlextra/schema.py b/psqlextra/schema.py index eeea14a3..ec062455 100644 --- a/psqlextra/schema.py +++ b/psqlextra/schema.py @@ -104,8 +104,10 @@ def create_time_based( Name of the database connection to use. """ - name_suffix = timezone.now().strftime("%Y%m%d%H%m%s") - return cls.create(f"{prefix}_{name_suffix}", using=using) + suffix = timezone.now().strftime("%Y%m%d%H%m%s") + cls._verify_generated_name_length(prefix, suffix) + + return cls.create(f"{prefix}_{suffix}", using=using) @classmethod def create_random( @@ -122,8 +124,10 @@ def create_random( Name of the database connection to use. """ - name_suffix = os.urandom(4).hex() - return cls.create(f"{prefix}_{name_suffix}", using=using) + suffix = os.urandom(4).hex() + cls._verify_generated_name_length(prefix, suffix) + + return cls.create(f"{prefix}_{suffix}", using=using) @classmethod def delete_and_create( @@ -201,6 +205,15 @@ def connection(self) -> BaseDatabaseWrapper: return PostgresSchemaConnectionWrapper(connections[self.using], self) + @classmethod + def _verify_generated_name_length(cls, prefix: str, suffix: str) -> None: + max_prefix_length = cls.NAME_MAX_LENGTH - len(suffix) + + if len(prefix) > max_prefix_length: + raise ValidationError( + f"Schema prefix '{prefix}' is longer than {max_prefix_length} characters. Together with the generated suffix of {len(suffix)} characters, the name would exceed Postgres's limit of {cls.NAME_MAX_LENGTH} characters." + ) + PostgresSchema.default = PostgresSchema("public") diff --git a/tests/test_schema.py b/tests/test_schema.py index 840dfc38..9c00f90c 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -37,11 +37,13 @@ def test_postgres_schema_does_not_overwrite(): def test_postgres_schema_create_max_name_length(): - with pytest.raises(ValidationError): + with pytest.raises(ValidationError) as exc_info: PostgresSchema.create( "stringthatislongerhtan63charactersforsureabsolutelysurethisislongerthanthat" ) + assert "is longer than Postgres's limit" in str(exc_info.value) + def test_postgres_schema_create_name_that_requires_escaping(): # 'table' needs escaping because it conflicts with @@ -60,6 +62,13 @@ def test_postgres_schema_create_time_based(): assert _does_schema_exist(schema.name) +def test_postgres_schema_create_time_based_long_prefix(): + with pytest.raises(ValidationError) as exc_info: + PostgresSchema.create_time_based("a" * 100) + + assert "is longer than 55 characters" in str(exc_info.value) + + def test_postgres_schema_create_random(): schema = PostgresSchema.create_random("myprefix") @@ -70,6 +79,13 @@ def test_postgres_schema_create_random(): assert _does_schema_exist(schema.name) +def test_postgres_schema_create_random_long_prefix(): + with pytest.raises(ValidationError) as exc_info: + PostgresSchema.create_random("a" * 100) + + assert "is longer than 55 characters" in str(exc_info.value) + + def test_postgres_schema_delete_and_create(): schema = PostgresSchema.create("test") From 4b9cf2a630a5f9aff207c17ab040a6072dd840c2 Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Thu, 6 Apr 2023 11:26:27 +0300 Subject: [PATCH 24/43] Make time-based schema names stable in length --- psqlextra/schema.py | 2 +- tests/test_schema.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/psqlextra/schema.py b/psqlextra/schema.py index ec062455..3dfb535a 100644 --- a/psqlextra/schema.py +++ b/psqlextra/schema.py @@ -104,7 +104,7 @@ def create_time_based( Name of the database connection to use. """ - suffix = timezone.now().strftime("%Y%m%d%H%m%s") + suffix = timezone.now().strftime("%Y%m%d%H%m%S") cls._verify_generated_name_length(prefix, suffix) return cls.create(f"{prefix}_{suffix}", using=using) diff --git a/tests/test_schema.py b/tests/test_schema.py index 9c00f90c..3c078827 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -55,18 +55,19 @@ def test_postgres_schema_create_name_that_requires_escaping(): def test_postgres_schema_create_time_based(): - with freezegun.freeze_time("2023-04-07 13:37:00.0"): + with freezegun.freeze_time("2023-04-07 13:37:23.4"): schema = PostgresSchema.create_time_based("myprefix") - assert schema.name == "myprefix_2023040713041680892620" + assert schema.name == "myprefix_20230407130423" assert _does_schema_exist(schema.name) def test_postgres_schema_create_time_based_long_prefix(): with pytest.raises(ValidationError) as exc_info: - PostgresSchema.create_time_based("a" * 100) + with freezegun.freeze_time("2023-04-07 13:37:23.4"): + PostgresSchema.create_time_based("a" * 100) - assert "is longer than 55 characters" in str(exc_info.value) + assert "is longer than 49 characters" in str(exc_info.value) def test_postgres_schema_create_random(): From a87eb153965a3d2cc7ae680ac88bdb6b277560c2 Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Thu, 6 Apr 2023 16:16:47 +0300 Subject: [PATCH 25/43] Pass `get_schema_list` SQL directly to cursor --- psqlextra/backend/introspection.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/psqlextra/backend/introspection.py b/psqlextra/backend/introspection.py index 16b3a8ba..0f7daf1a 100644 --- a/psqlextra/backend/introspection.py +++ b/psqlextra/backend/introspection.py @@ -194,14 +194,16 @@ def get_columns(self, cursor, table_name: str): def get_schema_list(self, cursor) -> List[str]: """A flat list of available schemas.""" - sql = """ + cursor.execute( + """ SELECT schema_name FROM information_schema.schemata - """ + """, + tuple(), + ) - cursor.execute(sql, tuple()) return [name for name, in cursor.fetchall()] def get_constraints(self, cursor, table_name: str): From 0094c553e6d96bc5bc48e41663e02e13a10a564d Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Mon, 10 Apr 2023 10:25:33 +0300 Subject: [PATCH 26/43] Remove connections scoped to schema It's too tricky to get this right. It'll lead to surprises because: 1. It breaks with transaction pooling. 2. Interaction with `SET LOCAL` is strange. A `SET` command after a `SET LOCAL` overrides it. I already shot myself in the foot twice since implementing this. --- docs/source/schemas.rst | 33 ------------- psqlextra/schema.py | 46 ------------------ tests/test_schema.py | 101 +++++----------------------------------- 3 files changed, 11 insertions(+), 169 deletions(-) diff --git a/docs/source/schemas.rst b/docs/source/schemas.rst index edf4d929..01fdd345 100644 --- a/docs/source/schemas.rst +++ b/docs/source/schemas.rst @@ -153,36 +153,3 @@ The ``public`` schema cannot be dropped. This is a Postgres built-in and it is a schema = PostgresSchema.drop("myprefix") schema = PostgresSchema.drop("myprefix", cascade=True) - - -Executing queries within a schema ---------------------------------- - -By default, a connection operates in the ``public`` schema. The schema offers a connection scoped to that schema that sets the Postgres ``search_path`` to only search within that schema. - -.. warning:: - - This can be abused to manage Django models in a custom schema. This is not a supported workflow and there might be unexpected issues from attempting to do so. - -.. warning:: - - Do not use this in the following scenarios: - - 1. You access the connection from multiple threads. Scoped connections are **NOT** thread safe. - - 2. The underlying database connection is passed through a connection pooler in transaction pooling mode. - -.. code-block:: python - - from psqlextra.schema import PostgresSchema - - schema = PostgresSchema.create("myschema") - - with schema.connection.cursor() as cursor: - # table gets created within the `myschema` schema, without - # explicitly specifying the schema name - cursor.execute("CREATE TABLE mytable AS SELECT 'hello'") - - with schema.connection.schema_editor() as schema_editor: - # creates a table for the model within the schema - schema_editor.create_model(MyModel) diff --git a/psqlextra/schema.py b/psqlextra/schema.py index 3dfb535a..92686531 100644 --- a/psqlextra/schema.py +++ b/psqlextra/schema.py @@ -2,43 +2,11 @@ from contextlib import contextmanager -import wrapt - from django.core.exceptions import SuspiciousOperation, ValidationError from django.db import DEFAULT_DB_ALIAS, connections, transaction -from django.db.backends.base.base import BaseDatabaseWrapper -from django.db.backends.utils import CursorWrapper from django.utils import timezone -class PostgresSchemaConnectionWrapper(wrapt.ObjectProxy): - """Wraps a Django database connection and ensures that each cursor operates - within the specified schema.""" - - def __init__(self, connection, schema) -> None: - super().__init__(connection) - - self._self_schema = schema - - @contextmanager - def schema_editor(self): - with self.__wrapped__.schema_editor() as schema_editor: - schema_editor.connection = self - yield schema_editor - - @contextmanager - def cursor(self) -> CursorWrapper: - schema = self._self_schema - - with self.__wrapped__.cursor() as cursor: - quoted_name = self.ops.quote_name(schema.name) - cursor.execute(f"SET search_path = {quoted_name}") - try: - yield cursor - finally: - cursor.execute("SET search_path TO DEFAULT") - - class PostgresSchema: """Represents a Postgres schema. @@ -191,20 +159,6 @@ def delete(self, *, cascade: bool = False) -> None: with connections[self.using].schema_editor() as schema_editor: schema_editor.delete_schema(self.name, cascade=cascade) - @property - def connection(self) -> BaseDatabaseWrapper: - """Obtains a database connection scoped to this schema. - - Do not use this in the following scenarios: - - 1. You access the connection from multiple threads. Scoped - connections are NOT thread safe. - 2. The underlying database connection is passed through a - connection pooler in transaction pooling mode. - """ - - return PostgresSchemaConnectionWrapper(connections[self.using], self) - @classmethod def _verify_generated_name_length(cls, prefix: str, suffix: str) -> None: max_prefix_length = cls.NAME_MAX_LENGTH - len(suffix) diff --git a/tests/test_schema.py b/tests/test_schema.py index 3c078827..1c4a24c4 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1,4 +1,3 @@ -import uuid import freezegun import pytest @@ -136,7 +135,7 @@ def test_postgres_schema_delete_not_empty(): schema = PostgresSchema.create("test") assert _does_schema_exist(schema.name) - with schema.connection.cursor() as cursor: + with connection.cursor() as cursor: cursor.execute("CREATE TABLE test.bla AS SELECT 'hello'") with pytest.raises(InternalError) as exc_info: @@ -150,96 +149,14 @@ def test_postgres_schema_delete_cascade_not_empty(): schema = PostgresSchema.create("test") assert _does_schema_exist(schema.name) - with schema.connection.cursor() as cursor: + with connection.cursor() as cursor: cursor.execute("CREATE TABLE test.bla AS SELECT 'hello'") schema.delete(cascade=True) assert not _does_schema_exist(schema.name) -def test_postgres_schema_connection(): - schema = PostgresSchema.create("test") - - with schema.connection.cursor() as cursor: - # Creating a table without specifying the schema should create - # it in our schema and we should be able to select from it without - # specifying the schema. - cursor.execute("CREATE TABLE myschematable AS SELECT 'myschema'") - cursor.execute("SELECT * FROM myschematable") - assert cursor.fetchone() == ("myschema",) - - # Proof that the table was created in our schema even though we - # never explicitly told it to do so. - cursor.execute( - "SELECT table_schema FROM information_schema.tables WHERE table_name = %s", - ("myschematable",), - ) - assert cursor.fetchone() == (schema.name,) - - # Creating a table in another schema, we should not be able - # to select it without specifying the schema since our - # schema scoped connection only looks at our schema by default. - cursor.execute( - "CREATE TABLE public.otherschematable AS SELECT 'otherschema'" - ) - with pytest.raises(ProgrammingError) as exc_info: - cursor.execute("SELECT * FROM otherschematable") - - cursor.execute("ROLLBACK") - - pg_error = extract_postgres_error(exc_info.value) - assert pg_error.pgcode == errorcodes.UNDEFINED_TABLE - - -def test_postgres_schema_connection_does_not_affect_default(): - schema = PostgresSchema.create("test") - - with schema.connection.cursor() as cursor: - cursor.execute("SHOW search_path") - assert cursor.fetchone() == ("test",) - - with connection.cursor() as cursor: - cursor.execute("SHOW search_path") - assert cursor.fetchone() == ('"$user", public',) - - -@pytest.mark.django_db(transaction=True) -def test_postgres_schema_connection_does_not_affect_default_after_throw(): - schema = PostgresSchema.create(str(uuid.uuid4())) - - with pytest.raises(ProgrammingError): - with schema.connection.cursor() as cursor: - cursor.execute("COMMIT") - cursor.execute("SELECT frombadtable") - - with connection.cursor() as cursor: - cursor.execute("ROLLBACK") - cursor.execute("SHOW search_path") - assert cursor.fetchone() == ('"$user", public',) - - -def test_postgres_schema_connection_schema_editor(): - schema = PostgresSchema.create("test") - - with schema.connection.schema_editor() as schema_editor: - with schema_editor.connection.cursor() as cursor: - cursor.execute("SHOW search_path") - assert cursor.fetchone() == ("test",) - - with connection.cursor() as cursor: - cursor.execute("SHOW search_path") - assert cursor.fetchone() == ('"$user", public',) - - -def test_postgres_schema_connection_does_not_catch(): - schema = PostgresSchema.create("test") - - with pytest.raises(ValueError): - with schema.connection.cursor(): - raise ValueError("test") - - -def test_postgres_schema_connection_no_delete_default(): +def test_postgres_schema_no_delete_default(): with pytest.raises(SuspiciousOperation): PostgresSchema.default.delete() @@ -261,8 +178,10 @@ def test_postgres_temporary_schema(): def test_postgres_temporary_schema_not_empty(): with pytest.raises(InternalError) as exc_info: with postgres_temporary_schema("temp") as schema: - with schema.connection.cursor() as cursor: - cursor.execute("CREATE TABLE mytable AS SELECT 'hello world'") + with connection.cursor() as cursor: + cursor.execute( + f"CREATE TABLE {schema.name}.mytable AS SELECT 'hello world'" + ) pg_error = extract_postgres_error(exc_info.value) assert pg_error.pgcode == errorcodes.DEPENDENT_OBJECTS_STILL_EXIST @@ -270,8 +189,10 @@ def test_postgres_temporary_schema_not_empty(): def test_postgres_temporary_schema_not_empty_cascade(): with postgres_temporary_schema("temp", cascade=True) as schema: - with schema.connection.cursor() as cursor: - cursor.execute("CREATE TABLE mytable AS SELECT 'hello world'") + with connection.cursor() as cursor: + cursor.execute( + f"CREATE TABLE {schema.name}.mytable AS SELECT 'hello world'" + ) assert not _does_schema_exist(schema.name) From 28b4c9a91f9b2675db45d845b38993971e9ed32a Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Mon, 10 Apr 2023 13:30:41 +0300 Subject: [PATCH 27/43] Move `using` parameter for `PostgresSchema` to individual methods --- psqlextra/schema.py | 18 +++++++++--------- tests/test_schema.py | 9 +-------- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/psqlextra/schema.py b/psqlextra/schema.py index 92686531..5c749362 100644 --- a/psqlextra/schema.py +++ b/psqlextra/schema.py @@ -16,13 +16,11 @@ class PostgresSchema: NAME_MAX_LENGTH = 63 name: str - using: str default: "PostgresSchema" - def __init__(self, name: str, *, using: str = DEFAULT_DB_ALIAS) -> None: + def __init__(self, name: str) -> None: self.name = name - self.using = using @classmethod def create( @@ -51,7 +49,7 @@ def create( with connections[using].schema_editor() as schema_editor: schema_editor.create_schema(name) - return cls(name, using=using) + return cls(name) @classmethod def create_time_based( @@ -116,7 +114,7 @@ def delete_and_create( """ with transaction.atomic(using=using): - cls(name, using=using).delete(cascade=cascade) + cls(name).delete(cascade=cascade, using=using) return cls.create(name, using=using) @classmethod @@ -137,7 +135,9 @@ def exists(cls, name: str, *, using: str = DEFAULT_DB_ALIAS) -> bool: with connection.cursor() as cursor: return name in connection.introspection.get_schema_list(cursor) - def delete(self, *, cascade: bool = False) -> None: + def delete( + self, *, cascade: bool = False, using: str = DEFAULT_DB_ALIAS + ) -> None: """Deletes the schema and optionally deletes the contents of the schema and anything that references it. @@ -156,7 +156,7 @@ def delete(self, *, cascade: bool = False) -> None: "Pretty sure you are about to make a mistake by trying to drop the 'public' schema. I have stopped you. Thank me later." ) - with connections[self.using].schema_editor() as schema_editor: + with connections[using].schema_editor() as schema_editor: schema_editor.delete_schema(self.name, cascade=cascade) @classmethod @@ -207,8 +207,8 @@ def postgres_temporary_schema( yield schema except Exception as e: if delete_on_throw: - schema.delete(cascade=cascade) + schema.delete(cascade=cascade, using=using) raise e - schema.delete(cascade=cascade) + schema.delete(cascade=cascade, using=using) diff --git a/tests/test_schema.py b/tests/test_schema.py index 1c4a24c4..2938749c 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1,14 +1,8 @@ - import freezegun import pytest from django.core.exceptions import SuspiciousOperation, ValidationError -from django.db import ( - DEFAULT_DB_ALIAS, - InternalError, - ProgrammingError, - connection, -) +from django.db import InternalError, ProgrammingError, connection from psycopg2 import errorcodes from psqlextra.error import extract_postgres_error @@ -23,7 +17,6 @@ def _does_schema_exist(name: str) -> bool: def test_postgres_schema_create(): schema = PostgresSchema.create("myschema") assert schema.name == "myschema" - assert schema.using == DEFAULT_DB_ALIAS assert _does_schema_exist(schema.name) From e6d0b8b72f9bc38183ffedc039508e3a34ec44e8 Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Tue, 11 Apr 2023 09:50:20 +0300 Subject: [PATCH 28/43] Remove wrapt dependency, it's unused --- setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.py b/setup.py index 4b1dbf44..281be89d 100644 --- a/setup.py +++ b/setup.py @@ -67,7 +67,6 @@ def run(self): install_requires=[ "Django>=2.0,<5.0", "python-dateutil>=2.8.0,<=3.0.0", - "wrapt>=1.0,<2.0", ], extras_require={ ':python_version <= "3.6"': ["dataclasses"], From 4daceb67e5a61a4821093346f809b1a606ca464c Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Tue, 11 Apr 2023 11:14:52 +0300 Subject: [PATCH 29/43] Make `extract_postgres_error` work with Psycopg 3.1 --- psqlextra/_version.py | 2 +- psqlextra/error.py | 43 ++++++++++++++++++++++++++++++++++++++----- tests/test_schema.py | 18 +++++++++--------- 3 files changed, 48 insertions(+), 15 deletions(-) diff --git a/psqlextra/_version.py b/psqlextra/_version.py index f6bb6f4d..5b7739fb 100644 --- a/psqlextra/_version.py +++ b/psqlextra/_version.py @@ -1 +1 @@ -__version__ = "2.0.4" +__version__ = "2.0.9rc3+swen.4" diff --git a/psqlextra/error.py b/psqlextra/error.py index 082a1cfd..66438a5b 100644 --- a/psqlextra/error.py +++ b/psqlextra/error.py @@ -1,11 +1,21 @@ -from typing import Optional - -import psycopg2 +from typing import Optional, Union from django import db +try: + from psycopg2 import Error as Psycopg2Error +except ImportError: + Psycopg2Error = None + +try: + from psycopg import Error as Psycopg3Error +except ImportError: + Psycopg3Error = None -def extract_postgres_error(error: db.Error) -> Optional[psycopg2.Error]: + +def extract_postgres_error( + error: db.Error, +) -> Optional[Union["Psycopg2Error", "Psycopg3Error"]]: """Extracts the underlying :see:psycopg2.Error from the specified Django database error. @@ -14,7 +24,30 @@ def extract_postgres_error(error: db.Error) -> Optional[psycopg2.Error]: the cause of the error. """ - if not isinstance(error.__cause__, psycopg2.Error): + if (Psycopg2Error and not isinstance(error.__cause__, Psycopg2Error)) and ( + Psycopg3Error and not isinstance(error.__cause__, Psycopg3Error) + ): return None return error.__cause__ + + +def extract_postgres_error_code(error: db.Error) -> Optional[str]: + """Extracts the underlying Postgres error code. + + As per PEP-249, Django wraps all database errors in its own + exception. We can extract the underlying database error by examaning + the cause of the error. + """ + + cause = error.__cause__ + if not cause: + return None + + if Psycopg2Error and isinstance(cause, Psycopg2Error): + return cause.pgcode + + if Psycopg3Error and isinstance(cause, Psycopg3Error): + return cause.sqlstate + + return None diff --git a/tests/test_schema.py b/tests/test_schema.py index 2938749c..2d470e60 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -5,7 +5,7 @@ from django.db import InternalError, ProgrammingError, connection from psycopg2 import errorcodes -from psqlextra.error import extract_postgres_error +from psqlextra.error import extract_postgres_error_code from psqlextra.schema import PostgresSchema, postgres_temporary_schema @@ -92,8 +92,8 @@ def test_postgres_schema_delete_and_create(): with pytest.raises(InternalError) as exc_info: schema = PostgresSchema.delete_and_create(schema.name) - pg_error = extract_postgres_error(exc_info.value) - assert pg_error.pgcode == errorcodes.DEPENDENT_OBJECTS_STILL_EXIST + pg_error = extract_postgres_error_code(exc_info.value) + assert pg_error == errorcodes.DEPENDENT_OBJECTS_STILL_EXIST # Verify that the schema and table still exist assert _does_schema_exist(schema.name) @@ -112,8 +112,8 @@ def test_postgres_schema_delete_and_create(): cursor.execute("SELECT * FROM test.bla") assert cursor.fetchone() == ("hello",) - pg_error = extract_postgres_error(exc_info.value) - assert pg_error.pgcode == errorcodes.UNDEFINED_TABLE + pg_error = extract_postgres_error_code(exc_info.value) + assert pg_error == errorcodes.UNDEFINED_TABLE def test_postgres_schema_delete(): @@ -134,8 +134,8 @@ def test_postgres_schema_delete_not_empty(): with pytest.raises(InternalError) as exc_info: schema.delete() - pg_error = extract_postgres_error(exc_info.value) - assert pg_error.pgcode == errorcodes.DEPENDENT_OBJECTS_STILL_EXIST + pg_error = extract_postgres_error_code(exc_info.value) + assert pg_error == errorcodes.DEPENDENT_OBJECTS_STILL_EXIST def test_postgres_schema_delete_cascade_not_empty(): @@ -176,8 +176,8 @@ def test_postgres_temporary_schema_not_empty(): f"CREATE TABLE {schema.name}.mytable AS SELECT 'hello world'" ) - pg_error = extract_postgres_error(exc_info.value) - assert pg_error.pgcode == errorcodes.DEPENDENT_OBJECTS_STILL_EXIST + pg_error = extract_postgres_error_code(exc_info.value) + assert pg_error == errorcodes.DEPENDENT_OBJECTS_STILL_EXIST def test_postgres_temporary_schema_not_empty_cascade(): From 784e7a1451ddf4df51f4a371f35bbd35d46a1544 Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Wed, 12 Apr 2023 09:27:04 +0300 Subject: [PATCH 30/43] Take separator into account when computing maximum schema name prefix length --- psqlextra/schema.py | 20 ++++++++++++-------- tests/test_schema.py | 8 ++++---- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/psqlextra/schema.py b/psqlextra/schema.py index 5c749362..4ee81cd8 100644 --- a/psqlextra/schema.py +++ b/psqlextra/schema.py @@ -71,9 +71,9 @@ def create_time_based( """ suffix = timezone.now().strftime("%Y%m%d%H%m%S") - cls._verify_generated_name_length(prefix, suffix) + name = cls._create_generated_name(prefix, suffix) - return cls.create(f"{prefix}_{suffix}", using=using) + return cls.create(name, using=using) @classmethod def create_random( @@ -91,9 +91,9 @@ def create_random( """ suffix = os.urandom(4).hex() - cls._verify_generated_name_length(prefix, suffix) + name = cls._create_generated_name(prefix, suffix) - return cls.create(f"{prefix}_{suffix}", using=using) + return cls.create(name, using=using) @classmethod def delete_and_create( @@ -160,14 +160,18 @@ def delete( schema_editor.delete_schema(self.name, cascade=cascade) @classmethod - def _verify_generated_name_length(cls, prefix: str, suffix: str) -> None: - max_prefix_length = cls.NAME_MAX_LENGTH - len(suffix) + def _create_generated_name(cls, prefix: str, suffix: str) -> str: + separator = "_" + generated_name = f"{prefix}{separator}{suffix}" + max_prefix_length = cls.NAME_MAX_LENGTH - len(suffix) - len(separator) - if len(prefix) > max_prefix_length: + if len(generated_name) > cls.NAME_MAX_LENGTH: raise ValidationError( - f"Schema prefix '{prefix}' is longer than {max_prefix_length} characters. Together with the generated suffix of {len(suffix)} characters, the name would exceed Postgres's limit of {cls.NAME_MAX_LENGTH} characters." + f"Schema prefix '{prefix}' is longer than {max_prefix_length} characters. Together with the separator and generated suffix of {len(suffix)} characters, the name would exceed Postgres's limit of {cls.NAME_MAX_LENGTH} characters." ) + return generated_name + PostgresSchema.default = PostgresSchema("public") diff --git a/tests/test_schema.py b/tests/test_schema.py index 2d470e60..7ae4a3f2 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -57,9 +57,9 @@ def test_postgres_schema_create_time_based(): def test_postgres_schema_create_time_based_long_prefix(): with pytest.raises(ValidationError) as exc_info: with freezegun.freeze_time("2023-04-07 13:37:23.4"): - PostgresSchema.create_time_based("a" * 100) + PostgresSchema.create_time_based("a" * 49) - assert "is longer than 49 characters" in str(exc_info.value) + assert "is longer than 48 characters" in str(exc_info.value) def test_postgres_schema_create_random(): @@ -74,9 +74,9 @@ def test_postgres_schema_create_random(): def test_postgres_schema_create_random_long_prefix(): with pytest.raises(ValidationError) as exc_info: - PostgresSchema.create_random("a" * 100) + PostgresSchema.create_random("a" * 55) - assert "is longer than 55 characters" in str(exc_info.value) + assert "is longer than 54 characters" in str(exc_info.value) def test_postgres_schema_delete_and_create(): From 462dc227d0bd8246ca73de60c92d7ad5ccc17fca Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Thu, 20 Apr 2023 11:48:29 +0300 Subject: [PATCH 31/43] Type-check entire code base with mypy Lots of `# type: ignore` in the internals, but the types should be decent enough for consumption. --- .circleci/config.yml | 2 +- psqlextra/backend/base.py | 20 +++++++- psqlextra/backend/base_impl.py | 16 ++++-- psqlextra/backend/introspection.py | 19 +++++-- .../migrations/patched_autodetector.py | 25 ++++++---- psqlextra/backend/migrations/state/model.py | 24 +++++---- .../backend/migrations/state/partitioning.py | 4 +- psqlextra/backend/migrations/state/view.py | 8 +-- psqlextra/backend/operations.py | 2 +- psqlextra/backend/schema.py | 45 ++++++++++------- psqlextra/compiler.py | 25 +++++----- psqlextra/error.py | 21 +++++--- psqlextra/expressions.py | 2 +- .../management/commands/pgmakemigrations.py | 4 +- psqlextra/management/commands/pgpartition.py | 2 +- psqlextra/manager/manager.py | 2 +- psqlextra/models/base.py | 5 +- psqlextra/models/partitioned.py | 6 ++- psqlextra/models/view.py | 20 +++++--- psqlextra/partitioning/config.py | 4 +- psqlextra/partitioning/manager.py | 6 ++- psqlextra/partitioning/partition.py | 6 +-- psqlextra/partitioning/plan.py | 12 +++-- psqlextra/partitioning/range_partition.py | 6 +-- psqlextra/partitioning/shorthands.py | 4 +- psqlextra/query.py | 49 ++++++++++++++----- psqlextra/schema.py | 17 +++++-- psqlextra/settings.py | 16 +++--- psqlextra/sql.py | 22 ++++++--- psqlextra/type_assertions.py | 2 +- psqlextra/util.py | 7 ++- pyproject.toml | 15 ++++++ setup.py | 23 +++++++++ 33 files changed, 315 insertions(+), 126 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index bb545bad..f5ee6a31 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -119,7 +119,7 @@ jobs: steps: - checkout - install-dependencies: - extra: analysis + extra: analysis, test - run: name: Verify command: python setup.py verify diff --git a/psqlextra/backend/base.py b/psqlextra/backend/base.py index 40086da8..5c788a05 100644 --- a/psqlextra/backend/base.py +++ b/psqlextra/backend/base.py @@ -1,5 +1,7 @@ import logging +from typing import TYPE_CHECKING + from django.conf import settings from django.db import ProgrammingError @@ -8,17 +10,31 @@ from .operations import PostgresOperations from .schema import PostgresSchemaEditor +from django.db.backends.postgresql.base import ( # isort:skip + DatabaseWrapper as PostgresDatabaseWrapper, +) + + logger = logging.getLogger(__name__) -class DatabaseWrapper(base_impl.backend()): +if TYPE_CHECKING: + + class Wrapper(PostgresDatabaseWrapper): + pass + +else: + Wrapper = base_impl.backend() + + +class DatabaseWrapper(Wrapper): """Wraps the standard PostgreSQL database back-end. Overrides the schema editor with our custom schema editor and makes sure the `hstore` extension is enabled. """ - SchemaEditorClass = PostgresSchemaEditor + SchemaEditorClass = PostgresSchemaEditor # type: ignore[assignment] introspection_class = PostgresIntrospection ops_class = PostgresOperations diff --git a/psqlextra/backend/base_impl.py b/psqlextra/backend/base_impl.py index 4e2af04c..88bf9278 100644 --- a/psqlextra/backend/base_impl.py +++ b/psqlextra/backend/base_impl.py @@ -3,6 +3,14 @@ from django.conf import settings from django.core.exceptions import ImproperlyConfigured from django.db import DEFAULT_DB_ALIAS, connections +from django.db.backends.postgresql.base import DatabaseWrapper +from django.db.backends.postgresql.introspection import ( # type: ignore[import] + DatabaseIntrospection, +) +from django.db.backends.postgresql.operations import DatabaseOperations +from django.db.backends.postgresql.schema import ( # type: ignore[import] + DatabaseSchemaEditor, +) from django.db.backends.postgresql.base import ( # isort:skip DatabaseWrapper as Psycopg2DatabaseWrapper, @@ -68,13 +76,13 @@ def base_backend_instance(): return base_instance -def backend(): +def backend() -> DatabaseWrapper: """Gets the base class for the database back-end.""" return base_backend_instance().__class__ -def schema_editor(): +def schema_editor() -> DatabaseSchemaEditor: """Gets the base class for the schema editor. We have to use the configured base back-end's schema editor for @@ -84,7 +92,7 @@ def schema_editor(): return base_backend_instance().SchemaEditorClass -def introspection(): +def introspection() -> DatabaseIntrospection: """Gets the base class for the introspection class. We have to use the configured base back-end's introspection class @@ -94,7 +102,7 @@ def introspection(): return base_backend_instance().introspection.__class__ -def operations(): +def operations() -> DatabaseOperations: """Gets the base class for the operations class. We have to use the configured base back-end's operations class for diff --git a/psqlextra/backend/introspection.py b/psqlextra/backend/introspection.py index 0f7daf1a..bd775779 100644 --- a/psqlextra/backend/introspection.py +++ b/psqlextra/backend/introspection.py @@ -1,5 +1,9 @@ from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple + +from django.db.backends.postgresql.introspection import ( # type: ignore[import] + DatabaseIntrospection, +) from psqlextra.types import PostgresPartitioningMethod @@ -45,7 +49,16 @@ def partition_by_name( ) -class PostgresIntrospection(base_impl.introspection()): +if TYPE_CHECKING: + + class Introspection(DatabaseIntrospection): + pass + +else: + Introspection = base_impl.introspection() + + +class PostgresIntrospection(Introspection): """Adds introspection features specific to PostgreSQL.""" # TODO: This class is a mess, both here and in the @@ -66,7 +79,7 @@ class PostgresIntrospection(base_impl.introspection()): def get_partitioned_tables( self, cursor - ) -> PostgresIntrospectedPartitonedTable: + ) -> List[PostgresIntrospectedPartitonedTable]: """Gets a list of partitioned tables.""" cursor.execute( diff --git a/psqlextra/backend/migrations/patched_autodetector.py b/psqlextra/backend/migrations/patched_autodetector.py index cd647fb0..e5ba8938 100644 --- a/psqlextra/backend/migrations/patched_autodetector.py +++ b/psqlextra/backend/migrations/patched_autodetector.py @@ -12,7 +12,7 @@ RenameField, ) from django.db.migrations.autodetector import MigrationAutodetector -from django.db.migrations.operations.base import Operation +from django.db.migrations.operations.fields import FieldOperation from psqlextra.models import ( PostgresMaterializedViewModel, @@ -83,7 +83,7 @@ def rename_field(self, operation: RenameField): return self._transform_view_field_operations(operation) - def _transform_view_field_operations(self, operation: Operation): + def _transform_view_field_operations(self, operation: FieldOperation): """Transforms operations on fields on a (materialized) view into state only operations. @@ -199,9 +199,15 @@ def add_create_partitioned_model(self, operation: CreateModel): ) ) + partitioned_kwargs = { + **kwargs, + "partitioning_options": partitioning_options, + } + self.add( operations.PostgresCreatePartitionedModel( - *args, **kwargs, partitioning_options=partitioning_options + *args, + **partitioned_kwargs, ) ) @@ -231,11 +237,9 @@ def add_create_view_model(self, operation: CreateModel): _, args, kwargs = operation.deconstruct() - self.add( - operations.PostgresCreateViewModel( - *args, **kwargs, view_options=view_options - ) - ) + view_kwargs = {**kwargs, "view_options": view_options} + + self.add(operations.PostgresCreateViewModel(*args, **view_kwargs)) def add_delete_view_model(self, operation: DeleteModel): """Adds a :see:PostgresDeleteViewModel operation to the list of @@ -261,9 +265,12 @@ def add_create_materialized_view_model(self, operation: CreateModel): _, args, kwargs = operation.deconstruct() + view_kwargs = {**kwargs, "view_options": view_options} + self.add( operations.PostgresCreateMaterializedViewModel( - *args, **kwargs, view_options=view_options + *args, + **view_kwargs, ) ) diff --git a/psqlextra/backend/migrations/state/model.py b/psqlextra/backend/migrations/state/model.py index 465b6152..797147f4 100644 --- a/psqlextra/backend/migrations/state/model.py +++ b/psqlextra/backend/migrations/state/model.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Type +from typing import Tuple, Type, cast from django.db.migrations.state import ModelState from django.db.models import Model @@ -17,8 +17,8 @@ class PostgresModelState(ModelState): """ @classmethod - def from_model( - cls, model: PostgresModel, *args, **kwargs + def from_model( # type: ignore[override] + cls, model: Type[PostgresModel], *args, **kwargs ) -> "PostgresModelState": """Creates a new :see:PostgresModelState object from the specified model. @@ -29,28 +29,32 @@ def from_model( We also need to patch up the base class for the model. """ - model_state = super().from_model(model, *args, **kwargs) - model_state = cls._pre_new(model, model_state) + model_state = super().from_model( + cast(Type[Model], model), *args, **kwargs + ) + model_state = cls._pre_new( + model, cast("PostgresModelState", model_state) + ) # django does not add abstract bases as a base in migrations # because it assumes the base does not add anything important # in a migration.. but it does, so we replace the Model # base with the actual base - bases = tuple() + bases: Tuple[Type[Model], ...] = tuple() for base in model_state.bases: if issubclass(base, Model): bases += (cls._get_base_model_class(),) else: bases += (base,) - model_state.bases = bases + model_state.bases = cast(Tuple[Type[Model]], bases) return model_state def clone(self) -> "PostgresModelState": """Gets an exact copy of this :see:PostgresModelState.""" model_state = super().clone() - return self._pre_clone(model_state) + return self._pre_clone(cast(PostgresModelState, model_state)) def render(self, apps): """Renders this state into an actual model.""" @@ -95,7 +99,9 @@ def render(self, apps): @classmethod def _pre_new( - cls, model: PostgresModel, model_state: "PostgresModelState" + cls, + model: Type[PostgresModel], + model_state: "PostgresModelState", ) -> "PostgresModelState": """Called when a new model state is created from the specified model.""" diff --git a/psqlextra/backend/migrations/state/partitioning.py b/psqlextra/backend/migrations/state/partitioning.py index aef7a5e3..e8b9a5eb 100644 --- a/psqlextra/backend/migrations/state/partitioning.py +++ b/psqlextra/backend/migrations/state/partitioning.py @@ -94,7 +94,7 @@ def delete_partition(self, name: str): del self.partitions[name] @classmethod - def _pre_new( + def _pre_new( # type: ignore[override] cls, model: PostgresPartitionedModel, model_state: "PostgresPartitionedModelState", @@ -108,7 +108,7 @@ def _pre_new( ) return model_state - def _pre_clone( + def _pre_clone( # type: ignore[override] self, model_state: "PostgresPartitionedModelState" ) -> "PostgresPartitionedModelState": """Called when this model state is cloned.""" diff --git a/psqlextra/backend/migrations/state/view.py b/psqlextra/backend/migrations/state/view.py index d59b3120..0f5b52eb 100644 --- a/psqlextra/backend/migrations/state/view.py +++ b/psqlextra/backend/migrations/state/view.py @@ -22,8 +22,10 @@ def __init__(self, *args, view_options={}, **kwargs): self.view_options = dict(view_options) @classmethod - def _pre_new( - cls, model: PostgresViewModel, model_state: "PostgresViewModelState" + def _pre_new( # type: ignore[override] + cls, + model: Type[PostgresViewModel], + model_state: "PostgresViewModelState", ) -> "PostgresViewModelState": """Called when a new model state is created from the specified model.""" @@ -31,7 +33,7 @@ def _pre_new( model_state.view_options = dict(model._view_meta.original_attrs) return model_state - def _pre_clone( + def _pre_clone( # type: ignore[override] self, model_state: "PostgresViewModelState" ) -> "PostgresViewModelState": """Called when this model state is cloned.""" diff --git a/psqlextra/backend/operations.py b/psqlextra/backend/operations.py index 52793fac..3bcf1897 100644 --- a/psqlextra/backend/operations.py +++ b/psqlextra/backend/operations.py @@ -9,7 +9,7 @@ from . import base_impl -class PostgresOperations(base_impl.operations()): +class PostgresOperations(base_impl.operations()): # type: ignore[misc] """Simple operations specific to PostgreSQL.""" compiler_module = "psqlextra.compiler" diff --git a/psqlextra/backend/schema.py b/psqlextra/backend/schema.py index 85978f05..28e9211a 100644 --- a/psqlextra/backend/schema.py +++ b/psqlextra/backend/schema.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, Type +from typing import TYPE_CHECKING, Any, List, Optional, Type, cast from unittest import mock import django @@ -10,6 +10,9 @@ ) from django.db import transaction from django.db.backends.ddl_references import Statement +from django.db.backends.postgresql.schema import ( # type: ignore[import] + DatabaseSchemaEditor, +) from django.db.models import Field, Model from psqlextra.settings import ( @@ -26,7 +29,13 @@ HStoreUniqueSchemaEditorSideEffect, ) -SchemaEditor = base_impl.schema_editor() +if TYPE_CHECKING: + + class SchemaEditor(DatabaseSchemaEditor): + pass + +else: + SchemaEditor = base_impl.schema_editor() class PostgresSchemaEditor(SchemaEditor): @@ -72,9 +81,9 @@ class PostgresSchemaEditor(SchemaEditor): sql_delete_partition = "DROP TABLE %s" sql_table_comment = "COMMENT ON TABLE %s IS %s" - side_effects = [ - HStoreUniqueSchemaEditorSideEffect(), - HStoreRequiredSchemaEditorSideEffect(), + side_effects: List[DatabaseSchemaEditor] = [ + cast(DatabaseSchemaEditor, HStoreUniqueSchemaEditorSideEffect()), + cast(DatabaseSchemaEditor, HStoreRequiredSchemaEditorSideEffect()), ] def __init__(self, connection, collect_sql=False, atomic=True): @@ -231,7 +240,7 @@ def clone_model_constraints_and_indexes_to_schema( [schema_name], using=self.connection.alias ): for constraint in model._meta.constraints: - self.add_constraint(model, constraint) + self.add_constraint(model, constraint) # type: ignore[attr-defined] for index in model._meta.indexes: self.add_index(model, index) @@ -246,14 +255,14 @@ def clone_model_constraints_and_indexes_to_schema( model, tuple(), model._meta.index_together ) - for field in model._meta.local_concrete_fields: + for field in model._meta.local_concrete_fields: # type: ignore[attr-defined] # Django creates primary keys later added to the model with # a custom name. We want the name as it was created originally. if field.primary_key: with postgres_reset_local_search_path( using=self.connection.alias ): - [primary_key_name] = self._constraint_names( + [primary_key_name] = self._constraint_names( # type: ignore[attr-defined] model, primary_key=True ) @@ -278,7 +287,7 @@ def clone_model_constraints_and_indexes_to_schema( with postgres_reset_local_search_path( using=self.connection.alias ): - [fk_name] = self._constraint_names( + [fk_name] = self._constraint_names( # type: ignore[attr-defined] model, [field.column], foreign_key=True ) @@ -304,7 +313,7 @@ def clone_model_constraints_and_indexes_to_schema( with postgres_reset_local_search_path( using=self.connection.alias ): - [field_check_name] = self._constraint_names( + [field_check_name] = self._constraint_names( # type: ignore[attr-defined] model, [field.column], check=True, @@ -315,7 +324,7 @@ def clone_model_constraints_and_indexes_to_schema( ) self.execute( - self._create_check_sql( + self._create_check_sql( # type: ignore[attr-defined] model, field_check_name, field_check ) ) @@ -361,7 +370,7 @@ def clone_model_foreign_keys_to_schema( resides. """ - constraint_names = self._constraint_names(model, foreign_key=True) + constraint_names = self._constraint_names(model, foreign_key=True) # type: ignore[attr-defined] with postgres_prepend_local_search_path( [schema_name], using=self.connection.alias @@ -569,7 +578,7 @@ def replace_materialized_view_model(self, model: Type[Model]) -> None: if not constraint_options["definition"]: raise SuspiciousOperation( "Table %s has a constraint '%s' that no definition could be generated for", - (model._meta.db_tabel, constraint_name), + (model._meta.db_table, constraint_name), ) self.execute(constraint_options["definition"]) @@ -597,7 +606,7 @@ def create_partitioned_model(self, model: Type[Model]) -> None: # create a composite key that includes the partitioning key sql = sql.replace(" PRIMARY KEY", "") - if model._meta.pk.name not in meta.key: + if model._meta.pk and model._meta.pk.name not in meta.key: sql = sql[:-1] + ", PRIMARY KEY (%s, %s))" % ( self.quote_name(model._meta.pk.name), partitioning_key_sql, @@ -927,7 +936,9 @@ def vacuum_model( """ columns = [ - field.column for field in fields if field.concrete and field.column + field.column + for field in fields + if getattr(field, "concrete", False) and field.column ] self.vacuum_table(model._meta.db_table, columns, **kwargs) @@ -1080,8 +1091,8 @@ def _clone_model_field(self, field: Field, **overrides) -> Field: cloned_field.model = field.model cloned_field.set_attributes_from_name(field.name) - if cloned_field.remote_field: + if cloned_field.remote_field and field.remote_field: cloned_field.remote_field.model = field.remote_field.model - cloned_field.set_attributes_from_rel() + cloned_field.set_attributes_from_rel() # type: ignore[attr-defined] return cloned_field diff --git a/psqlextra/compiler.py b/psqlextra/compiler.py index be96e50d..12fff3fa 100644 --- a/psqlextra/compiler.py +++ b/psqlextra/compiler.py @@ -71,25 +71,25 @@ def append_caller_to_sql(sql): return sql -class SQLCompiler(django_compiler.SQLCompiler): +class SQLCompiler(django_compiler.SQLCompiler): # type: ignore [attr-defined] def as_sql(self, *args, **kwargs): sql, params = super().as_sql(*args, **kwargs) return append_caller_to_sql(sql), params -class SQLDeleteCompiler(django_compiler.SQLDeleteCompiler): +class SQLDeleteCompiler(django_compiler.SQLDeleteCompiler): # type: ignore [name-defined] def as_sql(self, *args, **kwargs): sql, params = super().as_sql(*args, **kwargs) return append_caller_to_sql(sql), params -class SQLAggregateCompiler(django_compiler.SQLAggregateCompiler): +class SQLAggregateCompiler(django_compiler.SQLAggregateCompiler): # type: ignore [name-defined] def as_sql(self, *args, **kwargs): sql, params = super().as_sql(*args, **kwargs) return append_caller_to_sql(sql), params -class SQLUpdateCompiler(django_compiler.SQLUpdateCompiler): +class SQLUpdateCompiler(django_compiler.SQLUpdateCompiler): # type: ignore [name-defined] """Compiler for SQL UPDATE statements that allows us to use expressions inside HStore values. @@ -146,7 +146,7 @@ def _does_dict_contain_expression(data: dict) -> bool: return False -class SQLInsertCompiler(django_compiler.SQLInsertCompiler): +class SQLInsertCompiler(django_compiler.SQLInsertCompiler): # type: ignore [name-defined] """Compiler for SQL INSERT statements.""" def as_sql(self, *args, **kwargs): @@ -159,7 +159,7 @@ def as_sql(self, *args, **kwargs): return queries -class PostgresInsertOnConflictCompiler(django_compiler.SQLInsertCompiler): +class PostgresInsertOnConflictCompiler(django_compiler.SQLInsertCompiler): # type: ignore [name-defined] """Compiler for SQL INSERT statements.""" def __init__(self, *args, **kwargs): @@ -237,15 +237,15 @@ def _rewrite_insert_on_conflict( update_columns = ", ".join( [ "{0} = EXCLUDED.{0}".format(self.qn(field.column)) - for field in self.query.update_fields + for field in self.query.update_fields # type: ignore[attr-defined] ] ) # build the conflict target, the columns to watch # for conflicts conflict_target = self._build_conflict_target() - index_predicate = self.query.index_predicate - update_condition = self.query.conflict_update_condition + index_predicate = self.query.index_predicate # type: ignore[attr-defined] + update_condition = self.query.conflict_update_condition # type: ignore[attr-defined] rewritten_sql = f"{sql} ON CONFLICT {conflict_target}" @@ -355,12 +355,15 @@ def _get_model_field(self, name: str): field_name = self._normalize_field_name(name) + if not self.query.model: + return None + # 'pk' has special meaning and always refers to the primary # key of a model, we have to respect this de-facto standard behaviour if field_name == "pk" and self.query.model._meta.pk: return self.query.model._meta.pk - for field in self.query.model._meta.local_concrete_fields: + for field in self.query.model._meta.local_concrete_fields: # type: ignore[attr-defined] if field.name == field_name or field.column == field_name: return field @@ -402,7 +405,7 @@ def _format_field_value(self, field_name) -> str: if isinstance(field, RelatedField) and isinstance(value, Model): value = value.pk - return django_compiler.SQLInsertCompiler.prepare_value( + return django_compiler.SQLInsertCompiler.prepare_value( # type: ignore[attr-defined] self, field, # Note: this deliberately doesn't use `pre_save_val` as we don't diff --git a/psqlextra/error.py b/psqlextra/error.py index 66438a5b..b3a5cf83 100644 --- a/psqlextra/error.py +++ b/psqlextra/error.py @@ -1,21 +1,30 @@ -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Type, Union from django import db +if TYPE_CHECKING: + from psycopg2 import Error as _Psycopg2Error + + Psycopg2Error: Optional[Type[_Psycopg2Error]] + + from psycopg import Error as _Psycopg3Error + + Psycopg3Error: Optional[Type[_Psycopg3Error]] + try: - from psycopg2 import Error as Psycopg2Error + from psycopg2 import Error as Psycopg2Error # type: ignore[no-redef] except ImportError: - Psycopg2Error = None + Psycopg2Error = None # type: ignore[misc] try: - from psycopg import Error as Psycopg3Error + from psycopg import Error as Psycopg3Error # type: ignore[no-redef] except ImportError: - Psycopg3Error = None + Psycopg3Error = None # type: ignore[misc] def extract_postgres_error( error: db.Error, -) -> Optional[Union["Psycopg2Error", "Psycopg3Error"]]: +) -> Optional[Union["_Psycopg2Error", "_Psycopg3Error"]]: """Extracts the underlying :see:psycopg2.Error from the specified Django database error. diff --git a/psqlextra/expressions.py b/psqlextra/expressions.py index 75351e68..d9c6bb54 100644 --- a/psqlextra/expressions.py +++ b/psqlextra/expressions.py @@ -140,7 +140,7 @@ def __init__(self, name: str, key: str): def resolve_expression(self, *args, **kwargs): """Resolves the expression into a :see:HStoreColumn expression.""" - original_expression: expressions.Col = super().resolve_expression( + original_expression: expressions.Col = super().resolve_expression( # type: ignore[annotation-unchecked] *args, **kwargs ) expression = HStoreColumn( diff --git a/psqlextra/management/commands/pgmakemigrations.py b/psqlextra/management/commands/pgmakemigrations.py index cdb7131b..7b678855 100644 --- a/psqlextra/management/commands/pgmakemigrations.py +++ b/psqlextra/management/commands/pgmakemigrations.py @@ -1,4 +1,6 @@ -from django.core.management.commands import makemigrations +from django.core.management.commands import ( # type: ignore[attr-defined] + makemigrations, +) from psqlextra.backend.migrations import postgres_patched_migrations diff --git a/psqlextra/management/commands/pgpartition.py b/psqlextra/management/commands/pgpartition.py index 592b57d7..8a6fa636 100644 --- a/psqlextra/management/commands/pgpartition.py +++ b/psqlextra/management/commands/pgpartition.py @@ -57,7 +57,7 @@ def add_arguments(self, parser): default=False, ) - def handle( + def handle( # type: ignore[override] self, dry: bool, yes: bool, diff --git a/psqlextra/manager/manager.py b/psqlextra/manager/manager.py index 4b96e34f..0931b38a 100644 --- a/psqlextra/manager/manager.py +++ b/psqlextra/manager/manager.py @@ -8,7 +8,7 @@ from psqlextra.query import PostgresQuerySet -class PostgresManager(Manager.from_queryset(PostgresQuerySet)): +class PostgresManager(Manager.from_queryset(PostgresQuerySet)): # type: ignore[misc] """Adds support for PostgreSQL specifics.""" use_in_migrations = True diff --git a/psqlextra/models/base.py b/psqlextra/models/base.py index 21caad36..d240237a 100644 --- a/psqlextra/models/base.py +++ b/psqlextra/models/base.py @@ -1,4 +1,7 @@ +from typing import Any + from django.db import models +from django.db.models import Manager from psqlextra.manager import PostgresManager @@ -10,4 +13,4 @@ class Meta: abstract = True base_manager_name = "objects" - objects = PostgresManager() + objects: "Manager[Any]" = PostgresManager() diff --git a/psqlextra/models/partitioned.py b/psqlextra/models/partitioned.py index c03f3e93..f0115367 100644 --- a/psqlextra/models/partitioned.py +++ b/psqlextra/models/partitioned.py @@ -1,3 +1,5 @@ +from typing import Iterable + from django.db.models.base import ModelBase from psqlextra.types import PostgresPartitioningMethod @@ -15,7 +17,7 @@ class PostgresPartitionedModelMeta(ModelBase): """ default_method = PostgresPartitioningMethod.RANGE - default_key = [] + default_key: Iterable[str] = [] def __new__(cls, name, bases, attrs, **kwargs): new_class = super().__new__(cls, name, bases, attrs, **kwargs) @@ -38,6 +40,8 @@ class PostgresPartitionedModel( """Base class for taking advantage of PostgreSQL's 11.x native support for table partitioning.""" + _partitioning_meta: PostgresPartitionedModelOptions + class Meta: abstract = True base_manager_name = "objects" diff --git a/psqlextra/models/view.py b/psqlextra/models/view.py index a9497057..b19f88c8 100644 --- a/psqlextra/models/view.py +++ b/psqlextra/models/view.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast from django.core.exceptions import ImproperlyConfigured from django.db import connections @@ -12,6 +12,9 @@ from .base import PostgresModel from .options import PostgresViewOptions +if TYPE_CHECKING: + from psqlextra.backend.schema import PostgresSchemaEditor + ViewQueryValue = Union[QuerySet, SQLWithParams, SQL] ViewQuery = Optional[Union[ViewQueryValue, Callable[[], ViewQueryValue]]] @@ -77,23 +80,26 @@ def _view_query_as_sql_with_params( " to be a valid `django.db.models.query.QuerySet`" " SQL string, or tuple of SQL string and params." ) - % (model.__name__) + % (model.__class__.__name__) ) # querysets can easily be converted into sql, params if is_query_set(view_query): - return view_query.query.sql_with_params() + return cast("QuerySet[Any]", view_query).query.sql_with_params() # query was already specified in the target format if is_sql_with_params(view_query): - return view_query + return cast(SQLWithParams, view_query) - return view_query, tuple() + view_query_sql = cast(str, view_query) + return view_query_sql, tuple() class PostgresViewModel(PostgresModel, metaclass=PostgresViewModelMeta): """Base class for creating a model that is a view.""" + _view_meta: PostgresViewOptions + class Meta: abstract = True base_manager_name = "objects" @@ -127,4 +133,6 @@ def refresh( conn_name = using or "default" with connections[conn_name].schema_editor() as schema_editor: - schema_editor.refresh_materialized_view_model(cls, concurrently) + cast( + "PostgresSchemaEditor", schema_editor + ).refresh_materialized_view_model(cls, concurrently) diff --git a/psqlextra/partitioning/config.py b/psqlextra/partitioning/config.py index df21c057..976bf1ae 100644 --- a/psqlextra/partitioning/config.py +++ b/psqlextra/partitioning/config.py @@ -1,3 +1,5 @@ +from typing import Type + from psqlextra.models import PostgresPartitionedModel from .strategy import PostgresPartitioningStrategy @@ -9,7 +11,7 @@ class PostgresPartitioningConfig: def __init__( self, - model: PostgresPartitionedModel, + model: Type[PostgresPartitionedModel], strategy: PostgresPartitioningStrategy, ) -> None: self.model = model diff --git a/psqlextra/partitioning/manager.py b/psqlextra/partitioning/manager.py index 4dcbb599..074cc1c6 100644 --- a/psqlextra/partitioning/manager.py +++ b/psqlextra/partitioning/manager.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Type from django.db import connections @@ -111,7 +111,9 @@ def _plan_for_config( return model_plan @staticmethod - def _get_partitioned_table(connection, model: PostgresPartitionedModel): + def _get_partitioned_table( + connection, model: Type[PostgresPartitionedModel] + ): with connection.cursor() as cursor: table = connection.introspection.get_partitioned_table( cursor, model._meta.db_table diff --git a/psqlextra/partitioning/partition.py b/psqlextra/partitioning/partition.py index ca64bbdc..4c13fda0 100644 --- a/psqlextra/partitioning/partition.py +++ b/psqlextra/partitioning/partition.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Optional +from typing import Optional, Type from psqlextra.backend.schema import PostgresSchemaEditor from psqlextra.models import PostgresPartitionedModel @@ -15,7 +15,7 @@ def name(self) -> str: @abstractmethod def create( self, - model: PostgresPartitionedModel, + model: Type[PostgresPartitionedModel], schema_editor: PostgresSchemaEditor, comment: Optional[str] = None, ) -> None: @@ -24,7 +24,7 @@ def create( @abstractmethod def delete( self, - model: PostgresPartitionedModel, + model: Type[PostgresPartitionedModel], schema_editor: PostgresSchemaEditor, ) -> None: """Deletes this partition from the database.""" diff --git a/psqlextra/partitioning/plan.py b/psqlextra/partitioning/plan.py index 31746360..3fcac44d 100644 --- a/psqlextra/partitioning/plan.py +++ b/psqlextra/partitioning/plan.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional, cast from django.db import connections, transaction @@ -7,6 +7,9 @@ from .constants import AUTO_PARTITIONED_COMMENT from .partition import PostgresPartition +if TYPE_CHECKING: + from psqlextra.backend.schema import PostgresSchemaEditor + @dataclass class PostgresModelPartitioningPlan: @@ -38,12 +41,15 @@ def apply(self, using: Optional[str]) -> None: for partition in self.creations: partition.create( self.config.model, - schema_editor, + cast("PostgresSchemaEditor", schema_editor), comment=AUTO_PARTITIONED_COMMENT, ) for partition in self.deletions: - partition.delete(self.config.model, schema_editor) + partition.delete( + self.config.model, + cast("PostgresSchemaEditor", schema_editor), + ) def print(self) -> None: """Prints this model plan to the terminal in a readable format.""" diff --git a/psqlextra/partitioning/range_partition.py b/psqlextra/partitioning/range_partition.py index b49fe784..a2f3e82f 100644 --- a/psqlextra/partitioning/range_partition.py +++ b/psqlextra/partitioning/range_partition.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any, Optional, Type from psqlextra.backend.schema import PostgresSchemaEditor from psqlextra.models import PostgresPartitionedModel @@ -23,7 +23,7 @@ def deconstruct(self) -> dict: def create( self, - model: PostgresPartitionedModel, + model: Type[PostgresPartitionedModel], schema_editor: PostgresSchemaEditor, comment: Optional[str] = None, ) -> None: @@ -37,7 +37,7 @@ def create( def delete( self, - model: PostgresPartitionedModel, + model: Type[PostgresPartitionedModel], schema_editor: PostgresSchemaEditor, ) -> None: schema_editor.delete_partition(model, self.name()) diff --git a/psqlextra/partitioning/shorthands.py b/psqlextra/partitioning/shorthands.py index dab65e4f..30175273 100644 --- a/psqlextra/partitioning/shorthands.py +++ b/psqlextra/partitioning/shorthands.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Type from dateutil.relativedelta import relativedelta @@ -10,7 +10,7 @@ def partition_by_current_time( - model: PostgresPartitionedModel, + model: Type[PostgresPartitionedModel], count: int, years: Optional[int] = None, months: Optional[int] = None, diff --git a/psqlextra/query.py b/psqlextra/query.py index 2f117e3d..5c5e6f47 100644 --- a/psqlextra/query.py +++ b/psqlextra/query.py @@ -1,10 +1,20 @@ from collections import OrderedDict from itertools import chain -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Dict, + Generic, + Iterable, + List, + Optional, + Tuple, + TypeVar, + Union, +) from django.core.exceptions import SuspiciousOperation from django.db import connections, models, router -from django.db.models import Expression, Q +from django.db.models import Expression, Q, QuerySet from django.db.models.fields import NOT_PROVIDED from .sql import PostgresInsertQuery, PostgresQuery @@ -13,7 +23,17 @@ ConflictTarget = List[Union[str, Tuple[str]]] -class PostgresQuerySet(models.QuerySet): +TModel = TypeVar("TModel", bound=models.Model, covariant=True) + +if TYPE_CHECKING: + from typing_extensions import Self + + QuerySetBase = QuerySet[TModel] +else: + QuerySetBase = QuerySet + + +class PostgresQuerySet(QuerySetBase, Generic[TModel]): """Adds support for PostgreSQL specifics.""" def __init__(self, model=None, query=None, using=None, hints=None): @@ -28,7 +48,7 @@ def __init__(self, model=None, query=None, using=None, hints=None): self.conflict_update_condition = None self.index_predicate = None - def annotate(self, **annotations): + def annotate(self, **annotations) -> "Self": # type: ignore[valid-type, override] """Custom version of the standard annotate function that allows using field names as annotated fields. @@ -112,7 +132,7 @@ def on_conflict( def bulk_insert( self, - rows: List[dict], + rows: Iterable[dict], return_model: bool = False, using: Optional[str] = None, ): @@ -202,7 +222,10 @@ def insert(self, using: Optional[str] = None, **fields): compiler = self._build_insert_compiler([fields], using=using) rows = compiler.execute_sql(return_id=True) - _, pk_db_column = self.model._meta.pk.get_attname_column() + if not self.model or not self.model.pk: + return None + + _, pk_db_column = self.model._meta.pk.get_attname_column() # type: ignore[union-attr] if not rows or len(rows) == 0: return None @@ -245,7 +268,7 @@ def insert_and_get(self, using: Optional[str] = None, **fields): # preserve the fact that the attribute name # might be different than the database column name model_columns = {} - for field in self.model._meta.local_concrete_fields: + for field in self.model._meta.local_concrete_fields: # type: ignore[attr-defined] model_columns[field.column] = field.attname # strip out any columns/fields returned by the db that @@ -298,7 +321,9 @@ def upsert( index_predicate=index_predicate, update_condition=update_condition, ) - return self.insert(**fields, using=using) + + kwargs = {**fields, "using": using} + return self.insert(**kwargs) def upsert_and_get( self, @@ -340,7 +365,9 @@ def upsert_and_get( index_predicate=index_predicate, update_condition=update_condition, ) - return self.insert_and_get(**fields, using=using) + + kwargs = {**fields, "using": using} + return self.insert_and_get(**kwargs) def bulk_upsert( self, @@ -403,7 +430,7 @@ def _create_model_instance( if apply_converters: connection = connections[using] - for field in self.model._meta.local_concrete_fields: + for field in self.model._meta.local_concrete_fields: # type: ignore[attr-defined] if field.attname not in converted_field_values: continue @@ -447,7 +474,7 @@ def _build_insert_compiler( # ask the db router which connection to use using = ( - using or self._db or router.db_for_write(self.model, **self._hints) + using or self._db or router.db_for_write(self.model, **self._hints) # type: ignore[attr-defined] ) # create model objects, we also have to detect cases diff --git a/psqlextra/schema.py b/psqlextra/schema.py index 4ee81cd8..9edb83bd 100644 --- a/psqlextra/schema.py +++ b/psqlextra/schema.py @@ -1,11 +1,16 @@ import os from contextlib import contextmanager +from typing import TYPE_CHECKING, Generator, cast from django.core.exceptions import SuspiciousOperation, ValidationError from django.db import DEFAULT_DB_ALIAS, connections, transaction from django.utils import timezone +if TYPE_CHECKING: + from psqlextra.backend.introspection import PostgresIntrospection + from psqlextra.backend.schema import PostgresSchemaEditor + class PostgresSchema: """Represents a Postgres schema. @@ -47,7 +52,7 @@ def create( ) with connections[using].schema_editor() as schema_editor: - schema_editor.create_schema(name) + cast("PostgresSchemaEditor", schema_editor).create_schema(name) return cls(name) @@ -133,7 +138,9 @@ def exists(cls, name: str, *, using: str = DEFAULT_DB_ALIAS) -> bool: connection = connections[using] with connection.cursor() as cursor: - return name in connection.introspection.get_schema_list(cursor) + return name in cast( + "PostgresIntrospection", connection.introspection + ).get_schema_list(cursor) def delete( self, *, cascade: bool = False, using: str = DEFAULT_DB_ALIAS @@ -157,7 +164,9 @@ def delete( ) with connections[using].schema_editor() as schema_editor: - schema_editor.delete_schema(self.name, cascade=cascade) + cast("PostgresSchemaEditor", schema_editor).delete_schema( + self.name, cascade=cascade + ) @classmethod def _create_generated_name(cls, prefix: str, suffix: str) -> str: @@ -183,7 +192,7 @@ def postgres_temporary_schema( cascade: bool = False, delete_on_throw: bool = False, using: str = DEFAULT_DB_ALIAS, -) -> PostgresSchema: +) -> Generator[PostgresSchema, None, None]: """Creates a temporary schema that only lives in the context of this context manager. diff --git a/psqlextra/settings.py b/psqlextra/settings.py index 6dd32f37..6f75c779 100644 --- a/psqlextra/settings.py +++ b/psqlextra/settings.py @@ -1,5 +1,5 @@ from contextlib import contextmanager -from typing import Dict, List, Optional, Union +from typing import Generator, List, Optional, Union from django.core.exceptions import SuspiciousOperation from django.db import DEFAULT_DB_ALIAS, connections @@ -9,8 +9,8 @@ def postgres_set_local( *, using: str = DEFAULT_DB_ALIAS, - **options: Dict[str, Optional[Union[str, int, float, List[str]]]], -) -> None: + **options: Optional[Union[str, int, float, List[str]]], +) -> Generator[None, None, None]: """Sets the specified PostgreSQL options using SET LOCAL so that they apply to the current transacton only. @@ -29,7 +29,7 @@ def postgres_set_local( ) sql = [] - params = [] + params: List[Union[str, int, float, List[str]]] = [] for name, value in options.items(): if value is None: sql.append(f"SET LOCAL {qn(name)} TO DEFAULT") @@ -78,7 +78,7 @@ def postgres_set_local( @contextmanager def postgres_set_local_search_path( search_path: List[str], *, using: str = DEFAULT_DB_ALIAS -) -> None: +) -> Generator[None, None, None]: """Sets the search path to the specified schemas.""" with postgres_set_local(search_path=search_path, using=using): @@ -88,7 +88,7 @@ def postgres_set_local_search_path( @contextmanager def postgres_prepend_local_search_path( search_path: List[str], *, using: str = DEFAULT_DB_ALIAS -) -> None: +) -> Generator[None, None, None]: """Prepends the current local search path with the specified schemas.""" connection = connections[using] @@ -111,7 +111,9 @@ def postgres_prepend_local_search_path( @contextmanager -def postgres_reset_local_search_path(*, using: str = DEFAULT_DB_ALIAS) -> None: +def postgres_reset_local_search_path( + *, using: str = DEFAULT_DB_ALIAS +) -> Generator[None, None, None]: """Resets the local search path to the default.""" with postgres_set_local(search_path=None, using=using): diff --git a/psqlextra/sql.py b/psqlextra/sql.py index 25c8314e..2a5b418e 100644 --- a/psqlextra/sql.py +++ b/psqlextra/sql.py @@ -1,11 +1,11 @@ from collections import OrderedDict -from typing import List, Optional, Tuple +from typing import Optional, Tuple import django from django.core.exceptions import SuspiciousOperation from django.db import connections, models -from django.db.models import sql +from django.db.models import Expression, sql from django.db.models.constants import LOOKUP_SEP from .compiler import PostgresInsertOnConflictCompiler @@ -16,6 +16,8 @@ class PostgresQuery(sql.Query): + select: Tuple[Expression, ...] + def chain(self, klass=None): """Chains this query to another. @@ -68,7 +70,7 @@ def rename_annotations(self, annotations) -> None: self.annotations.clear() self.annotations.update(new_annotations) - def add_fields(self, field_names: List[str], *args, **kwargs) -> None: + def add_fields(self, field_names, *args, **kwargs) -> None: """Adds the given (model) fields to the select set. The field names are added in the order specified. This overrides @@ -100,10 +102,11 @@ def add_fields(self, field_names: List[str], *args, **kwargs) -> None: if len(parts) > 1: column_name, hstore_key = parts[:2] is_hstore, field = self._is_hstore_field(column_name) - if is_hstore: + if self.model and is_hstore: select.append( HStoreColumn( - self.model._meta.db_table or self.model.name, + self.model._meta.db_table + or self.model.__class__.__name__, field, hstore_key, ) @@ -115,7 +118,7 @@ def add_fields(self, field_names: List[str], *args, **kwargs) -> None: super().add_fields(field_names_without_hstore, *args, **kwargs) if len(select) > 0: - self.set_select(self.select + tuple(select)) + self.set_select(list(self.select + tuple(select))) def _is_hstore_field( self, field_name: str @@ -127,8 +130,11 @@ def _is_hstore_field( instance. """ + if not self.model: + return (False, None) + field_instance = None - for field in self.model._meta.local_concrete_fields: + for field in self.model._meta.local_concrete_fields: # type: ignore[attr-defined] if field.name == field_name or field.column == field_name: field_instance = field break @@ -151,7 +157,7 @@ def __init__(self, *args, **kwargs): self.update_fields = [] - def values(self, objs: List, insert_fields: List, update_fields: List = []): + def values(self, objs, insert_fields, update_fields=[]): """Sets the values to be used in this query. Insert fields are fields that are definitely diff --git a/psqlextra/type_assertions.py b/psqlextra/type_assertions.py index 0a7e8608..e18d13be 100644 --- a/psqlextra/type_assertions.py +++ b/psqlextra/type_assertions.py @@ -7,7 +7,7 @@ def is_query_set(value: Any) -> bool: """Gets whether the specified value is a :see:QuerySet.""" - return isinstance(value, QuerySet) + return isinstance(value, QuerySet) # type: ignore[misc] def is_sql(value: Any) -> bool: diff --git a/psqlextra/util.py b/psqlextra/util.py index edc4e955..d0bca000 100644 --- a/psqlextra/util.py +++ b/psqlextra/util.py @@ -1,10 +1,15 @@ from contextlib import contextmanager +from typing import Generator, Type + +from django.db import models from .manager import PostgresManager @contextmanager -def postgres_manager(model): +def postgres_manager( + model: Type[models.Model], +) -> Generator[PostgresManager, None, None]: """Allows you to use the :see:PostgresManager with the specified model instance on the fly. diff --git a/pyproject.toml b/pyproject.toml index 126ae9a3..fb35b3b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,3 +10,18 @@ exclude = ''' )/ ) ''' + +[tool.mypy] +python_version = "3.8" +plugins = ["mypy_django_plugin.main"] +mypy_path = ["stubs", "."] +exclude = "(env|build|dist|migrations)" + +[[tool.mypy.overrides]] +module = [ + "psycopg.*" +] +ignore_missing_imports = true + +[tool.django-stubs] +django_settings_module = "settings" diff --git a/setup.py b/setup.py index 281be89d..311acf11 100644 --- a/setup.py +++ b/setup.py @@ -90,6 +90,15 @@ def run(self): "autopep8==1.6.0", "isort==5.10.0", "docformatter==1.4", + "mypy==1.2.0; python_version > '3.6'", + "mypy==0.971; python_version <= '3.6'", + "django-stubs==1.16.0; python_version > '3.6'", + "django-stubs==1.9.0; python_version <= '3.6'", + "typing-extensions==4.5.0; python_version > '3.6'", + "typing-extensions==4.1.0; python_version <= '3.6'", + "types-dj-database-url==1.3.0.0", + "types-psycopg2==2.9.21.9", + "types-python-dateutil==2.8.19.12", ], "publish": [ "build==0.7.0", @@ -124,6 +133,18 @@ def run(self): ["autopep8", "-i", "-r", "setup.py", "psqlextra", "tests"], ], ), + "lint_types": create_command( + "Type-checks the code", + [ + [ + "mypy", + "--package", + "psqlextra", + "--pretty", + "--show-error-codes", + ], + ], + ), "format": create_command( "Formats the code", [["black", "setup.py", "psqlextra", "tests"]] ), @@ -162,6 +183,7 @@ def run(self): ["python", "setup.py", "sort_imports"], ["python", "setup.py", "lint_fix"], ["python", "setup.py", "lint"], + ["python", "setup.py", "lint_types"], ], ), "verify": create_command( @@ -171,6 +193,7 @@ def run(self): ["python", "setup.py", "format_docstrings_verify"], ["python", "setup.py", "sort_imports_verify"], ["python", "setup.py", "lint"], + ["python", "setup.py", "lint_types"], ], ), "test": create_command( From 446f6883b465ee3e662d2fbba581d1f4658cc254 Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Thu, 20 Apr 2023 16:41:24 +0300 Subject: [PATCH 32/43] Add `py.typed` marker to indicate package as typed --- psqlextra/_version.py | 2 +- psqlextra/py.typed | 0 setup.py | 1 + 3 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 psqlextra/py.typed diff --git a/psqlextra/_version.py b/psqlextra/_version.py index 5b7739fb..e8733fa0 100644 --- a/psqlextra/_version.py +++ b/psqlextra/_version.py @@ -1 +1 @@ -__version__ = "2.0.9rc3+swen.4" +__version__ = "2.0.9rc4" diff --git a/psqlextra/py.typed b/psqlextra/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/setup.py b/setup.py index 311acf11..3f1429ad 100644 --- a/setup.py +++ b/setup.py @@ -40,6 +40,7 @@ def run(self): name="django-postgres-extra", version=__version__, packages=find_packages(exclude=["tests"]), + package_data={"psqlextra": ["py.typed"]}, include_package_data=True, license="MIT License", description="Bringing all of PostgreSQL's awesomeness to Django.", From c451255c5dbebdbeb0cff86cd2b103b29235d404 Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Thu, 20 Apr 2023 16:45:11 +0300 Subject: [PATCH 33/43] Do not double run the tests to publish --- .circleci/config.yml | 8 -------- 1 file changed, 8 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index f5ee6a31..64e60c30 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -195,14 +195,6 @@ workflows: branches: only: /.*/ - publish: - requires: - - test-python36 - - test-python37 - - test-python38 - - test-python39 - - test-python310 - - test-python311 - - analysis filters: tags: only: /^v.*/ From eb4b7ba1d85a7ce43fd4b369af504938bf1852d0 Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Wed, 10 May 2023 16:05:15 +0300 Subject: [PATCH 34/43] Allow specifying a specific constraint to use in `ON CONFLICT` --- docs/source/conflict_handling.rst | 35 +++++++++++++++++++++++++++++++ psqlextra/compiler.py | 19 +++++++++++++++-- psqlextra/query.py | 6 +++++- tests/test_on_conflict_update.py | 30 ++++++++++++++++++++++++++ 4 files changed, 87 insertions(+), 3 deletions(-) diff --git a/docs/source/conflict_handling.rst b/docs/source/conflict_handling.rst index 9edc71b1..89d1a0cc 100644 --- a/docs/source/conflict_handling.rst +++ b/docs/source/conflict_handling.rst @@ -87,6 +87,41 @@ Specifying multiple columns is necessary in case of a constraint that spans mult ) +Specific constraint +******************* + +Alternatively, instead of specifying the columns the constraint you're targetting applies to, you can also specify the exact constraint to use: + +.. code-block:: python + + from django.db import models + from psqlextra.models import PostgresModel + + class MyModel(PostgresModel) + class Meta: + constraints = [ + models.UniqueConstraint( + name="myconstraint", + fields=["first_name", "last_name"] + ), + ] + + first_name = models.CharField(max_length=255) + last_name = models.CharField(max_length=255) + + constraint = next( + constraint + for constraint in MyModel._meta.constraints + if constraint.name == "myconstraint" + ), None) + + obj = ( + MyModel.objects + .on_conflict(constraint, ConflictAction.UPDATE) + .insert_and_get(first_name='Henk', last_name='Jansen') + ) + + HStore keys *********** Catching conflicts in columns with a ``UNIQUE`` constraint on a :class:`~psqlextra.fields.HStoreField` key is also supported: diff --git a/psqlextra/compiler.py b/psqlextra/compiler.py index 12fff3fa..88a65e9a 100644 --- a/psqlextra/compiler.py +++ b/psqlextra/compiler.py @@ -243,11 +243,11 @@ def _rewrite_insert_on_conflict( # build the conflict target, the columns to watch # for conflicts - conflict_target = self._build_conflict_target() + on_conflict_clause = self._build_on_conflict_clause() index_predicate = self.query.index_predicate # type: ignore[attr-defined] update_condition = self.query.conflict_update_condition # type: ignore[attr-defined] - rewritten_sql = f"{sql} ON CONFLICT {conflict_target}" + rewritten_sql = f"{sql} {on_conflict_clause}" if index_predicate: expr_sql, expr_params = self._compile_expression(index_predicate) @@ -270,6 +270,21 @@ def _rewrite_insert_on_conflict( return (rewritten_sql, params) + def _build_on_conflict_clause(self): + if django.VERSION >= (2, 2): + from django.db.models.constraints import BaseConstraint + from django.db.models.indexes import Index + + if isinstance( + self.query.conflict_target, BaseConstraint + ) or isinstance(self.query.conflict_target, Index): + return "ON CONFLICT ON CONSTRAINT %s" % self.qn( + self.query.conflict_target.name + ) + + conflict_target = self._build_conflict_target() + return f"ON CONFLICT {conflict_target}" + def _build_conflict_target(self): """Builds the `conflict_target` for the ON CONFLICT clause.""" diff --git a/psqlextra/query.py b/psqlextra/query.py index 5c5e6f47..b3feec1d 100644 --- a/psqlextra/query.py +++ b/psqlextra/query.py @@ -20,7 +20,11 @@ from .sql import PostgresInsertQuery, PostgresQuery from .types import ConflictAction -ConflictTarget = List[Union[str, Tuple[str]]] +if TYPE_CHECKING: + from django.db.models.constraints import BaseConstraint + from django.db.models.indexes import Index + +ConflictTarget = Union[List[Union[str, Tuple[str]]], "BaseConstraint", "Index"] TModel = TypeVar("TModel", bound=models.Model, covariant=True) diff --git a/tests/test_on_conflict_update.py b/tests/test_on_conflict_update.py index 8425e3d3..b93e5781 100644 --- a/tests/test_on_conflict_update.py +++ b/tests/test_on_conflict_update.py @@ -1,3 +1,4 @@ +import django import pytest from django.db import models @@ -41,6 +42,35 @@ def test_on_conflict_update(): assert obj2.cookies == "choco" +@pytest.mark.skipif( + django.VERSION < (2, 2), + reason="Django < 2.2 doesn't implement constraints", +) +def test_on_conflict_update_by_unique_constraint(): + model = get_fake_model( + { + "title": models.CharField(max_length=255, null=True), + }, + meta_options={ + "constraints": [ + models.UniqueConstraint(name="test_uniq", fields=["title"]), + ], + }, + ) + + constraint = next( + ( + constraint + for constraint in model._meta.constraints + if constraint.name == "test_uniq" + ) + ) + + model.objects.on_conflict(constraint, ConflictAction.UPDATE).insert_and_get( + title="title" + ) + + def test_on_conflict_update_foreign_key_by_object(): """Tests whether simple upsert works correctly when the conflicting field is a foreign key specified as an object.""" From cbf93a314d7a6757eceb25b4c5c186a461b54ee6 Mon Sep 17 00:00:00 2001 From: loicgasser Date: Wed, 31 May 2023 23:18:27 -0400 Subject: [PATCH 35/43] chore(): Add missing python versions in setup.py --- setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.py b/setup.py index 3f1429ad..dd2ef9b1 100644 --- a/setup.py +++ b/setup.py @@ -61,6 +61,8 @@ def run(self): "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Topic :: Internet :: WWW/HTTP", "Topic :: Internet :: WWW/HTTP :: Dynamic Content", ], From 5fcb63f72bcd30ec3fc070b6b983511a4fa6336e Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Wed, 17 Aug 2022 08:56:23 +0300 Subject: [PATCH 36/43] Add ability to set update values, fixes #56 --- docs/source/conflict_handling.rst | 36 +++++++++++++++ psqlextra/compiler.py | 42 ++++++++++++----- psqlextra/query.py | 50 +++++++++++++++++--- psqlextra/sql.py | 21 +++++---- psqlextra/types.py | 3 ++ tests/test_upsert.py | 77 ++++++++++++++++++++++++++++++- 6 files changed, 201 insertions(+), 28 deletions(-) diff --git a/docs/source/conflict_handling.rst b/docs/source/conflict_handling.rst index 89d1a0cc..cb9423a9 100644 --- a/docs/source/conflict_handling.rst +++ b/docs/source/conflict_handling.rst @@ -232,6 +232,42 @@ Alternatively, with Django 3.1 or newer, :class:`~django:django.db.models.Q` obj Q(name__gt=ExcludedCol('priority')) +Update values +""""""""""""" + +Optionally, the fields to update can be overriden. The default is to update the same fields that were specified in the rows to insert. + +Refer to the insert values using the :class:`psqlextra.expressions.ExcludedCol` expression which translates to PostgreSQL's ``EXCLUDED.`` expression. All expressions and features that can be used with Django's :meth:`~django:django.db.models.query.QuerySet.update` can be used here. + +.. warning:: + + Specifying an empty ``update_values`` (``{}``) will transform the query into :attr:`~psqlextra.types.ConflictAction.NOTHING`. Only ``None`` makes the default behaviour kick in of updating all fields that were specified. + +.. code-block:: python + + from django.db.models import F + + from psqlextra.expressions import ExcludedCol + + ( + MyModel + .objects + .on_conflict( + ['name'], + ConflictAction.UPDATE, + update_values=dict( + name=ExcludedCol('name'), + count=F('count') + 1, + ), + ) + .insert( + name='henk', + count=0, + ) + ) + + + ConflictAction.NOTHING ********************** diff --git a/psqlextra/compiler.py b/psqlextra/compiler.py index 88a65e9a..f9760f45 100644 --- a/psqlextra/compiler.py +++ b/psqlextra/compiler.py @@ -104,8 +104,7 @@ def as_sql(self, *args, **kwargs): return append_caller_to_sql(sql), params def _prepare_query_values(self): - """Extra prep on query values by converting dictionaries into. - + """Extra prep on query values by converting dictionaries into :see:HStoreValue expressions. This allows putting expressions in a dictionary. The @@ -234,13 +233,6 @@ def _rewrite_insert_on_conflict( """Rewrites a normal SQL INSERT query to add the 'ON CONFLICT' clause.""" - update_columns = ", ".join( - [ - "{0} = EXCLUDED.{0}".format(self.qn(field.column)) - for field in self.query.update_fields # type: ignore[attr-defined] - ] - ) - # build the conflict target, the columns to watch # for conflicts on_conflict_clause = self._build_on_conflict_clause() @@ -254,10 +246,21 @@ def _rewrite_insert_on_conflict( rewritten_sql += f" WHERE {expr_sql}" params += tuple(expr_params) + # Fallback in case the user didn't specify any update values. We can still + # make the query work if we switch to ConflictAction.NOTHING + if ( + conflict_action == ConflictAction.UPDATE.value + and not self.query.update_values + ): + conflict_action = ConflictAction.NOTHING + rewritten_sql += f" DO {conflict_action}" - if conflict_action == "UPDATE": - rewritten_sql += f" SET {update_columns}" + if conflict_action == ConflictAction.UPDATE.value: + set_sql, sql_params = self._build_set_statement() + + rewritten_sql += f" SET {set_sql}" + params += sql_params if update_condition: expr_sql, expr_params = self._compile_expression( @@ -270,6 +273,23 @@ def _rewrite_insert_on_conflict( return (rewritten_sql, params) + def _build_set_statement(self) -> Tuple[str, tuple]: + """Builds the SET statement for the ON CONFLICT DO UPDATE clause. + + This uses the update compiler to provide full compatibility with + the standard Django's `update(...)`. + """ + + # Local import to work around the circular dependency between + # the compiler and the queries. + from .sql import PostgresUpdateQuery + + query = self.query.chain(PostgresUpdateQuery) + query.add_update_values(self.query.update_values) + + sql, params = query.get_compiler(self.connection.alias).as_sql() + return sql.split("SET")[1].split(" WHERE")[0], tuple(params) + def _build_on_conflict_clause(self): if django.VERSION >= (2, 2): from django.db.models.constraints import BaseConstraint diff --git a/psqlextra/query.py b/psqlextra/query.py index b3feec1d..c68b8ee9 100644 --- a/psqlextra/query.py +++ b/psqlextra/query.py @@ -2,6 +2,7 @@ from itertools import chain from typing import ( TYPE_CHECKING, + Any, Dict, Generic, Iterable, @@ -17,6 +18,7 @@ from django.db.models import Expression, Q, QuerySet from django.db.models.fields import NOT_PROVIDED +from .expressions import ExcludedCol from .sql import PostgresInsertQuery, PostgresQuery from .types import ConflictAction @@ -51,6 +53,7 @@ def __init__(self, model=None, query=None, using=None, hints=None): self.conflict_action = None self.conflict_update_condition = None self.index_predicate = None + self.update_values = None def annotate(self, **annotations) -> "Self": # type: ignore[valid-type, override] """Custom version of the standard annotate function that allows using @@ -108,6 +111,7 @@ def on_conflict( action: ConflictAction, index_predicate: Optional[Union[Expression, Q, str]] = None, update_condition: Optional[Union[Expression, Q, str]] = None, + update_values: Optional[Dict[str, Union[Any, Expression]]] = None, ): """Sets the action to take when conflicts arise when attempting to insert/create a new row. @@ -125,12 +129,18 @@ def on_conflict( update_condition: Only update if this SQL expression evaluates to true. + + update_values: + Optionally, values/expressions to use when rows + conflict. If not specified, all columns specified + in the rows are updated with the values you specified. """ self.conflict_target = fields self.conflict_action = action self.conflict_update_condition = update_condition self.index_predicate = index_predicate + self.update_values = update_values return self @@ -293,6 +303,7 @@ def upsert( index_predicate: Optional[Union[Expression, Q, str]] = None, using: Optional[str] = None, update_condition: Optional[Union[Expression, Q, str]] = None, + update_values: Optional[Dict[str, Union[Any, Expression]]] = None, ) -> int: """Creates a new record or updates the existing one with the specified data. @@ -315,6 +326,11 @@ def upsert( update_condition: Only update if this SQL expression evaluates to true. + update_values: + Optionally, values/expressions to use when rows + conflict. If not specified, all columns specified + in the rows are updated with the values you specified. + Returns: The primary key of the row that was created/updated. """ @@ -324,6 +340,7 @@ def upsert( ConflictAction.UPDATE, index_predicate=index_predicate, update_condition=update_condition, + update_values=update_values, ) kwargs = {**fields, "using": using} @@ -336,6 +353,7 @@ def upsert_and_get( index_predicate: Optional[Union[Expression, Q, str]] = None, using: Optional[str] = None, update_condition: Optional[Union[Expression, Q, str]] = None, + update_values: Optional[Dict[str, Union[Any, Expression]]] = None, ): """Creates a new record or updates the existing one with the specified data and then gets the row. @@ -358,6 +376,11 @@ def upsert_and_get( update_condition: Only update if this SQL expression evaluates to true. + update_values: + Optionally, values/expressions to use when rows + conflict. If not specified, all columns specified + in the rows are updated with the values you specified. + Returns: The model instance representing the row that was created/updated. @@ -368,6 +391,7 @@ def upsert_and_get( ConflictAction.UPDATE, index_predicate=index_predicate, update_condition=update_condition, + update_values=update_values, ) kwargs = {**fields, "using": using} @@ -381,6 +405,7 @@ def bulk_upsert( return_model: bool = False, using: Optional[str] = None, update_condition: Optional[Union[Expression, Q, str]] = None, + update_values: Optional[Dict[str, Union[Any, Expression]]] = None, ): """Creates a set of new records or updates the existing ones with the specified data. @@ -407,6 +432,11 @@ def bulk_upsert( update_condition: Only update if this SQL expression evaluates to true. + update_values: + Optionally, values/expressions to use when rows + conflict. If not specified, all columns specified + in the rows are updated with the values you specified. + Returns: A list of either the dicts of the rows upserted, including the pk or the models of the rows upserted @@ -417,7 +447,9 @@ def bulk_upsert( ConflictAction.UPDATE, index_predicate=index_predicate, update_condition=update_condition, + update_values=update_values, ) + return self.bulk_insert(rows, return_model, using=using) def _create_model_instance( @@ -505,7 +537,11 @@ def _build_insert_compiler( ) # get the fields to be used during update/insert - insert_fields, update_fields = self._get_upsert_fields(first_row) + insert_fields, update_values = self._get_upsert_fields(first_row) + + # allow the user to override what should happen on update + if self.update_values is not None: + update_values = self.update_values # build a normal insert query query = PostgresInsertQuery(self.model) @@ -513,7 +549,7 @@ def _build_insert_compiler( query.conflict_target = self.conflict_target query.conflict_update_condition = self.conflict_update_condition query.index_predicate = self.index_predicate - query.values(objs, insert_fields, update_fields) + query.insert_on_conflict_values(objs, insert_fields, update_values) compiler = query.get_compiler(using) return compiler @@ -578,13 +614,13 @@ def _get_upsert_fields(self, kwargs): model_instance = self.model(**kwargs) insert_fields = [] - update_fields = [] + update_values = {} for field in model_instance._meta.local_concrete_fields: has_default = field.default != NOT_PROVIDED if field.name in kwargs or field.column in kwargs: insert_fields.append(field) - update_fields.append(field) + update_values[field.name] = ExcludedCol(field.column) continue elif has_default: insert_fields.append(field) @@ -595,13 +631,13 @@ def _get_upsert_fields(self, kwargs): # instead of a concrete field, we have to handle that if field.primary_key is True and "pk" in kwargs: insert_fields.append(field) - update_fields.append(field) + update_values[field.name] = ExcludedCol(field.column) continue if self._is_magical_field(model_instance, field, is_insert=True): insert_fields.append(field) if self._is_magical_field(model_instance, field, is_insert=False): - update_fields.append(field) + update_values[field.name] = ExcludedCol(field.column) - return insert_fields, update_fields + return insert_fields, update_values diff --git a/psqlextra/sql.py b/psqlextra/sql.py index 2a5b418e..3ceb5966 100644 --- a/psqlextra/sql.py +++ b/psqlextra/sql.py @@ -1,5 +1,5 @@ from collections import OrderedDict -from typing import Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import django @@ -154,10 +154,14 @@ def __init__(self, *args, **kwargs): self.conflict_action = ConflictAction.UPDATE self.conflict_update_condition = None self.index_predicate = None - - self.update_fields = [] - - def values(self, objs, insert_fields, update_fields=[]): + self.update_values = {} + + def insert_on_conflict_values( + self, + objs: List, + insert_fields: List, + update_values: Dict[str, Union[Any, Expression]] = {}, + ): """Sets the values to be used in this query. Insert fields are fields that are definitely @@ -176,12 +180,13 @@ def values(self, objs, insert_fields, update_fields=[]): insert_fields: The fields to use in the INSERT statement - update_fields: - The fields to only use in the UPDATE statement. + update_values: + Expressions/values to use when a conflict + occurs and an UPDATE is performed. """ self.insert_values(insert_fields, objs, raw=False) - self.update_fields = update_fields + self.update_values = update_values def get_compiler(self, using=None, connection=None): if using: diff --git a/psqlextra/types.py b/psqlextra/types.py index a325fd9e..f1118075 100644 --- a/psqlextra/types.py +++ b/psqlextra/types.py @@ -28,6 +28,9 @@ class ConflictAction(Enum): def all(cls) -> List["ConflictAction"]: return [choice for choice in cls] + def __str__(self) -> str: + return self.value + class PostgresPartitioningMethod(StrEnum): """Methods of partitioning supported by PostgreSQL 11.x native support for diff --git a/tests/test_upsert.py b/tests/test_upsert.py index b9176da1..a53561ce 100644 --- a/tests/test_upsert.py +++ b/tests/test_upsert.py @@ -2,7 +2,7 @@ import pytest from django.db import models -from django.db.models import Q +from django.db.models import F, Q from django.db.models.expressions import CombinedExpression, Value from psqlextra.expressions import ExcludedCol @@ -144,6 +144,54 @@ def test_upsert_with_update_condition(): assert obj1.active +def test_upsert_with_update_values(): + """Tests that the default update values can be overriden with custom + expressions.""" + + model = get_fake_model( + { + "name": models.TextField(unique=True), + "count": models.IntegerField(default=0), + } + ) + + obj1 = model.objects.create(name="joe") + + model.objects.upsert( + conflict_target=["name"], + fields=dict(name="joe"), + update_values=dict( + count=F("count") + 1, + ), + ) + + obj1.refresh_from_db() + assert obj1.count == 1 + + +def test_upsert_with_update_values_empty(): + """Tests that an upsert with an empty dict turns into ON CONFLICT DO + NOTHING.""" + + model = get_fake_model( + { + "name": models.TextField(unique=True), + "count": models.IntegerField(default=0), + } + ) + + obj1 = model.objects.create(name="joe") + + model.objects.upsert( + conflict_target=["name"], + fields=dict(name="joe"), + update_values={}, + ) + + obj1.refresh_from_db() + assert obj1.count == 0 + + @pytest.mark.skipif( django.VERSION < (3, 1), reason="requires django 3.1 or newer" ) @@ -200,7 +248,7 @@ def from_db_value(self, value, expression, connection): assert obj.title == "bye" -def test_upsert_bulk(): +def test_bulk_upsert(): """Tests whether bulk_upsert works properly.""" model = get_fake_model( @@ -337,3 +385,28 @@ def __iter__(self): for index, obj in enumerate(objs, 1): assert isinstance(obj, model) assert obj.id == index + + +def test_bulk_upsert_update_values(): + model = get_fake_model( + { + "name": models.CharField(max_length=255, unique=True), + "count": models.IntegerField(default=0), + } + ) + + model.objects.bulk_create( + [ + model(name="joe"), + model(name="john"), + ] + ) + + objs = model.objects.bulk_upsert( + conflict_target=["name"], + rows=[], + return_model=True, + update_values=dict(count=F("count") + 1), + ) + + assert all([obj for obj in objs if obj.count == 1]) From 92ae69039ac0addce740a0a4092761094d33c5e1 Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Mon, 7 Aug 2023 15:57:10 +0200 Subject: [PATCH 37/43] Tolerate new columns being added to tables during upsert Up until now, upserting to a table which had a column added not know to Django would make the query crash. This commit introduces a more robust mechanism to constructing model instances from query results that tolerates a column being added at the end of the table. --- psqlextra/compiler.py | 39 +-- psqlextra/introspect/__init__.py | 8 + psqlextra/introspect/fields.py | 21 ++ psqlextra/introspect/models.py | 170 +++++++++++++ psqlextra/query.py | 127 ++++------ setup.py | 2 + tests/test_introspect.py | 417 +++++++++++++++++++++++++++++++ tests/test_upsert.py | 67 ++++- 8 files changed, 744 insertions(+), 107 deletions(-) create mode 100644 psqlextra/introspect/__init__.py create mode 100644 psqlextra/introspect/fields.py create mode 100644 psqlextra/introspect/models.py create mode 100644 tests/test_introspect.py diff --git a/psqlextra/compiler.py b/psqlextra/compiler.py index f9760f45..36aad204 100644 --- a/psqlextra/compiler.py +++ b/psqlextra/compiler.py @@ -3,7 +3,7 @@ import sys from collections.abc import Iterable -from typing import Tuple, Union +from typing import TYPE_CHECKING, Tuple, Union, cast import django @@ -12,11 +12,13 @@ from django.db.models import Expression, Model, Q from django.db.models.fields.related import RelatedField from django.db.models.sql import compiler as django_compiler -from django.db.utils import ProgrammingError from .expressions import HStoreValue from .types import ConflictAction +if TYPE_CHECKING: + from .sql import PostgresInsertQuery + def append_caller_to_sql(sql): """Append the caller to SQL queries. @@ -161,6 +163,8 @@ def as_sql(self, *args, **kwargs): class PostgresInsertOnConflictCompiler(django_compiler.SQLInsertCompiler): # type: ignore [name-defined] """Compiler for SQL INSERT statements.""" + query: "PostgresInsertQuery" + def __init__(self, *args, **kwargs): """Initializes a new instance of :see:PostgresInsertOnConflictCompiler.""" @@ -169,6 +173,7 @@ def __init__(self, *args, **kwargs): def as_sql(self, return_id=False, *args, **kwargs): """Builds the SQL INSERT statement.""" + queries = [ self._rewrite_insert(sql, params, return_id) for sql, params in super().as_sql(*args, **kwargs) @@ -176,28 +181,6 @@ def as_sql(self, return_id=False, *args, **kwargs): return queries - def execute_sql(self, return_id=False): - # execute all the generate queries - with self.connection.cursor() as cursor: - rows = [] - for sql, params in self.as_sql(return_id): - cursor.execute(sql, params) - try: - rows.extend(cursor.fetchall()) - except ProgrammingError: - pass - description = cursor.description - - # create a mapping between column names and column value - return [ - { - column.name: row[column_index] - for column_index, column in enumerate(description) - if row - } - for row in rows - ] - def _rewrite_insert(self, sql, params, return_id=False): """Rewrites a formed SQL INSERT query to include the ON CONFLICT clause. @@ -209,9 +192,9 @@ def _rewrite_insert(self, sql, params, return_id=False): params: The parameters passed to the query. - returning: - What to put in the `RETURNING` clause - of the resulting query. + return_id: + Whether to only return the ID or all + columns. Returns: A tuple of the rewritten SQL query and new params. @@ -284,7 +267,7 @@ def _build_set_statement(self) -> Tuple[str, tuple]: # the compiler and the queries. from .sql import PostgresUpdateQuery - query = self.query.chain(PostgresUpdateQuery) + query = cast(PostgresUpdateQuery, self.query.chain(PostgresUpdateQuery)) query.add_update_values(self.query.update_values) sql, params = query.get_compiler(self.connection.alias).as_sql() diff --git a/psqlextra/introspect/__init__.py b/psqlextra/introspect/__init__.py new file mode 100644 index 00000000..bd85935f --- /dev/null +++ b/psqlextra/introspect/__init__.py @@ -0,0 +1,8 @@ +from .fields import inspect_model_local_concrete_fields +from .models import model_from_cursor, models_from_cursor + +__all__ = [ + "models_from_cursor", + "model_from_cursor", + "inspect_model_local_concrete_fields", +] diff --git a/psqlextra/introspect/fields.py b/psqlextra/introspect/fields.py new file mode 100644 index 00000000..27ef28f7 --- /dev/null +++ b/psqlextra/introspect/fields.py @@ -0,0 +1,21 @@ +from typing import List, Type + +from django.db.models import Field, Model + + +def inspect_model_local_concrete_fields(model: Type[Model]) -> List[Field]: + """Gets a complete list of local and concrete fields on a model, these are + fields that directly map to a database colmn directly on the table backing + the model. + + This is similar to Django's `Meta.local_concrete_fields`, which is a + private API. This method utilizes only public APIs. + """ + + local_concrete_fields = [] + + for field in model._meta.get_fields(include_parents=False): + if isinstance(field, Field) and field.column and not field.many_to_many: + local_concrete_fields.append(field) + + return local_concrete_fields diff --git a/psqlextra/introspect/models.py b/psqlextra/introspect/models.py new file mode 100644 index 00000000..cad84e9f --- /dev/null +++ b/psqlextra/introspect/models.py @@ -0,0 +1,170 @@ +from typing import ( + Any, + Dict, + Generator, + Iterable, + List, + Optional, + Type, + TypeVar, + cast, +) + +from django.core.exceptions import FieldDoesNotExist +from django.db import connection, models +from django.db.models import Field, Model +from django.db.models.expressions import Expression + +from .fields import inspect_model_local_concrete_fields + +TModel = TypeVar("TModel", bound=models.Model) + + +def _construct_model( + model: Type[TModel], + columns: Iterable[str], + values: Iterable[Any], + *, + apply_converters: bool = True +) -> TModel: + fields_by_name_and_column = {} + for field in inspect_model_local_concrete_fields(model): + fields_by_name_and_column[field.attname] = field + + if field.db_column: + fields_by_name_and_column[field.db_column] = field + + indexable_columns = list(columns) + + row = {} + + for index, value in enumerate(values): + column = indexable_columns[index] + try: + field = cast(Field, model._meta.get_field(column)) + except FieldDoesNotExist: + field = fields_by_name_and_column[column] + + field_column_expression = field.get_col(model._meta.db_table) + + if apply_converters: + converters = cast(Expression, field).get_db_converters( + connection + ) + connection.ops.get_db_converters(field_column_expression) + + converted_value = value + for converter in converters: + converted_value = converter( + converted_value, + field_column_expression, + connection, + ) + else: + converted_value = value + + row[field.attname] = converted_value + + instance = model(**row) + instance._state.adding = False + instance._state.db = connection.alias + + return instance + + +def models_from_cursor( + model: Type[TModel], cursor, *, related_fields: List[str] = [] +) -> Generator[TModel, None, None]: + """Fetches all rows from a cursor and converts the values into model + instances. + + This is roughly what Django does internally when you do queries. This + goes further than `Model.from_db` as it also applies converters to make + sure that values are converted into their Python equivalent. + + Use this when you've outgrown the ORM and you are writing performant + queries yourself and you need to map the results back into ORM objects. + + Arguments: + model: + Model to construct. + + cursor: + Cursor to read the rows from. + + related_fields: + List of ForeignKey/OneToOneField names that were joined + into the raw query. Use this to achieve the same thing + that Django's `.select_related()` does. + + Field names should be specified in the order that they + are SELECT'd in. + """ + + columns = [col[0] for col in cursor.description] + field_offset = len(inspect_model_local_concrete_fields(model)) + + rows = cursor.fetchmany() + + while rows: + for values in rows: + instance = _construct_model( + model, columns[:field_offset], values[:field_offset] + ) + + for index, related_field_name in enumerate(related_fields): + related_model = model._meta.get_field( + related_field_name + ).related_model + if not related_model: + continue + + related_field_count = len( + inspect_model_local_concrete_fields(related_model) + ) + + # autopep8: off + related_columns = columns[ + field_offset : field_offset + related_field_count # noqa + ] + related_values = values[ + field_offset : field_offset + related_field_count # noqa + ] + # autopep8: one + + if ( + not related_columns + or not related_values + or all([value is None for value in related_values]) + ): + continue + + related_instance = _construct_model( + cast(Type[Model], related_model), + related_columns, + related_values, + ) + instance._state.fields_cache[related_field_name] = related_instance # type: ignore + + field_offset += len( + inspect_model_local_concrete_fields(related_model) + ) + + yield instance + + rows = cursor.fetchmany() + + +def model_from_cursor( + model: Type[TModel], cursor, *, related_fields: List[str] = [] +) -> Optional[TModel]: + return next( + models_from_cursor(model, cursor, related_fields=related_fields), None + ) + + +def model_from_dict( + model: Type[TModel], row: Dict[str, Any], *, apply_converters: bool = True +) -> TModel: + return _construct_model( + model, row.keys(), row.values(), apply_converters=apply_converters + ) diff --git a/psqlextra/query.py b/psqlextra/query.py index c68b8ee9..a014d9a8 100644 --- a/psqlextra/query.py +++ b/psqlextra/query.py @@ -14,11 +14,13 @@ ) from django.core.exceptions import SuspiciousOperation -from django.db import connections, models, router +from django.db import models, router +from django.db.backends.utils import CursorWrapper from django.db.models import Expression, Q, QuerySet from django.db.models.fields import NOT_PROVIDED from .expressions import ExcludedCol +from .introspect import model_from_cursor, models_from_cursor from .sql import PostgresInsertQuery, PostgresQuery from .types import ConflictAction @@ -146,7 +148,7 @@ def on_conflict( def bulk_insert( self, - rows: Iterable[dict], + rows: Iterable[Dict[str, Any]], return_model: bool = False, using: Optional[str] = None, ): @@ -205,14 +207,17 @@ def is_empty(r): deduped_rows.append(row) compiler = self._build_insert_compiler(deduped_rows, using=using) - objs = compiler.execute_sql(return_id=not return_model) - if return_model: - return [ - self._create_model_instance(dict(row, **obj), compiler.using) - for row, obj in zip(deduped_rows, objs) - ] - return [dict(row, **obj) for row, obj in zip(deduped_rows, objs)] + with compiler.connection.cursor() as cursor: + for sql, params in compiler.as_sql(return_id=not return_model): + cursor.execute(sql, params) + + if return_model: + return list(models_from_cursor(self.model, cursor)) + + return self._consume_cursor_as_dicts( + cursor, original_rows=deduped_rows + ) def insert(self, using: Optional[str] = None, **fields): """Creates a new record in the database. @@ -233,17 +238,20 @@ def insert(self, using: Optional[str] = None, **fields): """ if self.conflict_target or self.conflict_action: - compiler = self._build_insert_compiler([fields], using=using) - rows = compiler.execute_sql(return_id=True) - if not self.model or not self.model.pk: return None - _, pk_db_column = self.model._meta.pk.get_attname_column() # type: ignore[union-attr] - if not rows or len(rows) == 0: - return None + compiler = self._build_insert_compiler([fields], using=using) + + with compiler.connection.cursor() as cursor: + for sql, params in compiler.as_sql(return_id=True): + cursor.execute(sql, params) + + row = cursor.fetchone() + if not row: + return None - return rows[0][pk_db_column] + return row[0] # no special action required, use the standard Django create(..) return super().create(**fields).pk @@ -271,30 +279,12 @@ def insert_and_get(self, using: Optional[str] = None, **fields): return super().create(**fields) compiler = self._build_insert_compiler([fields], using=using) - rows = compiler.execute_sql(return_id=False) - - if not rows: - return None - - columns = rows[0] - - # get a list of columns that are officially part of the model and - # preserve the fact that the attribute name - # might be different than the database column name - model_columns = {} - for field in self.model._meta.local_concrete_fields: # type: ignore[attr-defined] - model_columns[field.column] = field.attname - # strip out any columns/fields returned by the db that - # are not present in the model - model_init_fields = {} - for column_name, column_value in columns.items(): - try: - model_init_fields[model_columns[column_name]] = column_value - except KeyError: - pass + with compiler.connection.cursor() as cursor: + for sql, params in compiler.as_sql(return_id=False): + cursor.execute(sql, params) - return self._create_model_instance(model_init_fields, compiler.using) + return model_from_cursor(self.model, cursor) def upsert( self, @@ -452,43 +442,23 @@ def bulk_upsert( return self.bulk_insert(rows, return_model, using=using) - def _create_model_instance( - self, field_values: dict, using: str, apply_converters: bool = True - ): - """Creates a new instance of the model with the specified field. - - Use this after the row was inserted into the database. The new - instance will marked as "saved". - """ - - converted_field_values = field_values.copy() - - if apply_converters: - connection = connections[using] - - for field in self.model._meta.local_concrete_fields: # type: ignore[attr-defined] - if field.attname not in converted_field_values: - continue - - # converters can be defined on the field, or by - # the database back-end we're using - field_column = field.get_col(self.model._meta.db_table) - converters = field.get_db_converters( - connection - ) + connection.ops.get_db_converters(field_column) - - for converter in converters: - converted_field_values[field.attname] = converter( - converted_field_values[field.attname], - field_column, - connection, - ) - - instance = self.model(**converted_field_values) - instance._state.db = using - instance._state.adding = False - - return instance + @staticmethod + def _consume_cursor_as_dicts( + cursor: CursorWrapper, *, original_rows: Iterable[Dict[str, Any]] + ) -> List[dict]: + cursor_description = cursor.description + + return [ + { + **original_row, + **{ + column.name: row[column_index] + for column_index, column in enumerate(cursor_description) + if row + }, + } + for original_row, row in zip(original_rows, cursor) + ] def _build_insert_compiler( self, rows: Iterable[Dict], using: Optional[str] = None @@ -532,9 +502,10 @@ def _build_insert_compiler( ).format(index) ) - objs.append( - self._create_model_instance(row, using, apply_converters=False) - ) + obj = self.model(**row.copy()) + obj._state.db = using + obj._state.adding = False + objs.append(obj) # get the fields to be used during update/insert insert_fields, update_values = self._get_upsert_fields(first_row) diff --git a/setup.py b/setup.py index dd2ef9b1..b3217fb3 100644 --- a/setup.py +++ b/setup.py @@ -81,6 +81,8 @@ def run(self): "pytest-benchmark==3.4.1", "pytest-django==4.4.0", "pytest-cov==3.0.0", + "pytest-lazy-fixture==0.6.3", + "pytest-freezegun==0.4.2", "tox==3.24.4", "freezegun==1.1.0", "coveralls==3.3.0", diff --git a/tests/test_introspect.py b/tests/test_introspect.py new file mode 100644 index 00000000..21e842ac --- /dev/null +++ b/tests/test_introspect.py @@ -0,0 +1,417 @@ +import django +import pytest + +from django.contrib.postgres.fields import ArrayField +from django.db import connection, models +from django.test.utils import CaptureQueriesContext +from django.utils import timezone + +from psqlextra.introspect import model_from_cursor, models_from_cursor + +from .fake_model import get_fake_model + +django_31_skip_reason = "Django < 3.1 does not support JSON fields which are required for these tests" + + +@pytest.fixture +def mocked_model_varying_fields(): + return get_fake_model( + { + "title": models.TextField(null=True), + "updated_at": models.DateTimeField(null=True), + "content": models.JSONField(null=True), + "items": ArrayField(models.TextField(), null=True), + } + ) + + +@pytest.fixture +def mocked_model_single_field(): + return get_fake_model( + { + "name": models.TextField(), + } + ) + + +@pytest.fixture +def mocked_model_foreign_keys( + mocked_model_varying_fields, mocked_model_single_field +): + return get_fake_model( + { + "varying_fields": models.ForeignKey( + mocked_model_varying_fields, null=True, on_delete=models.CASCADE + ), + "single_field": models.ForeignKey( + mocked_model_single_field, null=True, on_delete=models.CASCADE + ), + } + ) + + +@pytest.fixture +def mocked_model_varying_fields_instance(freezer, mocked_model_varying_fields): + return mocked_model_varying_fields.objects.create( + title="hello world", + updated_at=timezone.now(), + content={"a": 1}, + items=["a", "b"], + ) + + +@pytest.fixture +def models_from_cursor_wrapper_multiple(): + def _wrapper(*args, **kwargs): + return list(models_from_cursor(*args, **kwargs))[0] + + return _wrapper + + +@pytest.fixture +def models_from_cursor_wrapper_single(): + return model_from_cursor + + +@pytest.mark.skipif( + django.VERSION < (3, 1), + reason=django_31_skip_reason, +) +@pytest.mark.parametrize( + "models_from_cursor_wrapper", + [ + pytest.lazy_fixture("models_from_cursor_wrapper_multiple"), + pytest.lazy_fixture("models_from_cursor_wrapper_single"), + ], +) +def test_models_from_cursor_applies_converters( + mocked_model_varying_fields, + mocked_model_varying_fields_instance, + models_from_cursor_wrapper, +): + with connection.cursor() as cursor: + cursor.execute( + *mocked_model_varying_fields.objects.all().query.sql_with_params() + ) + queried_instance = models_from_cursor_wrapper( + mocked_model_varying_fields, cursor + ) + + assert queried_instance.id == mocked_model_varying_fields_instance.id + assert queried_instance.title == mocked_model_varying_fields_instance.title + assert ( + queried_instance.updated_at + == mocked_model_varying_fields_instance.updated_at + ) + assert ( + queried_instance.content == mocked_model_varying_fields_instance.content + ) + assert queried_instance.items == mocked_model_varying_fields_instance.items + + +@pytest.mark.skipif( + django.VERSION < (3, 1), + reason=django_31_skip_reason, +) +@pytest.mark.parametrize( + "models_from_cursor_wrapper", + [ + pytest.lazy_fixture("models_from_cursor_wrapper_multiple"), + pytest.lazy_fixture("models_from_cursor_wrapper_single"), + ], +) +def test_models_from_cursor_handles_field_order( + mocked_model_varying_fields, + mocked_model_varying_fields_instance, + models_from_cursor_wrapper, +): + with connection.cursor() as cursor: + cursor.execute( + f'SELECT content, items, id, title, updated_at FROM "{mocked_model_varying_fields._meta.db_table}"', + tuple(), + ) + queried_instance = models_from_cursor_wrapper( + mocked_model_varying_fields, cursor + ) + + assert queried_instance.id == mocked_model_varying_fields_instance.id + assert queried_instance.title == mocked_model_varying_fields_instance.title + assert ( + queried_instance.updated_at + == mocked_model_varying_fields_instance.updated_at + ) + assert ( + queried_instance.content == mocked_model_varying_fields_instance.content + ) + assert queried_instance.items == mocked_model_varying_fields_instance.items + + +@pytest.mark.skipif( + django.VERSION < (3, 1), + reason=django_31_skip_reason, +) +@pytest.mark.parametrize( + "models_from_cursor_wrapper", + [ + pytest.lazy_fixture("models_from_cursor_wrapper_multiple"), + pytest.lazy_fixture("models_from_cursor_wrapper_single"), + ], +) +def test_models_from_cursor_handles_partial_fields( + mocked_model_varying_fields, + mocked_model_varying_fields_instance, + models_from_cursor_wrapper, +): + with connection.cursor() as cursor: + cursor.execute( + f'SELECT id FROM "{mocked_model_varying_fields._meta.db_table}"', + tuple(), + ) + queried_instance = models_from_cursor_wrapper( + mocked_model_varying_fields, cursor + ) + + assert queried_instance.id == mocked_model_varying_fields_instance.id + assert queried_instance.title is None + assert queried_instance.updated_at is None + assert queried_instance.content is None + assert queried_instance.items is None + + +@pytest.mark.skipif( + django.VERSION < (3, 1), + reason=django_31_skip_reason, +) +@pytest.mark.parametrize( + "models_from_cursor_wrapper", + [ + pytest.lazy_fixture("models_from_cursor_wrapper_multiple"), + pytest.lazy_fixture("models_from_cursor_wrapper_single"), + ], +) +def test_models_from_cursor_handles_null( + mocked_model_varying_fields, models_from_cursor_wrapper +): + instance = mocked_model_varying_fields.objects.create() + + with connection.cursor() as cursor: + cursor.execute( + *mocked_model_varying_fields.objects.all().query.sql_with_params() + ) + queried_instance = models_from_cursor_wrapper( + mocked_model_varying_fields, cursor + ) + + assert queried_instance.id == instance.id + assert queried_instance.title is None + assert queried_instance.updated_at is None + assert queried_instance.content is None + assert queried_instance.items is None + + +@pytest.mark.skipif( + django.VERSION < (3, 1), + reason=django_31_skip_reason, +) +@pytest.mark.parametrize( + "models_from_cursor_wrapper", + [ + pytest.lazy_fixture("models_from_cursor_wrapper_multiple"), + pytest.lazy_fixture("models_from_cursor_wrapper_single"), + ], +) +def test_models_from_cursor_foreign_key( + mocked_model_single_field, + mocked_model_foreign_keys, + models_from_cursor_wrapper, +): + instance = mocked_model_foreign_keys.objects.create( + varying_fields=None, + single_field=mocked_model_single_field.objects.create(name="test"), + ) + + with connection.cursor() as cursor: + cursor.execute( + *mocked_model_foreign_keys.objects.all().query.sql_with_params() + ) + queried_instance = models_from_cursor_wrapper( + mocked_model_foreign_keys, cursor + ) + + with CaptureQueriesContext(connection) as ctx: + assert queried_instance.id == instance.id + assert queried_instance.varying_fields_id is None + assert queried_instance.varying_fields is None + assert queried_instance.single_field_id == instance.single_field_id + assert queried_instance.single_field.id == instance.single_field.id + assert queried_instance.single_field.name == instance.single_field.name + + assert len(ctx.captured_queries) == 1 + + +@pytest.mark.skipif( + django.VERSION < (3, 1), + reason=django_31_skip_reason, +) +@pytest.mark.parametrize( + "models_from_cursor_wrapper", + [ + pytest.lazy_fixture("models_from_cursor_wrapper_multiple"), + pytest.lazy_fixture("models_from_cursor_wrapper_single"), + ], +) +def test_models_from_cursor_related_fields( + mocked_model_varying_fields, + mocked_model_single_field, + mocked_model_foreign_keys, + models_from_cursor_wrapper, +): + instance = mocked_model_foreign_keys.objects.create( + varying_fields=mocked_model_varying_fields.objects.create( + title="test", updated_at=timezone.now() + ), + single_field=mocked_model_single_field.objects.create(name="test"), + ) + + with connection.cursor() as cursor: + cursor.execute( + *mocked_model_foreign_keys.objects.select_related( + "varying_fields", "single_field" + ) + .all() + .query.sql_with_params() + ) + queried_instance = models_from_cursor_wrapper( + mocked_model_foreign_keys, + cursor, + related_fields=["varying_fields", "single_field"], + ) + + with CaptureQueriesContext(connection) as ctx: + assert queried_instance.id == instance.id + + assert queried_instance.varying_fields_id == instance.varying_fields_id + assert queried_instance.varying_fields.id == instance.varying_fields.id + assert ( + queried_instance.varying_fields.title + == instance.varying_fields.title + ) + assert ( + queried_instance.varying_fields.updated_at + == instance.varying_fields.updated_at + ) + assert ( + queried_instance.varying_fields.content + == instance.varying_fields.content + ) + assert ( + queried_instance.varying_fields.items + == instance.varying_fields.items + ) + + assert queried_instance.single_field_id == instance.single_field_id + assert queried_instance.single_field.id == instance.single_field.id + assert queried_instance.single_field.name == instance.single_field.name + + assert len(ctx.captured_queries) == 0 + + +@pytest.mark.skipif( + django.VERSION < (3, 1), + reason=django_31_skip_reason, +) +@pytest.mark.parametrize( + "models_from_cursor_wrapper", + [ + pytest.lazy_fixture("models_from_cursor_wrapper_multiple"), + pytest.lazy_fixture("models_from_cursor_wrapper_single"), + ], +) +@pytest.mark.parametrize( + "selected", [True, False], ids=["selected", "not_selected"] +) +def test_models_from_cursor_related_fields_optional( + mocked_model_varying_fields, + mocked_model_foreign_keys, + models_from_cursor_wrapper, + selected, +): + instance = mocked_model_foreign_keys.objects.create( + varying_fields=mocked_model_varying_fields.objects.create( + title="test", updated_at=timezone.now() + ), + single_field=None, + ) + + with connection.cursor() as cursor: + select_related = ["varying_fields"] + if selected: + select_related.append("single_field") + + cursor.execute( + *mocked_model_foreign_keys.objects.select_related(*select_related) + .all() + .query.sql_with_params() + ) + queried_instance = models_from_cursor_wrapper( + mocked_model_foreign_keys, + cursor, + related_fields=["varying_fields", "single_field"], + ) + + assert queried_instance.id == instance.id + assert queried_instance.varying_fields_id == instance.varying_fields_id + assert queried_instance.single_field_id == instance.single_field_id + + with CaptureQueriesContext(connection) as ctx: + assert queried_instance.varying_fields.id == instance.varying_fields.id + assert ( + queried_instance.varying_fields.title + == instance.varying_fields.title + ) + assert ( + queried_instance.varying_fields.updated_at + == instance.varying_fields.updated_at + ) + assert ( + queried_instance.varying_fields.content + == instance.varying_fields.content + ) + assert ( + queried_instance.varying_fields.items + == instance.varying_fields.items + ) + + assert queried_instance.single_field is None + + assert len(ctx.captured_queries) == 0 + + +@pytest.mark.skipif( + django.VERSION < (3, 1), + reason=django_31_skip_reason, +) +def test_models_from_cursor_generator_efficiency( + mocked_model_varying_fields, mocked_model_single_field +): + mocked_model_single_field.objects.create(name="a") + mocked_model_single_field.objects.create(name="b") + + with connection.cursor() as cursor: + cursor.execute( + *mocked_model_single_field.objects.all().query.sql_with_params() + ) + + instances_generator = models_from_cursor( + mocked_model_single_field, cursor + ) + assert cursor.rownumber == 0 + + next(instances_generator) + assert cursor.rownumber == 1 + + next(instances_generator) + assert cursor.rownumber == 2 + + assert not next(instances_generator, None) + assert cursor.rownumber == 2 diff --git a/tests/test_upsert.py b/tests/test_upsert.py index a53561ce..3aa62079 100644 --- a/tests/test_upsert.py +++ b/tests/test_upsert.py @@ -1,7 +1,7 @@ import django import pytest -from django.db import models +from django.db import connection, models from django.db.models import F, Q from django.db.models.expressions import CombinedExpression, Value @@ -410,3 +410,68 @@ def test_bulk_upsert_update_values(): ) assert all([obj for obj in objs if obj.count == 1]) + + +@pytest.mark.parametrize("return_model", [True]) +def test_bulk_upsert_extra_columns_in_schema(return_model): + """Tests that extra columns being returned by the database that aren't + known by Django don't make the bulk upsert crash.""" + + model = get_fake_model( + { + "name": models.CharField(max_length=255, unique=True), + } + ) + + with connection.cursor() as cursor: + cursor.execute( + f"ALTER TABLE {model._meta.db_table} ADD COLUMN new_name text NOT NULL DEFAULT %s", + ("newjoe",), + ) + + objs = model.objects.bulk_upsert( + conflict_target=["name"], + rows=[ + dict(name="joe"), + ], + return_model=return_model, + ) + + assert len(objs) == 1 + + if return_model: + assert objs[0].name == "joe" + else: + assert objs[0]["name"] == "joe" + assert sorted(list(objs[0].keys())) == ["id", "name"] + + +def test_upsert_extra_columns_in_schema(): + """Tests that extra columns being returned by the database that aren't + known by Django don't make the upsert crash.""" + + model = get_fake_model( + { + "name": models.CharField(max_length=255, unique=True), + } + ) + + with connection.cursor() as cursor: + cursor.execute( + f"ALTER TABLE {model._meta.db_table} ADD COLUMN new_name text NOT NULL DEFAULT %s", + ("newjoe",), + ) + + obj_id = model.objects.upsert( + conflict_target=["name"], + fields=dict(name="joe"), + ) + + assert obj_id == 1 + + obj = model.objects.upsert_and_get( + conflict_target=["name"], + fields=dict(name="joe"), + ) + + assert obj.name == "joe" From d0b4df009f4f3645f160cd1ffcfddd3b47e232b3 Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Fri, 11 Aug 2023 12:46:18 +0200 Subject: [PATCH 38/43] Make `ExcludedCol` work with fields and use them when constructing SET clause The recent changes to add support for custom update values had a side effect that upserts with PostGIS related fields would break. They would break while building the `SET` clause. Django would try to figure out the right placeholder for the expression, even though none is required. Since there was not associated field information, it couldn't figure it out. By passing the field information, we ensure we always build the SET clause correctly. --- psqlextra/expressions.py | 18 +++++++++++++++--- psqlextra/query.py | 6 +++--- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/psqlextra/expressions.py b/psqlextra/expressions.py index d9c6bb54..20486dfa 100644 --- a/psqlextra/expressions.py +++ b/psqlextra/expressions.py @@ -1,4 +1,6 @@ -from django.db.models import CharField, expressions +from typing import Union + +from django.db.models import CharField, Field, expressions class HStoreValue(expressions.Expression): @@ -215,8 +217,18 @@ class ExcludedCol(expressions.Expression): See: https://www.postgresql.org/docs/current/sql-insert.html#SQL-ON-CONFLICT """ - def __init__(self, name: str): - self.name = name + def __init__(self, field_or_name: Union[Field, str]): + + # We support both field classes or just field names here. We prefer + # fields because when the expression is compiled, it might need + # the field information to figure out the correct placeholder. + # Even though that isn't require for this particular expression. + if isinstance(field_or_name, Field): + super().__init__(field_or_name) + self.name = field_or_name.column + else: + super().__init__(None) + self.name = field_or_name def as_sql(self, compiler, connection): quoted_name = connection.ops.quote_name(self.name) diff --git a/psqlextra/query.py b/psqlextra/query.py index a014d9a8..2d24b5ae 100644 --- a/psqlextra/query.py +++ b/psqlextra/query.py @@ -591,7 +591,7 @@ def _get_upsert_fields(self, kwargs): has_default = field.default != NOT_PROVIDED if field.name in kwargs or field.column in kwargs: insert_fields.append(field) - update_values[field.name] = ExcludedCol(field.column) + update_values[field.name] = ExcludedCol(field) continue elif has_default: insert_fields.append(field) @@ -602,13 +602,13 @@ def _get_upsert_fields(self, kwargs): # instead of a concrete field, we have to handle that if field.primary_key is True and "pk" in kwargs: insert_fields.append(field) - update_values[field.name] = ExcludedCol(field.column) + update_values[field.name] = ExcludedCol(field) continue if self._is_magical_field(model_instance, field, is_insert=True): insert_fields.append(field) if self._is_magical_field(model_instance, field, is_insert=False): - update_values[field.name] = ExcludedCol(field.column) + update_values[field.name] = ExcludedCol(field) return insert_fields, update_values From e5503cb3f3c1b7959bd55253d3a79296f4c8f0ef Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Fri, 25 Aug 2023 11:54:12 +0200 Subject: [PATCH 39/43] Tolerate unknown fields coming from JOIN'd data --- psqlextra/introspect/models.py | 15 +++++++----- tests/test_introspect.py | 45 ++++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 6 deletions(-) diff --git a/psqlextra/introspect/models.py b/psqlextra/introspect/models.py index cad84e9f..61a478dd 100644 --- a/psqlextra/introspect/models.py +++ b/psqlextra/introspect/models.py @@ -28,11 +28,11 @@ def _construct_model( apply_converters: bool = True ) -> TModel: fields_by_name_and_column = {} - for field in inspect_model_local_concrete_fields(model): - fields_by_name_and_column[field.attname] = field + for concrete_field in inspect_model_local_concrete_fields(model): + fields_by_name_and_column[concrete_field.attname] = concrete_field - if field.db_column: - fields_by_name_and_column[field.db_column] = field + if concrete_field.db_column: + fields_by_name_and_column[concrete_field.db_column] = concrete_field indexable_columns = list(columns) @@ -41,9 +41,12 @@ def _construct_model( for index, value in enumerate(values): column = indexable_columns[index] try: - field = cast(Field, model._meta.get_field(column)) + field: Optional[Field] = cast(Field, model._meta.get_field(column)) except FieldDoesNotExist: - field = fields_by_name_and_column[column] + field = fields_by_name_and_column.get(column) + + if not field: + continue field_column_expression = field.get_col(model._meta.db_table) diff --git a/tests/test_introspect.py b/tests/test_introspect.py index 21e842ac..5e5a9ffc 100644 --- a/tests/test_introspect.py +++ b/tests/test_introspect.py @@ -415,3 +415,48 @@ def test_models_from_cursor_generator_efficiency( assert not next(instances_generator, None) assert cursor.rownumber == 2 + + +@pytest.mark.skipif( + django.VERSION < (3, 1), + reason=django_31_skip_reason, +) +def test_models_from_cursor_tolerates_additional_columns( + mocked_model_foreign_keys, mocked_model_varying_fields +): + with connection.cursor() as cursor: + cursor.execute( + f"ALTER TABLE {mocked_model_foreign_keys._meta.db_table} ADD COLUMN new_col text DEFAULT NULL" + ) + cursor.execute( + f"ALTER TABLE {mocked_model_varying_fields._meta.db_table} ADD COLUMN new_col text DEFAULT NULL" + ) + + instance = mocked_model_foreign_keys.objects.create( + varying_fields=mocked_model_varying_fields.objects.create( + title="test", updated_at=timezone.now() + ), + single_field=None, + ) + + with connection.cursor() as cursor: + cursor.execute( + f""" + SELECT fk_t.*, vf_t.* FROM {mocked_model_foreign_keys._meta.db_table} fk_t + INNER JOIN {mocked_model_varying_fields._meta.db_table} vf_t ON vf_t.id = fk_t.varying_fields_id + """ + ) + + queried_instances = list( + models_from_cursor( + mocked_model_foreign_keys, + cursor, + related_fields=["varying_fields"], + ) + ) + + assert len(queried_instances) == 1 + assert queried_instances[0].id == instance.id + assert ( + queried_instances[0].varying_fields.id == instance.varying_fields.id + ) From 65b4688ab304cb35a013ef54170f3d4dc1070da8 Mon Sep 17 00:00:00 2001 From: Sebastian Willing Date: Sun, 19 Nov 2023 21:20:07 +0100 Subject: [PATCH 40/43] Stop updating on conflict if `update_condition` is False but not None Some users of this library set `update_condition=0` on `upsert` for not updating anything on conflict. The `upsert` documentation says: > update_condition: > Only update if this SQL expression evaluates to true. A value evaluating to Python `False` is ignored while the documentation says no update will be done. [#186513018] --- psqlextra/query.py | 4 +++- tests/test_upsert.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/psqlextra/query.py b/psqlextra/query.py index 2d24b5ae..65a20c50 100644 --- a/psqlextra/query.py +++ b/psqlextra/query.py @@ -327,7 +327,9 @@ def upsert( self.on_conflict( conflict_target, - ConflictAction.UPDATE, + ConflictAction.UPDATE + if (update_condition or update_condition is None) + else ConflictAction.NOTHING, index_predicate=index_predicate, update_condition=update_condition, update_values=update_values, diff --git a/tests/test_upsert.py b/tests/test_upsert.py index 3aa62079..a9e567b2 100644 --- a/tests/test_upsert.py +++ b/tests/test_upsert.py @@ -4,6 +4,7 @@ from django.db import connection, models from django.db.models import F, Q from django.db.models.expressions import CombinedExpression, Value +from django.test.utils import CaptureQueriesContext from psqlextra.expressions import ExcludedCol from psqlextra.fields import HStoreField @@ -144,6 +145,35 @@ def test_upsert_with_update_condition(): assert obj1.active +@pytest.mark.parametrize("update_condition_value", [0, False]) +def test_upsert_with_update_condition_false(update_condition_value): + """Tests that an expression can be used as an upsert update condition.""" + + model = get_fake_model( + { + "name": models.TextField(unique=True), + "priority": models.IntegerField(), + "active": models.BooleanField(), + } + ) + + obj1 = model.objects.create(name="joe", priority=1, active=False) + + with CaptureQueriesContext(connection) as ctx: + upsert_result = model.objects.upsert( + conflict_target=["name"], + update_condition=update_condition_value, + fields=dict(name="joe", priority=2, active=True), + ) + assert upsert_result is None + assert len(ctx) == 1 + assert 'ON CONFLICT ("name") DO NOTHING' in ctx[0]["sql"] + + obj1.refresh_from_db() + assert obj1.priority == 1 + assert not obj1.active + + def test_upsert_with_update_values(): """Tests that the default update values can be overriden with custom expressions.""" From c4aab405823331e17f4683e144e12e64e7541630 Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Tue, 6 Feb 2024 16:23:26 +0100 Subject: [PATCH 41/43] Prepare for Django 5.x support --- README.md | 2 +- psqlextra/sql.py | 11 +++++++++-- setup.py | 2 +- tox.ini | 3 ++- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 8127b254..17037d87 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ | :memo: | **License** | [![License](https://img.shields.io/:license-mit-blue.svg)](http://doge.mit-license.org) | | :package: | **PyPi** | [![PyPi](https://badge.fury.io/py/django-postgres-extra.svg)](https://pypi.python.org/pypi/django-postgres-extra) | | :four_leaf_clover: | **Code coverage** | [![Coverage Status](https://coveralls.io/repos/github/SectorLabs/django-postgres-extra/badge.svg?branch=coveralls)](https://coveralls.io/github/SectorLabs/django-postgres-extra?branch=master) | -| | **Django Versions** | 2.0, 2.1, 2.2, 3.0, 3.1, 3.2, 4.0, 4.1, 4.2 | +| | **Django Versions** | 2.0, 2.1, 2.2, 3.0, 3.1, 3.2, 4.0, 4.1, 4.2, 5.0 | | | **Python Versions** | 3.6, 3.7, 3.8, 3.9, 3.10, 3.11 | | | **Psycopg Versions** | 2, 3 | | :book: | **Documentation** | [Read The Docs](https://django-postgres-extra.readthedocs.io/en/master/) | diff --git a/psqlextra/sql.py b/psqlextra/sql.py index 3ceb5966..b2655088 100644 --- a/psqlextra/sql.py +++ b/psqlextra/sql.py @@ -64,8 +64,15 @@ def rename_annotations(self, annotations) -> None: new_annotations[new_name or old_name] = annotation if new_name and self.annotation_select_mask: - self.annotation_select_mask.discard(old_name) - self.annotation_select_mask.add(new_name) + # It's a set in all versions prior to Django 5.x + # and a list in Django 5.x and newer. + # https://github.com/django/django/commit/d6b6e5d0fd4e6b6d0183b4cf6e4bd4f9afc7bf67 + if isinstance(self.annotation_select_mask, set): + self.annotation_select_mask.discard(old_name) + self.annotation_select_mask.add(new_name) + elif isinstance(self.annotation_select_mask, list): + self.annotation_select_mask.remove(old_name) + self.annotation_select_mask.append(new_name) self.annotations.clear() self.annotations.update(new_annotations) diff --git a/setup.py b/setup.py index b3217fb3..c3431e27 100644 --- a/setup.py +++ b/setup.py @@ -68,7 +68,7 @@ def run(self): ], python_requires=">=3.6", install_requires=[ - "Django>=2.0,<5.0", + "Django>=2.0,<6.0", "python-dateutil>=2.8.0,<=3.0.0", ], extras_require={ diff --git a/tox.ini b/tox.ini index 3e229d0d..70a0e8ce 100644 --- a/tox.ini +++ b/tox.ini @@ -3,7 +3,7 @@ envlist = {py36,py37}-dj{20,21,22,30,31,32}-psycopg{28,29} {py38,py39,py310}-dj{21,22,30,31,32,40}-psycopg{28,29} {py38,py39,py310,py311}-dj{41}-psycopg{28,29} - {py38,py39,py310,py311}-dj{42}-psycopg{28,29,31} + {py310,py311}-dj{42,50}-psycopg{28,29,31} [testenv] deps = @@ -16,6 +16,7 @@ deps = dj40: Django~=4.0.0 dj41: Django~=4.1.0 dj42: Django~=4.2.0 + dj50: Django~=5.0.1 psycopg28: psycopg2[binary]~=2.8 psycopg29: psycopg2[binary]~=2.9 psycopg31: psycopg[binary]~=3.1 From 43a6f222b3ab85661a15ddf783c7171b76769fc8 Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Wed, 7 Feb 2024 09:07:21 +0100 Subject: [PATCH 42/43] Switch CircleCI to PyPi API token for publishing --- .circleci/config.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 64e60c30..92d9093b 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -142,8 +142,8 @@ jobs: name: Publish package command: > python -m twine upload - --username "${PYPI_REPO_USERNAME}" - --password "${PYPI_REPO_PASSWORD}" + --username "__token__" + --password "${PYPI_API_TOKEN}" --verbose --non-interactive --disable-progress-bar From 13b5672cfe1aed0ec10dcb0b3f4b382d22719de7 Mon Sep 17 00:00:00 2001 From: Swen Kooij Date: Tue, 18 Jun 2024 12:24:46 +0200 Subject: [PATCH 43/43] Fix automatic hstore extension creation not working on Django 4.2 or newer The following change broke the auto setup: https://github.com/django/django/commit/d3e746ace5eeea07216da97d9c3801f2fdc43223 This breaks because the call to `pscygop2.extras.register_hstore` is now conditional. Before, it would be called multiple times with empty OIDS, when eventually our auto registration would kick in and psycopg2 would fetch the OIDs itself. --- psqlextra/backend/base.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/psqlextra/backend/base.py b/psqlextra/backend/base.py index 5c788a05..c8ae73c5 100644 --- a/psqlextra/backend/base.py +++ b/psqlextra/backend/base.py @@ -3,6 +3,10 @@ from typing import TYPE_CHECKING from django.conf import settings +from django.contrib.postgres.signals import ( + get_hstore_oids, + register_type_handlers, +) from django.db import ProgrammingError from . import base_impl @@ -94,3 +98,22 @@ def prepare_database(self): "or add the extension manually.", exc_info=True, ) + return + + # Clear old (non-existent), stale oids. + get_hstore_oids.cache_clear() + + # Verify that we (and Django) can find the OIDs + # for hstore. + oids, _ = get_hstore_oids(self.alias) + if not oids: + logger.warning( + '"hstore" extension was created, but we cannot find the oids' + "in the database. Something went wrong.", + ) + return + + # We must trigger Django into registering the type handlers now + # so that any subsequent code can properly use the newly + # registered types. + register_type_handlers(self)