diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 1104aa29c1..47a4923a19 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -68,6 +68,13 @@ def update(self, request, *args, **kwargs): serializer = self.get_serializer(instance, data=request.data, partial=partial) serializer.is_valid(raise_exception=True) self.perform_update(serializer) + + if getattr(instance, '_prefetched_objects_cache', None): + # If 'prefetch_related' has been applied to a queryset, we need to + # refresh the instance from the database. + instance = self.get_object() + serializer = self.get_serializer(instance) + return Response(serializer.data) def perform_update(self, serializer): diff --git a/tests/test_prefetch_related.py b/tests/test_prefetch_related.py new file mode 100644 index 0000000000..fc697adc1a --- /dev/null +++ b/tests/test_prefetch_related.py @@ -0,0 +1,41 @@ +from django.contrib.auth.models import Group, User +from django.test import TestCase + +from rest_framework import generics, serializers +from rest_framework.test import APIRequestFactory + +factory = APIRequestFactory() + + +class UserSerializer(serializers.ModelSerializer): + class Meta: + model = User + fields = ('id', 'username', 'email', 'groups') + + +class UserUpdate(generics.UpdateAPIView): + queryset = User.objects.all().prefetch_related('groups') + serializer_class = UserSerializer + + +class TestPrefetchRelatedUpdates(TestCase): + def setUp(self): + self.user = User.objects.create(username='tom', email='tom@example.com') + self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')] + self.user.groups = self.groups + self.user.save() + + def test_prefetch_related_updates(self): + view = UserUpdate.as_view() + pk = self.user.pk + groups_pk = self.groups[0].pk + request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json') + response = view(request, pk=pk) + assert User.objects.get(pk=pk).groups.count() == 1 + expected = { + 'id': pk, + 'username': 'new', + 'groups': [1], + 'email': 'tom@example.com' + } + assert response.data == expected