Skip to content

Commit ff075ef

Browse files
author
khamaileon
committed
Allow custom serializer class for request and response
1 parent 86673a3 commit ff075ef

File tree

4 files changed

+144
-2
lines changed

4 files changed

+144
-2
lines changed

docs/api-guide/viewsets.md

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,8 @@ As with `ModelViewSet`, you'll normally need to provide at least the `queryset`
295295

296296
Again, as with `ModelViewSet`, you can use any of the standard attributes and method overrides available to `GenericAPIView`.
297297

298+
---
299+
298300
# Custom ViewSet base classes
299301

300302
You may need to provide custom `ViewSet` classes that do not have the full set of `ModelViewSet` actions, or that customize the behavior in some other way.
@@ -321,3 +323,44 @@ By creating your own base `ViewSet` classes, you can provide common behavior tha
321323

322324
[cite]: https://guides.rubyonrails.org/action_controller_overview.html
323325
[routers]: routers.md
326+
327+
---
328+
329+
# Custom serializer for request and response
330+
331+
It is possible to define at the view level (or for each custom method via the @action decorator) a custom serialization class for each request and response.
332+
To do this you need to define `request_serializer_response` and `response_serializer_response` and call them via `get_request_serializer` and `get_response_serializer`.
333+
334+
335+
class UserViewSet(viewsets.ModelViewSet):
336+
"""
337+
A viewset that provides the standard actions
338+
"""
339+
queryset = User.objects.all()
340+
serializer_class = UserSerializer
341+
342+
@action(
343+
detail=True,
344+
methods=['post'],
345+
request_serializer_class=PasswordSerializer
346+
)
347+
def set_password(self, request, pk=None):
348+
user = self.get_object()
349+
serializer = self.get_request_serializer(data=request.data)
350+
if serializer.is_valid():
351+
user.set_password(serializer.validated_data['password'])
352+
user.save()
353+
return Response({'status': 'password set'})
354+
else:
355+
return Response(serializer.errors,
356+
status=status.HTTP_400_BAD_REQUEST)
357+
358+
@action(
359+
detail=True,
360+
methods=['get'],
361+
request_response_class=ExtendedUserSerializer
362+
)
363+
def complete_profile(self, request, pk=None):
364+
user = self.get_object()
365+
response_serializer = self.get_response_serializer(user)
366+
return Response(response_serializer.data)

rest_framework/generics.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ class GenericAPIView(views.APIView):
3333
# for all subsequent requests.
3434
queryset = None
3535
serializer_class = None
36+
request_serializer_class = None
37+
response_serializer_class = None
3638

3739
# If you want to use object lookups other than pk, set 'lookup_field'.
3840
# For more complex lookup requirements override `get_object()`.
@@ -109,6 +111,23 @@ def get_serializer(self, *args, **kwargs):
109111
kwargs.setdefault('context', self.get_serializer_context())
110112
return serializer_class(*args, **kwargs)
111113

114+
def get_request_serializer(self, *args, **kwargs):
115+
"""
116+
Return the serializer instance that should be used for validating and
117+
deserializing input.
118+
"""
119+
serializer_class = self.get_request_serializer_class()
120+
kwargs.setdefault('context', self.get_serializer_context())
121+
return serializer_class(*args, **kwargs)
122+
123+
def get_response_serializer(self, *args, **kwargs):
124+
"""
125+
Return the serializer instance that should be used for serializing output.
126+
"""
127+
serializer_class = self.get_response_serializer_class()
128+
kwargs.setdefault('context', self.get_serializer_context())
129+
return serializer_class(*args, **kwargs)
130+
112131
def get_serializer_class(self):
113132
"""
114133
Return the class to use for the serializer.
@@ -127,6 +146,18 @@ def get_serializer_class(self):
127146

128147
return self.serializer_class
129148

149+
def get_request_serializer_class(self):
150+
"""
151+
Return the class to use as input serializer.
152+
"""
153+
return self.request_serializer_class or self.get_serializer_class()
154+
155+
def get_response_serializer_class(self):
156+
"""
157+
Returns the class to use as output serializer.
158+
"""
159+
return self.response_serializer_class or self.get_serializer_class()
160+
130161
def get_serializer_context(self):
131162
"""
132163
Extra context provided to the serializer class.

rest_framework/schemas/openapi.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -627,14 +627,40 @@ def get_request_serializer(self, path, method):
627627
Override this method if your view uses a different serializer for
628628
handling request body.
629629
"""
630-
return self.get_serializer(path, method)
630+
view = self.view
631+
632+
if not hasattr(view, "get_request_serializer"):
633+
return self.get_serializer(path, method)
634+
635+
try:
636+
return view.get_request_serializer()
637+
except exceptions.APIException:
638+
warnings.warn(
639+
"{}.get_request_serializer() raised an exception during "
640+
"schema generation. Serializer fields will not be "
641+
"generated for {} {}.".format(view.__class__.__name__, method, path)
642+
)
643+
return None
631644

632645
def get_response_serializer(self, path, method):
633646
"""
634647
Override this method if your view uses a different serializer for
635648
populating response data.
636649
"""
637-
return self.get_serializer(path, method)
650+
view = self.view
651+
652+
if not hasattr(view, "get_response_serializer"):
653+
return self.get_serializer(path, method)
654+
655+
try:
656+
return view.get_response_serializer()
657+
except exceptions.APIException:
658+
warnings.warn(
659+
"{}.get_response_serializer() raised an exception during "
660+
"schema generation. Serializer fields will not be "
661+
"generated for {} {}.".format(view.__class__.__name__, method, path)
662+
)
663+
return None
638664

639665
def _get_reference(self, serializer):
640666
return {'$ref': '#/components/schemas/{}'.format(self.get_component_name(serializer))}

tests/test_generics.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,3 +692,45 @@ def list(self, request):
692692
serializer = response.serializer
693693

694694
assert serializer.context is context
695+
696+
def test_get_request_serializer_class(self):
697+
class View(generics.GenericAPIView):
698+
request_serializer_class = BasicSerializer
699+
700+
view = View()
701+
assert view.get_request_serializer_class() == BasicSerializer
702+
703+
def test_get_response_serializer_class(self):
704+
class TestResponseSerializerView(generics.GenericAPIView):
705+
response_serializer_class = BasicSerializer
706+
707+
view = TestResponseSerializerView()
708+
assert view.get_response_serializer_class() == BasicSerializer
709+
710+
def test_get_request_serializer(self):
711+
class View(generics.ListAPIView):
712+
request_serializer_class = BasicSerializer
713+
714+
def list(self, request):
715+
response = Response()
716+
response.serializer = self.get_request_serializer()
717+
return response
718+
719+
view = View.as_view()
720+
request = factory.get('/')
721+
response = view(request)
722+
assert isinstance(response.serializer, BasicSerializer)
723+
724+
def test_get_response_serializer(self):
725+
class View(generics.ListAPIView):
726+
response_serializer_class = BasicSerializer
727+
728+
def list(self, request):
729+
response = Response()
730+
response.serializer = self.get_response_serializer()
731+
return response
732+
733+
view = View.as_view()
734+
request = factory.get('/')
735+
response = view(request)
736+
assert isinstance(response.serializer, BasicSerializer)

0 commit comments

Comments
 (0)