Skip to content

Commit e99b30d

Browse files
mammiquetomchristie
authored andcommitted
Fix rest_framework.filters.OrderingFilter doesn't pass context to ser… (encode#4543)
* Fix rest_framework.filters.OrderingFilter doesn't pass context to serializers encode#4541 * encode#4541 Additional fix for remove_invalid_fields()
1 parent 4ff9e96 commit e99b30d

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

rest_framework/filters.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def get_ordering(self, request, queryset, view):
252252
params = request.query_params.get(self.ordering_param)
253253
if params:
254254
fields = [param.strip() for param in params.split(',')]
255-
ordering = self.remove_invalid_fields(queryset, fields, view)
255+
ordering = self.remove_invalid_fields(queryset, fields, view, request)
256256
if ordering:
257257
return ordering
258258

@@ -265,7 +265,7 @@ def get_default_ordering(self, view):
265265
return (ordering,)
266266
return ordering
267267

268-
def get_default_valid_fields(self, queryset, view):
268+
def get_default_valid_fields(self, queryset, view, context={}):
269269
# If `ordering_fields` is not specified, then we determine a default
270270
# based on the serializer class, if one exists on the view.
271271
if hasattr(view, 'get_serializer_class'):
@@ -288,16 +288,16 @@ def get_default_valid_fields(self, queryset, view):
288288

289289
return [
290290
(field.source or field_name, field.label)
291-
for field_name, field in serializer_class().fields.items()
291+
for field_name, field in serializer_class(context=context).fields.items()
292292
if not getattr(field, 'write_only', False) and not field.source == '*'
293293
]
294294

295-
def get_valid_fields(self, queryset, view):
295+
def get_valid_fields(self, queryset, view, context={}):
296296
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
297297

298298
if valid_fields is None:
299299
# Default to allowing filtering on serializer fields
300-
return self.get_default_valid_fields(queryset, view)
300+
return self.get_default_valid_fields(queryset, view, context)
301301

302302
elif valid_fields == '__all__':
303303
# View explicitly allows filtering on any model field
@@ -316,8 +316,8 @@ def get_valid_fields(self, queryset, view):
316316

317317
return valid_fields
318318

319-
def remove_invalid_fields(self, queryset, fields, view):
320-
valid_fields = [item[0] for item in self.get_valid_fields(queryset, view)]
319+
def remove_invalid_fields(self, queryset, fields, view, request):
320+
valid_fields = [item[0] for item in self.get_valid_fields(queryset, view, {'request': request})]
321321
return [term for term in fields if term.lstrip('-') in valid_fields]
322322

323323
def filter_queryset(self, request, queryset, view):
@@ -332,15 +332,16 @@ def get_template_context(self, request, queryset, view):
332332
current = self.get_ordering(request, queryset, view)
333333
current = None if current is None else current[0]
334334
options = []
335-
for key, label in self.get_valid_fields(queryset, view):
336-
options.append((key, '%s - %s' % (label, _('ascending'))))
337-
options.append(('-' + key, '%s - %s' % (label, _('descending'))))
338-
return {
335+
context = {
339336
'request': request,
340337
'current': current,
341338
'param': self.ordering_param,
342-
'options': options,
343339
}
340+
for key, label in self.get_valid_fields(queryset, view, context):
341+
options.append((key, '%s - %s' % (label, _('ascending'))))
342+
options.append(('-' + key, '%s - %s' % (label, _('descending'))))
343+
context['options'] = options
344+
return context
344345

345346
def to_html(self, request, queryset, view):
346347
template = loader.get_template(self.template)

0 commit comments

Comments
 (0)