diff --git a/rest_framework/utils/serializer_helpers.py b/rest_framework/utils/serializer_helpers.py index 4cd2ada314..54068f5fb0 100644 --- a/rest_framework/utils/serializer_helpers.py +++ b/rest_framework/utils/serializer_helpers.py @@ -1,3 +1,4 @@ +import sys from collections import OrderedDict from collections.abc import Mapping, MutableMapping @@ -28,6 +29,22 @@ def __reduce__(self): # but preserve the raw data. return (dict, (dict(self),)) + if sys.version_info >= (3, 9): + # These are basically copied from OrderedDict, with `serializer` added. + def __or__(self, other): + if not isinstance(other, dict): + return NotImplemented + new = self.__class__(self, serializer=self.serializer) + new.update(other) + return new + + def __ror__(self, other): + if not isinstance(other, dict): + return NotImplemented + new = self.__class__(other, serializer=self.serializer) + new.update(self) + return new + class ReturnList(list): """ diff --git a/tests/test_serializer.py b/tests/test_serializer.py index afefd70e1c..c4c29ba4ad 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -740,3 +740,25 @@ class TestSerializer(A, B): 'f4': serializers.CharField, 'f5': serializers.CharField, } + + +class Test8301Regression: + @pytest.mark.skipif( + sys.version_info < (3, 9), + reason="dictionary union operator requires Python 3.9 or higher", + ) + def test_ReturnDict_merging(self): + # Serializer.data returns ReturnDict, this is essentially a test for that. + + class TestSerializer(serializers.Serializer): + char = serializers.CharField() + + s = TestSerializer(data={'char': 'x'}) + assert s.is_valid() + assert s.data | {} == {'char': 'x'} + assert s.data | {'other': 'y'} == {'char': 'x', 'other': 'y'} + assert {} | s.data == {'char': 'x'} + assert {'other': 'y'} | s.data == {'char': 'x', 'other': 'y'} + + assert (s.data | {}).__class__ == s.data.__class__ + assert ({} | s.data).__class__ == s.data.__class__