Skip to content

Commit 4564831

Browse files
bpo-32953: Dataclasses: frozen should not be inherited for non-dataclass derived classes (GH-6147) (GH-6148)
If a non-dataclass derives from a frozen dataclass, allow attributes to be set. Require either all of the dataclasses in a class hierarchy to be frozen, or all non-frozen. Store `@dataclass` parameters on the class object under `__dataclass_params__`. This is needed to detect frozen base classes. (cherry picked from commit f199bc6) Co-authored-by: Eric V. Smith <ericvsmith@users.noreply.github.com>
1 parent 3c0a5a7 commit 4564831

3 files changed

Lines changed: 168 additions & 55 deletions

File tree

Lib/dataclasses.py

Lines changed: 89 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,11 @@ class _MISSING_TYPE:
171171

172172
# The name of an attribute on the class where we store the Field
173173
# objects. Also used to check if a class is a Data Class.
174-
_MARKER = '__dataclass_fields__'
174+
_FIELDS = '__dataclass_fields__'
175+
176+
# The name of an attribute on the class that stores the parameters to
177+
# @dataclass.
178+
_PARAMS = '__dataclass_params__'
175179

176180
# The name of the function, that if it exists, is called at the end of
177181
# __init__.
@@ -192,7 +196,7 @@ class InitVar(metaclass=_InitVarMeta):
192196
# name and type are filled in after the fact, not in __init__. They're
193197
# not known at the time this class is instantiated, but it's
194198
# convenient if they're available later.
195-
# When cls._MARKER is filled in with a list of Field objects, the name
199+
# When cls._FIELDS is filled in with a list of Field objects, the name
196200
# and type fields will have been populated.
197201
class Field:
198202
__slots__ = ('name',
@@ -236,6 +240,32 @@ def __repr__(self):
236240
')')
237241

238242

243+
class _DataclassParams:
244+
__slots__ = ('init',
245+
'repr',
246+
'eq',
247+
'order',
248+
'unsafe_hash',
249+
'frozen',
250+
)
251+
def __init__(self, init, repr, eq, order, unsafe_hash, frozen):
252+
self.init = init
253+
self.repr = repr
254+
self.eq = eq
255+
self.order = order
256+
self.unsafe_hash = unsafe_hash
257+
self.frozen = frozen
258+
259+
def __repr__(self):
260+
return ('_DataclassParams('
261+
f'init={self.init},'
262+
f'repr={self.repr},'
263+
f'eq={self.eq},'
264+
f'order={self.order},'
265+
f'unsafe_hash={self.unsafe_hash},'
266+
f'frozen={self.frozen}'
267+
')')
268+
239269
# This function is used instead of exposing Field creation directly,
240270
# so that a type checker can be told (via overloads) that this is a
241271
# function whose type depends on its parameters.
@@ -285,6 +315,7 @@ def _create_fn(name, args, body, *, globals=None, locals=None,
285315
args = ','.join(args)
286316
body = '\n'.join(f' {b}' for b in body)
287317

318+
# Compute the text of the entire function.
288319
txt = f'def {name}({args}){return_annotation}:\n{body}'
289320

290321
exec(txt, globals, locals)
@@ -432,12 +463,29 @@ def _repr_fn(fields):
432463
')"'])
433464

434465

435-
def _frozen_setattr(self, name, value):
436-
raise FrozenInstanceError(f'cannot assign to field {name!r}')
437-
438-
439-
def _frozen_delattr(self, name):
440-
raise FrozenInstanceError(f'cannot delete field {name!r}')
466+
def _frozen_get_del_attr(cls, fields):
467+
# XXX: globals is modified on the first call to _create_fn, then the
468+
# modified version is used in the second call. Is this okay?
469+
globals = {'cls': cls,
470+
'FrozenInstanceError': FrozenInstanceError}
471+
if fields:
472+
fields_str = '(' + ','.join(repr(f.name) for f in fields) + ',)'
473+
else:
474+
# Special case for the zero-length tuple.
475+
fields_str = '()'
476+
return (_create_fn('__setattr__',
477+
('self', 'name', 'value'),
478+
(f'if type(self) is cls or name in {fields_str}:',
479+
' raise FrozenInstanceError(f"cannot assign to field {name!r}")',
480+
f'super(cls, self).__setattr__(name, value)'),
481+
globals=globals),
482+
_create_fn('__delattr__',
483+
('self', 'name'),
484+
(f'if type(self) is cls or name in {fields_str}:',
485+
' raise FrozenInstanceError(f"cannot delete field {name!r}")',
486+
f'super(cls, self).__delattr__(name)'),
487+
globals=globals),
488+
)
441489

442490

443491
def _cmp_fn(name, op, self_tuple, other_tuple):
@@ -583,23 +631,32 @@ def _set_new_attribute(cls, name, value):
583631
# version of this table.
584632

585633

586-
def _process_class(cls, repr, eq, order, unsafe_hash, init, frozen):
634+
def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
587635
# Now that dicts retain insertion order, there's no reason to use
588636
# an ordered dict. I am leveraging that ordering here, because
589637
# derived class fields overwrite base class fields, but the order
590638
# is defined by the base class, which is found first.
591639
fields = {}
592640

641+
setattr(cls, _PARAMS, _DataclassParams(init, repr, eq, order,
642+
unsafe_hash, frozen))
643+
593644
# Find our base classes in reverse MRO order, and exclude
594645
# ourselves. In reversed order so that more derived classes
595646
# override earlier field definitions in base classes.
647+
# As long as we're iterating over them, see if any are frozen.
648+
any_frozen_base = False
649+
has_dataclass_bases = False
596650
for b in cls.__mro__[-1:0:-1]:
597651
# Only process classes that have been processed by our
598-
# decorator. That is, they have a _MARKER attribute.
599-
base_fields = getattr(b, _MARKER, None)
652+
# decorator. That is, they have a _FIELDS attribute.
653+
base_fields = getattr(b, _FIELDS, None)
600654
if base_fields:
655+
has_dataclass_bases = True
601656
for f in base_fields.values():
602657
fields[f.name] = f
658+
if getattr(b, _PARAMS).frozen:
659+
any_frozen_base = True
603660

604661
# Now find fields in our class. While doing so, validate some
605662
# things, and set the default values (as class attributes)
@@ -623,20 +680,21 @@ def _process_class(cls, repr, eq, order, unsafe_hash, init, frozen):
623680
else:
624681
setattr(cls, f.name, f.default)
625682

626-
# We're inheriting from a frozen dataclass, but we're not frozen.
627-
if cls.__setattr__ is _frozen_setattr and not frozen:
628-
raise TypeError('cannot inherit non-frozen dataclass from a '
629-
'frozen one')
683+
# Check rules that apply if we are derived from any dataclasses.
684+
if has_dataclass_bases:
685+
# Raise an exception if any of our bases are frozen, but we're not.
686+
if any_frozen_base and not frozen:
687+
raise TypeError('cannot inherit non-frozen dataclass from a '
688+
'frozen one')
630689

631-
# We're inheriting from a non-frozen dataclass, but we're frozen.
632-
if (hasattr(cls, _MARKER) and cls.__setattr__ is not _frozen_setattr
633-
and frozen):
634-
raise TypeError('cannot inherit frozen dataclass from a '
635-
'non-frozen one')
690+
# Raise an exception if we're frozen, but none of our bases are.
691+
if not any_frozen_base and frozen:
692+
raise TypeError('cannot inherit frozen dataclass from a '
693+
'non-frozen one')
636694

637-
# Remember all of the fields on our class (including bases). This
695+
# Remember all of the fields on our class (including bases). This also
638696
# marks this class as being a dataclass.
639-
setattr(cls, _MARKER, fields)
697+
setattr(cls, _FIELDS, fields)
640698

641699
# Was this class defined with an explicit __hash__? Note that if
642700
# __eq__ is defined in this class, then python will automatically
@@ -704,10 +762,10 @@ def _process_class(cls, repr, eq, order, unsafe_hash, init, frozen):
704762
'functools.total_ordering')
705763

706764
if frozen:
707-
for name, fn in [('__setattr__', _frozen_setattr),
708-
('__delattr__', _frozen_delattr)]:
709-
if _set_new_attribute(cls, name, fn):
710-
raise TypeError(f'Cannot overwrite attribute {name} '
765+
# XXX: Which fields are frozen? InitVar? ClassVar? hashed-only?
766+
for fn in _frozen_get_del_attr(cls, field_list):
767+
if _set_new_attribute(cls, fn.__name__, fn):
768+
raise TypeError(f'Cannot overwrite attribute {fn.__name__} '
711769
f'in class {cls.__name__}')
712770

713771
# Decide if/how we're going to create a hash function.
@@ -759,7 +817,7 @@ def dataclass(_cls=None, *, init=True, repr=True, eq=True, order=False,
759817
"""
760818

761819
def wrap(cls):
762-
return _process_class(cls, repr, eq, order, unsafe_hash, init, frozen)
820+
return _process_class(cls, init, repr, eq, order, unsafe_hash, frozen)
763821

764822
# See if we're being called as @dataclass or @dataclass().
765823
if _cls is None:
@@ -779,7 +837,7 @@ def fields(class_or_instance):
779837

780838
# Might it be worth caching this, per class?
781839
try:
782-
fields = getattr(class_or_instance, _MARKER)
840+
fields = getattr(class_or_instance, _FIELDS)
783841
except AttributeError:
784842
raise TypeError('must be called with a dataclass type or instance')
785843

@@ -790,13 +848,13 @@ def fields(class_or_instance):
790848

791849
def _is_dataclass_instance(obj):
792850
"""Returns True if obj is an instance of a dataclass."""
793-
return not isinstance(obj, type) and hasattr(obj, _MARKER)
851+
return not isinstance(obj, type) and hasattr(obj, _FIELDS)
794852

795853

796854
def is_dataclass(obj):
797855
"""Returns True if obj is a dataclass or an instance of a
798856
dataclass."""
799-
return hasattr(obj, _MARKER)
857+
return hasattr(obj, _FIELDS)
800858

801859

802860
def asdict(obj, *, dict_factory=dict):
@@ -953,7 +1011,7 @@ class C:
9531011
# It's an error to have init=False fields in 'changes'.
9541012
# If a field is not in 'changes', read its value from the provided obj.
9551013

956-
for f in getattr(obj, _MARKER).values():
1014+
for f in getattr(obj, _FIELDS).values():
9571015
if not f.init:
9581016
# Error if this field is specified in changes.
9591017
if f.name in changes:

Lib/test/test_dataclasses.py

Lines changed: 75 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2476,41 +2476,92 @@ class D(C):
24762476
d = D(0, 10)
24772477
with self.assertRaises(FrozenInstanceError):
24782478
d.i = 5
2479+
with self.assertRaises(FrozenInstanceError):
2480+
d.j = 6
24792481
self.assertEqual(d.i, 0)
2482+
self.assertEqual(d.j, 10)
2483+
2484+
# Test both ways: with an intermediate normal (non-dataclass)
2485+
# class and without an intermediate class.
2486+
def test_inherit_nonfrozen_from_frozen(self):
2487+
for intermediate_class in [True, False]:
2488+
with self.subTest(intermediate_class=intermediate_class):
2489+
@dataclass(frozen=True)
2490+
class C:
2491+
i: int
24802492

2481-
def test_inherit_from_nonfrozen_from_frozen(self):
2482-
@dataclass(frozen=True)
2483-
class C:
2484-
i: int
2493+
if intermediate_class:
2494+
class I(C): pass
2495+
else:
2496+
I = C
24852497

2486-
with self.assertRaisesRegex(TypeError,
2487-
'cannot inherit non-frozen dataclass from a frozen one'):
2488-
@dataclass
2489-
class D(C):
2490-
pass
2498+
with self.assertRaisesRegex(TypeError,
2499+
'cannot inherit non-frozen dataclass from a frozen one'):
2500+
@dataclass
2501+
class D(I):
2502+
pass
24912503

2492-
def test_inherit_from_frozen_from_nonfrozen(self):
2493-
@dataclass
2494-
class C:
2495-
i: int
2504+
def test_inherit_frozen_from_nonfrozen(self):
2505+
for intermediate_class in [True, False]:
2506+
with self.subTest(intermediate_class=intermediate_class):
2507+
@dataclass
2508+
class C:
2509+
i: int
24962510

2497-
with self.assertRaisesRegex(TypeError,
2498-
'cannot inherit frozen dataclass from a non-frozen one'):
2499-
@dataclass(frozen=True)
2500-
class D(C):
2501-
pass
2511+
if intermediate_class:
2512+
class I(C): pass
2513+
else:
2514+
I = C
2515+
2516+
with self.assertRaisesRegex(TypeError,
2517+
'cannot inherit frozen dataclass from a non-frozen one'):
2518+
@dataclass(frozen=True)
2519+
class D(I):
2520+
pass
25022521

25032522
def test_inherit_from_normal_class(self):
2504-
class C:
2505-
pass
2523+
for intermediate_class in [True, False]:
2524+
with self.subTest(intermediate_class=intermediate_class):
2525+
class C:
2526+
pass
2527+
2528+
if intermediate_class:
2529+
class I(C): pass
2530+
else:
2531+
I = C
2532+
2533+
@dataclass(frozen=True)
2534+
class D(I):
2535+
i: int
2536+
2537+
d = D(10)
2538+
with self.assertRaises(FrozenInstanceError):
2539+
d.i = 5
2540+
2541+
def test_non_frozen_normal_derived(self):
2542+
# See bpo-32953.
25062543

25072544
@dataclass(frozen=True)
2508-
class D(C):
2509-
i: int
2545+
class D:
2546+
x: int
2547+
y: int = 10
25102548

2511-
d = D(10)
2549+
class S(D):
2550+
pass
2551+
2552+
s = S(3)
2553+
self.assertEqual(s.x, 3)
2554+
self.assertEqual(s.y, 10)
2555+
s.cached = True
2556+
2557+
# But can't change the frozen attributes.
25122558
with self.assertRaises(FrozenInstanceError):
2513-
d.i = 5
2559+
s.x = 5
2560+
with self.assertRaises(FrozenInstanceError):
2561+
s.y = 5
2562+
self.assertEqual(s.x, 3)
2563+
self.assertEqual(s.y, 10)
2564+
self.assertEqual(s.cached, True)
25142565

25152566

25162567
if __name__ == '__main__':
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
If a non-dataclass inherits from a frozen dataclass, allow attributes to be
2+
added to the derived class. Only attributes from from the frozen dataclass
3+
cannot be assigned to. Require all dataclasses in a hierarchy to be either
4+
all frozen or all non-frozen.

0 commit comments

Comments
 (0)