diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 82b7eb374f..a132b031b8 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -7,6 +7,7 @@ from django.utils.datastructures import SortedDict from django.utils.dateparse import parse_date, parse_datetime, parse_time from django.utils.encoding import is_protected_type +from django.utils.functional import cached_property from django.utils.translation import ugettext_lazy as _ from rest_framework import ISO_8601 from rest_framework.compat import ( @@ -233,12 +234,42 @@ def get_value(self, dictionary): return self.default_empty_html if (ret == '') else ret return dictionary.get(self.field_name, empty) + @cached_property + def single_source_attr(self): + return len(self.source_attrs) == 1 + + simple_getattr_map = {} + + def init_simple_getattr(self, instance): + assert len(self.source_attrs) == 1 + try: + val = getattr(instance, self.source_attrs[0]) + if is_simple_callable(val): + self.simple_getattr_map[instance.__class__] = False + else: + self.simple_getattr_map[instance.__class__] = True + except Exception: + self.simple_getattr_map[instance.__class__] = False + def get_attribute(self, instance): """ Given the *outgoing* object instance, return the primitive value that should be used for this field. """ - return get_attribute(instance, self.source_attrs) + simple_getattr = False + if self.single_source_attr: + try: + simple_getattr = self.simple_getattr_map[instance.__class__] + except KeyError: + self.init_simple_getattr(instance) + simple_getattr = self.simple_getattr_map[instance.__class__] + if simple_getattr: + try: + return getattr(instance, self.source_attrs[0]) + except ObjectDoesNotExist: + return None + else: + return get_attribute(instance, self.source_attrs) def get_default(self): """ diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 6aab020efd..d7b7a8b624 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -163,6 +163,12 @@ def __init__(self, *args, **kwargs): super(ReturnDict, self).__init__(*args, **kwargs) +class NonSortedReturnDict(dict): + def __init__(self, *args, **kwargs): + self.serialier = kwargs.pop('serializer') + super(NonSortedReturnDict, self).__init__(*args, **kwargs) + + class ReturnList(list): """ Return object from `serialier.data` for the `SerializerList` class. @@ -277,11 +283,13 @@ def _get_declared_fields(cls, bases, attrs): def __new__(cls, name, bases, attrs): attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs) + attrs['dict_class'] = ReturnDict if attrs.get('sorted_output', True) else NonSortedReturnDict return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs) @six.add_metaclass(SerializerMetaclass) class Serializer(BaseSerializer): + def __init__(self, *args, **kwargs): super(Serializer, self).__init__(*args, **kwargs) @@ -297,13 +305,13 @@ def _get_base_fields(self): def get_initial(self): if self._initial_data is not None: - return ReturnDict([ + return self.dict_class([ (field_name, field.get_value(self._initial_data)) for field_name, field in self.fields.items() if field.get_value(self._initial_data) is not empty ], serializer=self) - return ReturnDict([ + return self.dict_class([ (field.field_name, field.get_initial()) for field in self.fields.values() if not field.write_only @@ -368,7 +376,7 @@ def to_internal_value(self, data): Dict of native values <- Dict of primitive datatypes. """ ret = {} - errors = ReturnDict(serializer=self) + errors = self.dict_class(serializer=self) fields = [ field for field in self.fields.values() if (not field.read_only) or (field.default is not empty) @@ -393,20 +401,33 @@ def to_internal_value(self, data): return ret + @cached_property + def repr_fields(self): + return [f for f in self.fields.values() if not f.write_only] + + transform_map = {} + + def init_transform_map(self, field): + transform = getattr(self, 'transform_' + field.field_name, None) + self.transform_map[field.field_name] = transform + return transform + def to_representation(self, instance): """ Object instance -> Dict of primitive datatypes. """ - ret = ReturnDict(serializer=self) - fields = [field for field in self.fields.values() if not field.write_only] + ret = self.dict_class(serializer=self) - for field in fields: + for field in self.repr_fields: attribute = field.get_attribute(instance) if attribute is None: value = None else: value = field.to_representation(attribute) - transform_method = getattr(self, 'transform_' + field.field_name, None) + try: + transform_method = self.transform_map[field.field_name] + except KeyError: + transform_method = self.init_transform_map(field) if transform_method is not None: value = transform_method(value)