Skip to content

Add distinction between request and response serializers for OpenAPI #7424

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
14 changes: 14 additions & 0 deletions docs/api-guide/schemas.md
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,20 @@ operationIds.
In order to work around this, you can override `get_operation_id_base()` to
provide a different base for name part of the ID.

#### `get_serializer()`

If the view has implemented `get_serializer()`, returns the result.

#### `get_request_serializer()`

By default returns `get_serializer()` but can be overridden to
differentiate between request and response objects.

#### `get_response_serializer()`

By default returns `get_serializer()` but can be overridden to
differentiate between request and response objects.

### `AutoSchema.__init__()` kwargs

`AutoSchema` provides a number of `__init__()` kwargs that can be used for
Expand Down
37 changes: 29 additions & 8 deletions rest_framework/schemas/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,15 +192,22 @@ def get_components(self, path, method):
if method.lower() == 'delete':
return {}

serializer = self.get_serializer(path, method)
request_serializer = self.get_request_serializer(path, method)
response_serializer = self.get_response_serializer(path, method)

if not isinstance(serializer, serializers.Serializer):
return {}
components = {}

if isinstance(request_serializer, serializers.Serializer):
component_name = self.get_component_name(request_serializer)
content = self.map_serializer(request_serializer)
components.setdefault(component_name, content)

component_name = self.get_component_name(serializer)
if isinstance(response_serializer, serializers.Serializer):
component_name = self.get_component_name(response_serializer)
content = self.map_serializer(response_serializer)
components.setdefault(component_name, content)

content = self.map_serializer(serializer)
return {component_name: content}
return components

def _to_camel_case(self, snake_str):
components = snake_str.split('_')
Expand Down Expand Up @@ -615,6 +622,20 @@ def get_serializer(self, path, method):
.format(view.__class__.__name__, method, path))
return None

def get_request_serializer(self, path, method):
"""
Override this method if your view uses a different serializer for
handling request body.
"""
return self.get_serializer(path, method)

def get_response_serializer(self, path, method):
"""
Override this method if your view uses a different serializer for
populating response data.
"""
return self.get_serializer(path, method)

def _get_reference(self, serializer):
return {'$ref': '#/components/schemas/{}'.format(self.get_component_name(serializer))}

Expand All @@ -624,7 +645,7 @@ def get_request_body(self, path, method):

self.request_media_types = self.map_parsers(path, method)

serializer = self.get_serializer(path, method)
serializer = self.get_request_serializer(path, method)

if not isinstance(serializer, serializers.Serializer):
item_schema = {}
Expand All @@ -648,7 +669,7 @@ def get_responses(self, path, method):

self.response_media_types = self.map_renderers(path, method)

serializer = self.get_serializer(path, method)
serializer = self.get_response_serializer(path, method)

if not isinstance(serializer, serializers.Serializer):
item_schema = {}
Expand Down
85 changes: 85 additions & 0 deletions tests/schemas/test_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,91 @@ def get_operation_id_base(self, path, method, action):
operationId = inspector.get_operation_id(path, method)
assert operationId == 'listItem'

def test_different_request_response_objects(self):
class RequestSerializer(serializers.Serializer):
text = serializers.CharField()

class ResponseSerializer(serializers.Serializer):
text = serializers.BooleanField()

class CustomSchema(AutoSchema):
def get_request_serializer(self, path, method):
return RequestSerializer()

def get_response_serializer(self, path, method):
return ResponseSerializer()

path = '/'
method = 'POST'
view = create_view(
views.ExampleGenericAPIView,
method,
create_request(path),
)
inspector = CustomSchema()
inspector.view = view

components = inspector.get_components(path, method)
assert components == {
'Request': {
'properties': {
'text': {
'type': 'string'
}
},
'required': ['text'],
'type': 'object'
},
'Response': {
'properties': {
'text': {
'type': 'boolean'
}
},
'required': ['text'],
'type': 'object'
}
}

operation = inspector.get_operation(path, method)
assert operation == {
'operationId': 'createExample',
'description': '',
'parameters': [],
'requestBody': {
'content': {
'application/json': {
'schema': {
'$ref': '#/components/schemas/Request'
}
},
'application/x-www-form-urlencoded': {
'schema': {
'$ref': '#/components/schemas/Request'
}
},
'multipart/form-data': {
'schema': {
'$ref': '#/components/schemas/Request'
}
}
}
},
'responses': {
'201': {
'content': {
'application/json': {
'schema': {
'$ref': '#/components/schemas/Response'
}
}
},
'description': ''
}
},
'tags': ['']
}

def test_repeat_operation_ids(self):
router = routers.SimpleRouter()
router.register('account', views.ExampleGenericViewSet, basename="account")
Expand Down