Skip to content

Commit c284c27

Browse files
committed
Rebuild dataclass fields before schema generation
Backport of: #11949
1 parent 5e6d1dc commit c284c27

File tree

5 files changed

+112
-11
lines changed

5 files changed

+112
-11
lines changed

pydantic/_internal/_dataclasses.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ class PydanticDataclass(StandardDataclass, typing.Protocol):
5555
__pydantic_serializer__: ClassVar[SchemaSerializer]
5656
__pydantic_validator__: ClassVar[SchemaValidator | PluggableSchemaValidator]
5757

58+
@classmethod
59+
def __pydantic_fields_complete__(cls) -> bool: ...
60+
5861
else:
5962
# See PyCharm issues https://youtrack.jetbrains.com/issue/PY-21915
6063
# and https://youtrack.jetbrains.com/issue/PY-51428

pydantic/_internal/_fields.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
from ..fields import FieldInfo
3434
from ..main import BaseModel
35-
from ._dataclasses import StandardDataclass
35+
from ._dataclasses import PydanticDataclass, StandardDataclass
3636
from ._decorators import DecoratorInfos
3737

3838

@@ -380,7 +380,7 @@ def collect_dataclass_fields(
380380
continue
381381

382382
globalns, localns = ns_resolver.types_namespace
383-
ann_type, _ = _typing_extra.try_eval_type(dataclass_field.type, globalns, localns)
383+
ann_type, evaluated = _typing_extra.try_eval_type(dataclass_field.type, globalns, localns)
384384

385385
if _typing_extra.is_classvar_annotation(ann_type):
386386
continue
@@ -407,10 +407,16 @@ def collect_dataclass_fields(
407407
field_info = FieldInfo_.from_annotated_attribute(
408408
ann_type, dataclass_field.default, _source=AnnotationSource.DATACLASS
409409
)
410+
field_info._original_assignment = dataclass_field.default
410411
else:
411412
field_info = FieldInfo_.from_annotated_attribute(
412413
ann_type, dataclass_field, _source=AnnotationSource.DATACLASS
413414
)
415+
field_info._original_assignment = dataclass_field
416+
417+
if not evaluated:
418+
field_info._complete = False
419+
field_info._original_annotation = ann_type
414420

415421
fields[ann_name] = field_info
416422

@@ -439,6 +445,51 @@ def collect_dataclass_fields(
439445
return fields
440446

441447

448+
def rebuild_dataclass_fields(
449+
cls: type[PydanticDataclass],
450+
*,
451+
config_wrapper: ConfigWrapper,
452+
ns_resolver: NsResolver,
453+
typevars_map: Mapping[TypeVar, Any],
454+
) -> dict[str, FieldInfo]:
455+
"""Rebuild the (already present) dataclass fields by trying to reevaluate annotations.
456+
457+
This function should be called whenever a dataclass with incomplete fields is encountered.
458+
459+
Raises:
460+
NameError: If one of the annotations failed to evaluate.
461+
462+
Note:
463+
This function *doesn't* mutate the dataclass fields in place, as it can be called during
464+
schema generation, where you don't want to mutate other dataclass's fields.
465+
"""
466+
FieldInfo_ = import_cached_field_info()
467+
468+
rebuilt_fields: dict[str, FieldInfo] = {}
469+
with ns_resolver.push(cls):
470+
for f_name, field_info in cls.__pydantic_fields__.items():
471+
if field_info._complete:
472+
rebuilt_fields[f_name] = field_info
473+
else:
474+
existing_desc = field_info.description
475+
ann = _typing_extra.eval_type(
476+
field_info._original_annotation,
477+
*ns_resolver.types_namespace,
478+
)
479+
ann = _generics.replace_types(ann, typevars_map)
480+
new_field = FieldInfo_.from_annotated_attribute(
481+
ann,
482+
field_info._original_assignment,
483+
_source=AnnotationSource.DATACLASS,
484+
)
485+
486+
# The description might come from the docstring if `use_attribute_docstrings` was `True`:
487+
new_field.description = new_field.description if new_field.description is not None else existing_desc
488+
rebuilt_fields[f_name] = new_field
489+
490+
return rebuilt_fields
491+
492+
442493
def is_valid_field_name(name: str) -> bool:
443494
return not name.startswith('_')
444495

pydantic/_internal/_generate_schema.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,12 @@
8787
inspect_validator,
8888
)
8989
from ._docs_extraction import extract_docstrings_from_cls
90-
from ._fields import collect_dataclass_fields, rebuild_model_fields, takes_validated_data_argument
90+
from ._fields import (
91+
collect_dataclass_fields,
92+
rebuild_dataclass_fields,
93+
rebuild_model_fields,
94+
takes_validated_data_argument,
95+
)
9196
from ._forward_ref import PydanticRecursiveRef
9297
from ._generics import get_standard_typevars_map, replace_types
9398
from ._import_utils import import_cached_base_model, import_cached_field_info
@@ -1912,14 +1917,27 @@ def _dataclass_schema(
19121917

19131918
with self._ns_resolver.push(dataclass), self._config_wrapper_stack.push(config):
19141919
if is_pydantic_dataclass(dataclass):
1915-
# Copy the field info instances to avoid mutating the `FieldInfo` instances
1916-
# of the generic dataclass generic origin (e.g. `apply_typevars_map` below).
1917-
# Note that we don't apply `deepcopy` on `__pydantic_fields__` because we
1918-
# don't want to copy the `FieldInfo` attributes:
1919-
fields = {f_name: copy(field_info) for f_name, field_info in dataclass.__pydantic_fields__.items()}
1920-
if typevars_map:
1921-
for field in fields.values():
1922-
field.apply_typevars_map(typevars_map, *self._types_namespace)
1920+
if dataclass.__pydantic_fields_complete__():
1921+
# Copy the field info instances to avoid mutating the `FieldInfo` instances
1922+
# of the generic dataclass generic origin (e.g. `apply_typevars_map` below).
1923+
# Note that we don't apply `deepcopy` on `__pydantic_fields__` because we
1924+
# don't want to copy the `FieldInfo` attributes:
1925+
fields = {
1926+
f_name: copy(field_info) for f_name, field_info in dataclass.__pydantic_fields__.items()
1927+
}
1928+
if typevars_map:
1929+
for field in fields.values():
1930+
field.apply_typevars_map(typevars_map, *self._types_namespace)
1931+
else:
1932+
try:
1933+
fields = rebuild_dataclass_fields(
1934+
dataclass,
1935+
config_wrapper=self._config_wrapper,
1936+
ns_resolver=self._ns_resolver,
1937+
typevars_map=typevars_map or {},
1938+
)
1939+
except NameError as e:
1940+
raise PydanticUndefinedAnnotation.from_name_error(e) from e
19231941
else:
19241942
fields = collect_dataclass_fields(
19251943
dataclass,

pydantic/dataclasses.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ def create_dataclass(cls: type[Any]) -> type[PydanticDataclass]:
272272
cls.__doc__ = original_doc
273273
cls.__module__ = original_cls.__module__
274274
cls.__qualname__ = original_cls.__qualname__
275+
cls.__pydantic_fields_complete__ = classmethod(_pydantic_fields_complete)
275276
cls.__pydantic_complete__ = False # `complete_dataclass` will set it to `True` if successful.
276277
# TODO `parent_namespace` is currently None, but we could do the same thing as Pydantic models:
277278
# fetch the parent ns using `parent_frame_namespace` (if the dataclass was defined in a function),
@@ -282,6 +283,14 @@ def create_dataclass(cls: type[Any]) -> type[PydanticDataclass]:
282283
return create_dataclass if _cls is None else create_dataclass(_cls)
283284

284285

286+
def _pydantic_fields_complete(cls: type[PydanticDataclass]) -> bool:
287+
"""Return whether the fields where successfully collected (i.e. type hints were successfully resolves).
288+
289+
This is a private property, not meant to be used outside Pydantic.
290+
"""
291+
return all(field_info._complete for field_info in cls.__pydantic_fields__.values())
292+
293+
285294
__getattr__ = getattr_migration(__name__)
286295

287296
if sys.version_info < (3, 11):

tests/test_dataclasses.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3086,3 +3086,23 @@ class A:
30863086
a: int
30873087

30883088
assert 'a' in A.__pydantic_fields__ # pyright: ignore[reportAttributeAccessIssue]
3089+
3090+
3091+
def test_dataclass_fields_rebuilt_before_schema_generation() -> None:
3092+
"""https://github.com/pydantic/pydantic/issues/11947"""
3093+
3094+
def update_schema(schema: dict[str, Any]) -> None:
3095+
schema['test'] = schema['title']
3096+
3097+
@pydantic.dataclasses.dataclass
3098+
class A:
3099+
a: """Annotated[
3100+
Forward,
3101+
Field(field_title_generator=lambda name, _: name, json_schema_extra=update_schema)
3102+
]""" = True
3103+
3104+
Forward = bool
3105+
3106+
ta = TypeAdapter(A)
3107+
3108+
assert ta.json_schema()['properties']['a']['test'] == 'a'

0 commit comments

Comments
 (0)