From ff075ef82790538fb785f4b675607dedd3a2a78f Mon Sep 17 00:00:00 2001 From: khamaileon Date: Mon, 18 Apr 2022 09:25:55 +0200 Subject: [PATCH] Allow custom serializer class for request and response --- docs/api-guide/viewsets.md | 43 +++++++++++++++++++++++++++++++ rest_framework/generics.py | 31 ++++++++++++++++++++++ rest_framework/schemas/openapi.py | 30 +++++++++++++++++++-- tests/test_generics.py | 42 ++++++++++++++++++++++++++++++ 4 files changed, 144 insertions(+), 2 deletions(-) diff --git a/docs/api-guide/viewsets.md b/docs/api-guide/viewsets.md index 4179725078..d8c410a493 100644 --- a/docs/api-guide/viewsets.md +++ b/docs/api-guide/viewsets.md @@ -295,6 +295,8 @@ As with `ModelViewSet`, you'll normally need to provide at least the `queryset` Again, as with `ModelViewSet`, you can use any of the standard attributes and method overrides available to `GenericAPIView`. +--- + # Custom ViewSet base classes You may need to provide custom `ViewSet` classes that do not have the full set of `ModelViewSet` actions, or that customize the behavior in some other way. @@ -321,3 +323,44 @@ By creating your own base `ViewSet` classes, you can provide common behavior tha [cite]: https://guides.rubyonrails.org/action_controller_overview.html [routers]: routers.md + +--- + +# Custom serializer for request and response + +It is possible to define at the view level (or for each custom method via the @action decorator) a custom serialization class for each request and response. +To do this you need to define `request_serializer_response` and `response_serializer_response` and call them via `get_request_serializer` and `get_response_serializer`. + + + class UserViewSet(viewsets.ModelViewSet): + """ + A viewset that provides the standard actions + """ + queryset = User.objects.all() + serializer_class = UserSerializer + + @action( + detail=True, + methods=['post'], + request_serializer_class=PasswordSerializer + ) + def set_password(self, request, pk=None): + user = self.get_object() + serializer = self.get_request_serializer(data=request.data) + if serializer.is_valid(): + user.set_password(serializer.validated_data['password']) + user.save() + return Response({'status': 'password set'}) + else: + return Response(serializer.errors, + status=status.HTTP_400_BAD_REQUEST) + + @action( + detail=True, + methods=['get'], + request_response_class=ExtendedUserSerializer + ) + def complete_profile(self, request, pk=None): + user = self.get_object() + response_serializer = self.get_response_serializer(user) + return Response(response_serializer.data) diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 55cfafda44..3de280aa9e 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -33,6 +33,8 @@ class GenericAPIView(views.APIView): # for all subsequent requests. queryset = None serializer_class = None + request_serializer_class = None + response_serializer_class = None # If you want to use object lookups other than pk, set 'lookup_field'. # For more complex lookup requirements override `get_object()`. @@ -109,6 +111,23 @@ def get_serializer(self, *args, **kwargs): kwargs.setdefault('context', self.get_serializer_context()) return serializer_class(*args, **kwargs) + def get_request_serializer(self, *args, **kwargs): + """ + Return the serializer instance that should be used for validating and + deserializing input. + """ + serializer_class = self.get_request_serializer_class() + kwargs.setdefault('context', self.get_serializer_context()) + return serializer_class(*args, **kwargs) + + def get_response_serializer(self, *args, **kwargs): + """ + Return the serializer instance that should be used for serializing output. + """ + serializer_class = self.get_response_serializer_class() + kwargs.setdefault('context', self.get_serializer_context()) + return serializer_class(*args, **kwargs) + def get_serializer_class(self): """ Return the class to use for the serializer. @@ -127,6 +146,18 @@ def get_serializer_class(self): return self.serializer_class + def get_request_serializer_class(self): + """ + Return the class to use as input serializer. + """ + return self.request_serializer_class or self.get_serializer_class() + + def get_response_serializer_class(self): + """ + Returns the class to use as output serializer. + """ + return self.response_serializer_class or self.get_serializer_class() + def get_serializer_context(self): """ Extra context provided to the serializer class. diff --git a/rest_framework/schemas/openapi.py b/rest_framework/schemas/openapi.py index 5e9d59f8bf..27a2e91326 100644 --- a/rest_framework/schemas/openapi.py +++ b/rest_framework/schemas/openapi.py @@ -627,14 +627,40 @@ def get_request_serializer(self, path, method): Override this method if your view uses a different serializer for handling request body. """ - return self.get_serializer(path, method) + view = self.view + + if not hasattr(view, "get_request_serializer"): + return self.get_serializer(path, method) + + try: + return view.get_request_serializer() + except exceptions.APIException: + warnings.warn( + "{}.get_request_serializer() raised an exception during " + "schema generation. Serializer fields will not be " + "generated for {} {}.".format(view.__class__.__name__, method, path) + ) + return None def get_response_serializer(self, path, method): """ Override this method if your view uses a different serializer for populating response data. """ - return self.get_serializer(path, method) + view = self.view + + if not hasattr(view, "get_response_serializer"): + return self.get_serializer(path, method) + + try: + return view.get_response_serializer() + except exceptions.APIException: + warnings.warn( + "{}.get_response_serializer() raised an exception during " + "schema generation. Serializer fields will not be " + "generated for {} {}.".format(view.__class__.__name__, method, path) + ) + return None def _get_reference(self, serializer): return {'$ref': '#/components/schemas/{}'.format(self.get_component_name(serializer))} diff --git a/tests/test_generics.py b/tests/test_generics.py index 2907d27733..27ab6eb4d0 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -692,3 +692,45 @@ def list(self, request): serializer = response.serializer assert serializer.context is context + + def test_get_request_serializer_class(self): + class View(generics.GenericAPIView): + request_serializer_class = BasicSerializer + + view = View() + assert view.get_request_serializer_class() == BasicSerializer + + def test_get_response_serializer_class(self): + class TestResponseSerializerView(generics.GenericAPIView): + response_serializer_class = BasicSerializer + + view = TestResponseSerializerView() + assert view.get_response_serializer_class() == BasicSerializer + + def test_get_request_serializer(self): + class View(generics.ListAPIView): + request_serializer_class = BasicSerializer + + def list(self, request): + response = Response() + response.serializer = self.get_request_serializer() + return response + + view = View.as_view() + request = factory.get('/') + response = view(request) + assert isinstance(response.serializer, BasicSerializer) + + def test_get_response_serializer(self): + class View(generics.ListAPIView): + response_serializer_class = BasicSerializer + + def list(self, request): + response = Response() + response.serializer = self.get_response_serializer() + return response + + view = View.as_view() + request = factory.get('/') + response = view(request) + assert isinstance(response.serializer, BasicSerializer)