diff --git a/docs/api-guide/schemas.md b/docs/api-guide/schemas.md index b4832b3690..acf2ecb932 100644 --- a/docs/api-guide/schemas.md +++ b/docs/api-guide/schemas.md @@ -375,6 +375,20 @@ operationIds. In order to work around this, you can override `get_operation_id_base()` to provide a different base for name part of the ID. +#### `get_serializer()` + +If the view has implemented `get_serializer()`, returns the result. + +#### `get_request_serializer()` + +By default returns `get_serializer()` but can be overridden to +differentiate between request and response objects. + +#### `get_response_serializer()` + +By default returns `get_serializer()` but can be overridden to +differentiate between request and response objects. + ### `AutoSchema.__init__()` kwargs `AutoSchema` provides a number of `__init__()` kwargs that can be used for diff --git a/rest_framework/schemas/openapi.py b/rest_framework/schemas/openapi.py index 4ecb7a65f1..5e9d59f8bf 100644 --- a/rest_framework/schemas/openapi.py +++ b/rest_framework/schemas/openapi.py @@ -192,15 +192,22 @@ def get_components(self, path, method): if method.lower() == 'delete': return {} - serializer = self.get_serializer(path, method) + request_serializer = self.get_request_serializer(path, method) + response_serializer = self.get_response_serializer(path, method) - if not isinstance(serializer, serializers.Serializer): - return {} + components = {} + + if isinstance(request_serializer, serializers.Serializer): + component_name = self.get_component_name(request_serializer) + content = self.map_serializer(request_serializer) + components.setdefault(component_name, content) - component_name = self.get_component_name(serializer) + if isinstance(response_serializer, serializers.Serializer): + component_name = self.get_component_name(response_serializer) + content = self.map_serializer(response_serializer) + components.setdefault(component_name, content) - content = self.map_serializer(serializer) - return {component_name: content} + return components def _to_camel_case(self, snake_str): components = snake_str.split('_') @@ -615,6 +622,20 @@ def get_serializer(self, path, method): .format(view.__class__.__name__, method, path)) return None + def get_request_serializer(self, path, method): + """ + Override this method if your view uses a different serializer for + handling request body. + """ + return self.get_serializer(path, method) + + def get_response_serializer(self, path, method): + """ + Override this method if your view uses a different serializer for + populating response data. + """ + return self.get_serializer(path, method) + def _get_reference(self, serializer): return {'$ref': '#/components/schemas/{}'.format(self.get_component_name(serializer))} @@ -624,7 +645,7 @@ def get_request_body(self, path, method): self.request_media_types = self.map_parsers(path, method) - serializer = self.get_serializer(path, method) + serializer = self.get_request_serializer(path, method) if not isinstance(serializer, serializers.Serializer): item_schema = {} @@ -648,7 +669,7 @@ def get_responses(self, path, method): self.response_media_types = self.map_renderers(path, method) - serializer = self.get_serializer(path, method) + serializer = self.get_response_serializer(path, method) if not isinstance(serializer, serializers.Serializer): item_schema = {} diff --git a/tests/schemas/test_openapi.py b/tests/schemas/test_openapi.py index 542c377b15..f7f4ccec37 100644 --- a/tests/schemas/test_openapi.py +++ b/tests/schemas/test_openapi.py @@ -711,6 +711,91 @@ def get_operation_id_base(self, path, method, action): operationId = inspector.get_operation_id(path, method) assert operationId == 'listItem' + def test_different_request_response_objects(self): + class RequestSerializer(serializers.Serializer): + text = serializers.CharField() + + class ResponseSerializer(serializers.Serializer): + text = serializers.BooleanField() + + class CustomSchema(AutoSchema): + def get_request_serializer(self, path, method): + return RequestSerializer() + + def get_response_serializer(self, path, method): + return ResponseSerializer() + + path = '/' + method = 'POST' + view = create_view( + views.ExampleGenericAPIView, + method, + create_request(path), + ) + inspector = CustomSchema() + inspector.view = view + + components = inspector.get_components(path, method) + assert components == { + 'Request': { + 'properties': { + 'text': { + 'type': 'string' + } + }, + 'required': ['text'], + 'type': 'object' + }, + 'Response': { + 'properties': { + 'text': { + 'type': 'boolean' + } + }, + 'required': ['text'], + 'type': 'object' + } + } + + operation = inspector.get_operation(path, method) + assert operation == { + 'operationId': 'createExample', + 'description': '', + 'parameters': [], + 'requestBody': { + 'content': { + 'application/json': { + 'schema': { + '$ref': '#/components/schemas/Request' + } + }, + 'application/x-www-form-urlencoded': { + 'schema': { + '$ref': '#/components/schemas/Request' + } + }, + 'multipart/form-data': { + 'schema': { + '$ref': '#/components/schemas/Request' + } + } + } + }, + 'responses': { + '201': { + 'content': { + 'application/json': { + 'schema': { + '$ref': '#/components/schemas/Response' + } + } + }, + 'description': '' + } + }, + 'tags': [''] + } + def test_repeat_operation_ids(self): router = routers.SimpleRouter() router.register('account', views.ExampleGenericViewSet, basename="account")