Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions psqlextra/introspect/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Optional,
Type,
TypeVar,
Union,
cast,
)

Expand Down Expand Up @@ -115,9 +116,10 @@ def models_from_cursor(
)

for index, related_field_name in enumerate(related_fields):
related_model = model._meta.get_field(
related_field_name
).related_model
related_model = cast(
Union[Type[Model], None],
model._meta.get_field(related_field_name).related_model,
)
if not related_model:
continue

Expand Down
10 changes: 10 additions & 0 deletions psqlextra/sql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import OrderedDict
from collections.abc import Iterable
from typing import Any, Dict, List, Optional, Tuple, Union

import django
Expand All @@ -7,6 +8,7 @@
from django.db import connections, models
from django.db.models import Expression, sql
from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import Ref

from .compiler import PostgresInsertOnConflictCompiler
from .compiler import SQLUpdateCompiler as PostgresUpdateCompiler
Expand Down Expand Up @@ -74,6 +76,14 @@ def rename_annotations(self, annotations) -> None:
self.annotation_select_mask.remove(old_name)
self.annotation_select_mask.append(new_name)

if isinstance(self.group_by, Iterable):
for statement in self.group_by:
if not isinstance(statement, Ref):
continue

if statement.refs in annotations: # type: ignore[attr-defined]
statement.refs = annotations[statement.refs] # type: ignore[attr-defined]

self.annotations.clear()
self.annotations.update(new_annotations)

Expand Down
5 changes: 4 additions & 1 deletion settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
'default': dj_database_url.config(default='postgres:///psqlextra'),
}

DATABASES['default']['ENGINE'] = 'psqlextra.backend'
DATABASES['default']['ENGINE'] = 'tests.psqlextra_test_backend'

LANGUAGE_CODE = 'en'
LANGUAGES = (
Expand All @@ -24,3 +24,6 @@
'psqlextra',
'tests',
)

USE_TZ = True
TIME_ZONE = 'UTC'
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def run(self):
"docformatter==1.4",
"mypy==1.2.0; python_version > '3.6'",
"mypy==0.971; python_version <= '3.6'",
"django-stubs==1.16.0; python_version > '3.6'",
"django-stubs==4.2.7; python_version > '3.6'",
"django-stubs==1.9.0; python_version <= '3.6'",
"typing-extensions==4.5.0; python_version > '3.6'",
"typing-extensions==4.1.0; python_version <= '3.6'",
Expand Down
Empty file.
23 changes: 23 additions & 0 deletions tests/psqlextra_test_backend/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from datetime import timezone

import django

from django.conf import settings

from psqlextra.backend.base import DatabaseWrapper as PSQLExtraDatabaseWrapper


class DatabaseWrapper(PSQLExtraDatabaseWrapper):
# Works around the compatibility issue of Django <3.0 and psycopg2.9
# in combination with USE_TZ
#
# See: https://github.com/psycopg/psycopg2/issues/1293#issuecomment-862835147
if django.VERSION < (3, 1):

def create_cursor(self, name=None):
cursor = super().create_cursor(name)
cursor.tzinfo_factory = (
lambda offset: timezone.utc if settings.USE_TZ else None
)

return cursor
39 changes: 38 additions & 1 deletion tests/test_query.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from datetime import datetime, timezone

from django.db import connection, models
from django.db.models import Case, F, Q, Value, When
from django.db.models import Case, F, Min, Q, Value, When
from django.db.models.functions.datetime import TruncSecond
from django.test.utils import CaptureQueriesContext, override_settings

from psqlextra.expressions import HStoreRef
Expand Down Expand Up @@ -96,6 +99,40 @@ def test_query_annotate_in_expression():
assert result.is_he_henk == "really henk"


def test_query_annotate_group_by():
"""Tests whether annotations with GROUP BY clauses are properly renamed
when the annotation overwrites a field name."""

model = get_fake_model(
{
"name": models.TextField(),
"timestamp": models.DateTimeField(null=False),
"value": models.IntegerField(),
}
)

timestamp = datetime(2024, 1, 1, 0, 0, 0, 0, tzinfo=timezone.utc)

model.objects.create(name="me", timestamp=timestamp, value=1)

result = (
model.objects.values("name")
.annotate(
timestamp=TruncSecond("timestamp", tzinfo=timezone.utc),
value=Min("value"),
)
.values_list(
"name",
"value",
"timestamp",
)
.order_by("name")
.first()
)

assert result == ("me", 1, timestamp)


def test_query_hstore_value_update_f_ref():
"""Tests whether F(..) expressions can be used in hstore values when
performing update queries."""
Expand Down