Skip to content

Allow custom serializer class for request and response #8347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions docs/api-guide/viewsets.md
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,8 @@ As with `ModelViewSet`, you'll normally need to provide at least the `queryset`

Again, as with `ModelViewSet`, you can use any of the standard attributes and method overrides available to `GenericAPIView`.

---

# Custom ViewSet base classes

You may need to provide custom `ViewSet` classes that do not have the full set of `ModelViewSet` actions, or that customize the behavior in some other way.
Expand All @@ -321,3 +323,44 @@ By creating your own base `ViewSet` classes, you can provide common behavior tha

[cite]: https://guides.rubyonrails.org/action_controller_overview.html
[routers]: routers.md

---

# Custom serializer for request and response

It is possible to define at the view level (or for each custom method via the @action decorator) a custom serialization class for each request and response.
To do this you need to define `request_serializer_response` and `response_serializer_response` and call them via `get_request_serializer` and `get_response_serializer`.


class UserViewSet(viewsets.ModelViewSet):
"""
A viewset that provides the standard actions
"""
queryset = User.objects.all()
serializer_class = UserSerializer

@action(
detail=True,
methods=['post'],
request_serializer_class=PasswordSerializer
)
def set_password(self, request, pk=None):
user = self.get_object()
serializer = self.get_request_serializer(data=request.data)
if serializer.is_valid():
user.set_password(serializer.validated_data['password'])
user.save()
return Response({'status': 'password set'})
else:
return Response(serializer.errors,
status=status.HTTP_400_BAD_REQUEST)

@action(
detail=True,
methods=['get'],
request_response_class=ExtendedUserSerializer
)
def complete_profile(self, request, pk=None):
user = self.get_object()
response_serializer = self.get_response_serializer(user)
return Response(response_serializer.data)
31 changes: 31 additions & 0 deletions rest_framework/generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class GenericAPIView(views.APIView):
# for all subsequent requests.
queryset = None
serializer_class = None
request_serializer_class = None
response_serializer_class = None

# If you want to use object lookups other than pk, set 'lookup_field'.
# For more complex lookup requirements override `get_object()`.
Expand Down Expand Up @@ -109,6 +111,23 @@ def get_serializer(self, *args, **kwargs):
kwargs.setdefault('context', self.get_serializer_context())
return serializer_class(*args, **kwargs)

def get_request_serializer(self, *args, **kwargs):
"""
Return the serializer instance that should be used for validating and
deserializing input.
"""
serializer_class = self.get_request_serializer_class()
kwargs.setdefault('context', self.get_serializer_context())
return serializer_class(*args, **kwargs)

def get_response_serializer(self, *args, **kwargs):
"""
Return the serializer instance that should be used for serializing output.
"""
serializer_class = self.get_response_serializer_class()
kwargs.setdefault('context', self.get_serializer_context())
return serializer_class(*args, **kwargs)

def get_serializer_class(self):
"""
Return the class to use for the serializer.
Expand All @@ -127,6 +146,18 @@ def get_serializer_class(self):

return self.serializer_class

def get_request_serializer_class(self):
"""
Return the class to use as input serializer.
"""
return self.request_serializer_class or self.get_serializer_class()

def get_response_serializer_class(self):
"""
Returns the class to use as output serializer.
"""
return self.response_serializer_class or self.get_serializer_class()

def get_serializer_context(self):
"""
Extra context provided to the serializer class.
Expand Down
30 changes: 28 additions & 2 deletions rest_framework/schemas/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,14 +627,40 @@ def get_request_serializer(self, path, method):
Override this method if your view uses a different serializer for
handling request body.
"""
return self.get_serializer(path, method)
view = self.view

if not hasattr(view, "get_request_serializer"):
return self.get_serializer(path, method)

try:
return view.get_request_serializer()
except exceptions.APIException:
warnings.warn(
"{}.get_request_serializer() raised an exception during "
"schema generation. Serializer fields will not be "
"generated for {} {}.".format(view.__class__.__name__, method, path)
)
return None

def get_response_serializer(self, path, method):
"""
Override this method if your view uses a different serializer for
populating response data.
"""
return self.get_serializer(path, method)
view = self.view

if not hasattr(view, "get_response_serializer"):
return self.get_serializer(path, method)

try:
return view.get_response_serializer()
except exceptions.APIException:
warnings.warn(
"{}.get_response_serializer() raised an exception during "
"schema generation. Serializer fields will not be "
"generated for {} {}.".format(view.__class__.__name__, method, path)
)
return None

def _get_reference(self, serializer):
return {'$ref': '#/components/schemas/{}'.format(self.get_component_name(serializer))}
Expand Down
42 changes: 42 additions & 0 deletions tests/test_generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,3 +692,45 @@ def list(self, request):
serializer = response.serializer

assert serializer.context is context

def test_get_request_serializer_class(self):
class View(generics.GenericAPIView):
request_serializer_class = BasicSerializer

view = View()
assert view.get_request_serializer_class() == BasicSerializer

def test_get_response_serializer_class(self):
class TestResponseSerializerView(generics.GenericAPIView):
response_serializer_class = BasicSerializer

view = TestResponseSerializerView()
assert view.get_response_serializer_class() == BasicSerializer

def test_get_request_serializer(self):
class View(generics.ListAPIView):
request_serializer_class = BasicSerializer

def list(self, request):
response = Response()
response.serializer = self.get_request_serializer()
return response

view = View.as_view()
request = factory.get('/')
response = view(request)
assert isinstance(response.serializer, BasicSerializer)

def test_get_response_serializer(self):
class View(generics.ListAPIView):
response_serializer_class = BasicSerializer

def list(self, request):
response = Response()
response.serializer = self.get_response_serializer()
return response

view = View.as_view()
request = factory.get('/')
response = view(request)
assert isinstance(response.serializer, BasicSerializer)