diff --git a/CHANGELOG.md b/CHANGELOG.md index c5d5e074..c3a9c9f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,8 @@ any parts of the framework not mentioned in the documentation should generally b * Allow to define `select_related` per include using [select_for_includes](https://django-rest-framework-json-api.readthedocs.io/en/stable/usage.html#performance-improvements) * Reduce number of queries to calculate includes by using `select_related` when possible +* Use REST framework serializer functionality to extract includes. This adds support like using + dotted notations in source attribute in `ResourceRelatedField`. ### Fixed diff --git a/example/tests/unit/test_renderers.py b/example/tests/unit/test_renderers.py index 6126cff0..1452aac8 100644 --- a/example/tests/unit/test_renderers.py +++ b/example/tests/unit/test_renderers.py @@ -10,9 +10,11 @@ # serializers class RelatedModelSerializer(serializers.ModelSerializer): + blog = serializers.ReadOnlyField(source='entry.blog') + class Meta: model = Comment - fields = ('id',) + fields = ('id', 'blog') class DummyTestSerializer(serializers.ModelSerializer): @@ -137,3 +139,13 @@ class EmptyRelationshipViewSet(views.ReadOnlyModelViewSet): assert 'relationships' in result['data'] assert 'bio' in result['data']['relationships'] assert result['data']['relationships']['bio'] == {'data': None} + + +@pytest.mark.django_db +def test_extract_relation_instance(comment): + serializer = RelatedModelSerializer(instance=comment) + + got = JSONRenderer.extract_relation_instance( + field=serializer.fields['blog'], resource_instance=comment + ) + assert got == comment.entry.blog diff --git a/rest_framework_json_api/renderers.py b/rest_framework_json_api/renderers.py index 0023399a..2d1430a7 100644 --- a/rest_framework_json_api/renderers.py +++ b/rest_framework_json_api/renderers.py @@ -8,6 +8,8 @@ from django.db.models import Manager from django.utils import encoding, six from rest_framework import relations, renderers +from rest_framework.fields import SkipField, get_attribute +from rest_framework.relations import PKOnlyObject from rest_framework.serializers import BaseSerializer, ListSerializer, Serializer from rest_framework.settings import api_settings @@ -297,34 +299,20 @@ def extract_relationships(cls, fields, resource, resource_instance): return utils._format_object(data) @classmethod - def extract_relation_instance(cls, field_name, field, resource_instance, serializer): + def extract_relation_instance(cls, field, resource_instance): """ Determines what instance represents given relation and extracts it. - Relation instance is determined by given field_name or source configured on - field. As fallback is a serializer method called with name of field's source. + Relation instance is determined exactly same way as it determined + in parent serializer """ - relation_instance = None - try: - relation_instance = getattr(resource_instance, field_name) - except AttributeError: - try: - # For ManyRelatedFields if `related_name` is not set - # we need to access `foo_set` from `source` - relation_instance = getattr(resource_instance, field.child_relation.source) - except AttributeError: - if hasattr(serializer, field.source): - serializer_method = getattr(serializer, field.source) - relation_instance = serializer_method(resource_instance) - else: - # case when source is a simple remap on resource_instance - try: - relation_instance = getattr(resource_instance, field.source) - except AttributeError: - pass - - return relation_instance + res = field.get_attribute(resource_instance) + if isinstance(res, PKOnlyObject): + return get_attribute(resource_instance, field.source_attrs) + return res + except SkipField: + return None @classmethod def extract_included(cls, fields, resource, resource_instance, included_resources, @@ -363,7 +351,7 @@ def extract_included(cls, fields, resource, resource_instance, included_resource continue relation_instance = cls.extract_relation_instance( - field_name, field, resource_instance, current_serializer + field, resource_instance ) if isinstance(relation_instance, Manager): relation_instance = relation_instance.all()