@@ -252,7 +252,7 @@ def get_ordering(self, request, queryset, view):
252
252
params = request .query_params .get (self .ordering_param )
253
253
if params :
254
254
fields = [param .strip () for param in params .split (',' )]
255
- ordering = self .remove_invalid_fields (queryset , fields , view )
255
+ ordering = self .remove_invalid_fields (queryset , fields , view , request )
256
256
if ordering :
257
257
return ordering
258
258
@@ -265,7 +265,7 @@ def get_default_ordering(self, view):
265
265
return (ordering ,)
266
266
return ordering
267
267
268
- def get_default_valid_fields (self , queryset , view ):
268
+ def get_default_valid_fields (self , queryset , view , context = {} ):
269
269
# If `ordering_fields` is not specified, then we determine a default
270
270
# based on the serializer class, if one exists on the view.
271
271
if hasattr (view , 'get_serializer_class' ):
@@ -288,16 +288,16 @@ def get_default_valid_fields(self, queryset, view):
288
288
289
289
return [
290
290
(field .source or field_name , field .label )
291
- for field_name , field in serializer_class ().fields .items ()
291
+ for field_name , field in serializer_class (context = context ).fields .items ()
292
292
if not getattr (field , 'write_only' , False ) and not field .source == '*'
293
293
]
294
294
295
- def get_valid_fields (self , queryset , view ):
295
+ def get_valid_fields (self , queryset , view , context = {} ):
296
296
valid_fields = getattr (view , 'ordering_fields' , self .ordering_fields )
297
297
298
298
if valid_fields is None :
299
299
# Default to allowing filtering on serializer fields
300
- return self .get_default_valid_fields (queryset , view )
300
+ return self .get_default_valid_fields (queryset , view , context )
301
301
302
302
elif valid_fields == '__all__' :
303
303
# View explicitly allows filtering on any model field
@@ -316,8 +316,8 @@ def get_valid_fields(self, queryset, view):
316
316
317
317
return valid_fields
318
318
319
- def remove_invalid_fields (self , queryset , fields , view ):
320
- valid_fields = [item [0 ] for item in self .get_valid_fields (queryset , view )]
319
+ def remove_invalid_fields (self , queryset , fields , view , request ):
320
+ valid_fields = [item [0 ] for item in self .get_valid_fields (queryset , view , { 'request' : request } )]
321
321
return [term for term in fields if term .lstrip ('-' ) in valid_fields ]
322
322
323
323
def filter_queryset (self , request , queryset , view ):
@@ -332,15 +332,16 @@ def get_template_context(self, request, queryset, view):
332
332
current = self .get_ordering (request , queryset , view )
333
333
current = None if current is None else current [0 ]
334
334
options = []
335
- for key , label in self .get_valid_fields (queryset , view ):
336
- options .append ((key , '%s - %s' % (label , _ ('ascending' ))))
337
- options .append (('-' + key , '%s - %s' % (label , _ ('descending' ))))
338
- return {
335
+ context = {
339
336
'request' : request ,
340
337
'current' : current ,
341
338
'param' : self .ordering_param ,
342
- 'options' : options ,
343
339
}
340
+ for key , label in self .get_valid_fields (queryset , view , context ):
341
+ options .append ((key , '%s - %s' % (label , _ ('ascending' ))))
342
+ options .append (('-' + key , '%s - %s' % (label , _ ('descending' ))))
343
+ context ['options' ] = options
344
+ return context
344
345
345
346
def to_html (self , request , queryset , view ):
346
347
template = loader .get_template (self .template )
0 commit comments