Skip to content

Commit b78d100

Browse files
orftimgraham
authored andcommitted
Fixed #27849 -- Added filtering support to aggregates.
1 parent 489421b commit b78d100

File tree

13 files changed

+290
-55
lines changed

13 files changed

+290
-55
lines changed

django/contrib/postgres/aggregates/statistics.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88

99

1010
class StatAggregate(Aggregate):
11-
def __init__(self, y, x, output_field=FloatField()):
11+
def __init__(self, y, x, output_field=FloatField(), filter=None):
1212
if not x or not y:
1313
raise ValueError('Both y and x must be provided.')
14-
super().__init__(y, x, output_field=output_field)
14+
super().__init__(y, x, output_field=output_field, filter=filter)
1515

1616
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
1717
return super().resolve_expression(query, allow_joins, reuse, summarize)
@@ -22,9 +22,9 @@ class Corr(StatAggregate):
2222

2323

2424
class CovarPop(StatAggregate):
25-
def __init__(self, y, x, sample=False):
25+
def __init__(self, y, x, sample=False, filter=None):
2626
self.function = 'COVAR_SAMP' if sample else 'COVAR_POP'
27-
super().__init__(y, x)
27+
super().__init__(y, x, filter=filter)
2828

2929

3030
class RegrAvgX(StatAggregate):
@@ -38,8 +38,8 @@ class RegrAvgY(StatAggregate):
3838
class RegrCount(StatAggregate):
3939
function = 'REGR_COUNT'
4040

41-
def __init__(self, y, x):
42-
super().__init__(y=y, x=x, output_field=IntegerField())
41+
def __init__(self, y, x, filter=None):
42+
super().__init__(y=y, x=x, output_field=IntegerField(), filter=filter)
4343

4444
def convert_value(self, value, expression, connection):
4545
if value is None:

django/db/backends/base/features.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,10 @@ class BaseDatabaseFeatures:
229229
supports_select_difference = True
230230
supports_slicing_ordering_in_compound = False
231231

232+
# Does the database support SQL 2003 FILTER (WHERE ...) in aggregate
233+
# expressions?
234+
supports_aggregate_filter_clause = False
235+
232236
# Does the backend support indexing a TextField?
233237
supports_index_on_text_field = True
234238

django/db/backends/postgresql/features.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ class DatabaseFeatures(BaseDatabaseFeatures):
5050
END;
5151
$$ LANGUAGE plpgsql;"""
5252

53+
@cached_property
54+
def supports_aggregate_filter_clause(self):
55+
return self.connection.pg_version >= 90400
56+
5357
@cached_property
5458
def has_select_for_update_skip_locked(self):
5559
return self.connection.pg_version >= 90500

django/db/models/aggregates.py

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
Classes to represent the definitions of aggregate functions.
33
"""
44
from django.core.exceptions import FieldError
5-
from django.db.models.expressions import Func, Star
5+
from django.db.models.expressions import Case, Func, Star, When
66
from django.db.models.fields import DecimalField, FloatField, IntegerField
7+
from django.db.models.query_utils import Q
78

89
__all__ = [
910
'Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance',
@@ -13,12 +14,36 @@
1314
class Aggregate(Func):
1415
contains_aggregate = True
1516
name = None
17+
filter_template = '%s FILTER (WHERE %%(filter)s)'
18+
19+
def __init__(self, *args, filter=None, **kwargs):
20+
self.filter = filter
21+
super().__init__(*args, **kwargs)
22+
23+
def get_source_fields(self):
24+
# Don't return the filter expression since it's not a source field.
25+
return [e._output_field_or_none for e in super().get_source_expressions()]
26+
27+
def get_source_expressions(self):
28+
source_expressions = super().get_source_expressions()
29+
if self.filter:
30+
source_expressions += [self.filter]
31+
return source_expressions
32+
33+
def set_source_expressions(self, exprs):
34+
if self.filter:
35+
self.filter = exprs.pop()
36+
return super().set_source_expressions(exprs)
1637

1738
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
1839
# Aggregates are not allowed in UPDATE queries, so ignore for_save
1940
c = super().resolve_expression(query, allow_joins, reuse, summarize)
41+
if c.filter:
42+
c.filter = c.filter.resolve_expression(query, allow_joins, reuse, summarize)
2043
if not summarize:
21-
expressions = c.get_source_expressions()
44+
# Call Aggregate.get_source_expressions() to avoid
45+
# returning self.filter and including that in this loop.
46+
expressions = super(Aggregate, c).get_source_expressions()
2247
for index, expr in enumerate(expressions):
2348
if expr.contains_aggregate:
2449
before_resolved = self.get_source_expressions()[index]
@@ -36,6 +61,29 @@ def default_alias(self):
3661
def get_group_by_cols(self):
3762
return []
3863

64+
def as_sql(self, compiler, connection, **extra_context):
65+
if self.filter:
66+
if connection.features.supports_aggregate_filter_clause:
67+
filter_sql, filter_params = self.filter.as_sql(compiler, connection)
68+
template = self.filter_template % extra_context.get('template', self.template)
69+
sql, params = super().as_sql(compiler, connection, template=template, filter=filter_sql)
70+
return sql, params + filter_params
71+
else:
72+
copy = self.copy()
73+
copy.filter = None
74+
condition = When(Q())
75+
source_expressions = copy.get_source_expressions()
76+
condition.set_source_expressions([self.filter, source_expressions[0]])
77+
copy.set_source_expressions([Case(condition)] + source_expressions[1:])
78+
return super(Aggregate, copy).as_sql(compiler, connection, **extra_context)
79+
return super().as_sql(compiler, connection, **extra_context)
80+
81+
def _get_repr_options(self):
82+
options = super()._get_repr_options()
83+
if self.filter:
84+
options.update({'filter': self.filter})
85+
return options
86+
3987

4088
class Avg(Aggregate):
4189
function = 'AVG'
@@ -52,7 +100,7 @@ def as_oracle(self, compiler, connection):
52100
expression = self.get_source_expressions()[0]
53101
from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
54102
return compiler.compile(
55-
SecondsToInterval(Avg(IntervalToSeconds(expression)))
103+
SecondsToInterval(Avg(IntervalToSeconds(expression), filter=self.filter))
56104
)
57105
return super().as_sql(compiler, connection)
58106

@@ -62,16 +110,19 @@ class Count(Aggregate):
62110
name = 'Count'
63111
template = '%(function)s(%(distinct)s%(expressions)s)'
64112

65-
def __init__(self, expression, distinct=False, **extra):
113+
def __init__(self, expression, distinct=False, filter=None, **extra):
66114
if expression == '*':
67115
expression = Star()
116+
if isinstance(expression, Star) and filter is not None:
117+
raise ValueError('Star cannot be used with filter. Please specify a field.')
68118
super().__init__(
69119
expression, distinct='DISTINCT ' if distinct else '',
70-
output_field=IntegerField(), **extra
120+
output_field=IntegerField(), filter=filter, **extra
71121
)
72122

73123
def _get_repr_options(self):
74-
return {'distinct': self.extra['distinct'] != ''}
124+
options = super()._get_repr_options()
125+
return dict(options, distinct=self.extra['distinct'] != '')
75126

76127
def convert_value(self, value, expression, connection):
77128
if value is None:
@@ -97,7 +148,8 @@ def __init__(self, expression, sample=False, **extra):
97148
super().__init__(expression, output_field=FloatField(), **extra)
98149

99150
def _get_repr_options(self):
100-
return {'sample': self.function == 'STDDEV_SAMP'}
151+
options = super()._get_repr_options()
152+
return dict(options, sample=self.function == 'STDDEV_SAMP')
101153

102154
def convert_value(self, value, expression, connection):
103155
if value is None:
@@ -127,7 +179,8 @@ def __init__(self, expression, sample=False, **extra):
127179
super().__init__(expression, output_field=FloatField(), **extra)
128180

129181
def _get_repr_options(self):
130-
return {'sample': self.function == 'VAR_SAMP'}
182+
options = super()._get_repr_options()
183+
return dict(options, sample=self.function == 'VAR_SAMP')
131184

132185
def convert_value(self, value, expression, connection):
133186
if value is None:

docs/ref/contrib/postgres/aggregates.txt

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ General-purpose aggregation functions
2222
``ArrayAgg``
2323
------------
2424

25-
.. class:: ArrayAgg(expression, distinct=False, **extra)
25+
.. class:: ArrayAgg(expression, distinct=False, filter=None, **extra)
2626

2727
Returns a list of values, including nulls, concatenated into an array.
2828

@@ -36,39 +36,39 @@ General-purpose aggregation functions
3636
``BitAnd``
3737
----------
3838

39-
.. class:: BitAnd(expression, **extra)
39+
.. class:: BitAnd(expression, filter=None, **extra)
4040

4141
Returns an ``int`` of the bitwise ``AND`` of all non-null input values, or
4242
``None`` if all values are null.
4343

4444
``BitOr``
4545
---------
4646

47-
.. class:: BitOr(expression, **extra)
47+
.. class:: BitOr(expression, filter=None, **extra)
4848

4949
Returns an ``int`` of the bitwise ``OR`` of all non-null input values, or
5050
``None`` if all values are null.
5151

5252
``BoolAnd``
5353
-----------
5454

55-
.. class:: BoolAnd(expression, **extra)
55+
.. class:: BoolAnd(expression, filter=None, **extra)
5656

5757
Returns ``True``, if all input values are true, ``None`` if all values are
5858
null or if there are no values, otherwise ``False`` .
5959

6060
``BoolOr``
6161
----------
6262

63-
.. class:: BoolOr(expression, **extra)
63+
.. class:: BoolOr(expression, filter=None, **extra)
6464

6565
Returns ``True`` if at least one input value is true, ``None`` if all
6666
values are null or if there are no values, otherwise ``False``.
6767

6868
``JSONBAgg``
6969
------------
7070

71-
.. class:: JSONBAgg(expressions, **extra)
71+
.. class:: JSONBAgg(expressions, filter=None, **extra)
7272

7373
.. versionadded:: 1.11
7474

@@ -77,7 +77,7 @@ General-purpose aggregation functions
7777
``StringAgg``
7878
-------------
7979

80-
.. class:: StringAgg(expression, delimiter, distinct=False)
80+
.. class:: StringAgg(expression, delimiter, distinct=False, filter=None)
8181

8282
Returns the input values concatenated into a string, separated by
8383
the ``delimiter`` string.
@@ -105,15 +105,15 @@ field or an expression returning a numeric data. Both are required.
105105
``Corr``
106106
--------
107107

108-
.. class:: Corr(y, x)
108+
.. class:: Corr(y, x, filter=None)
109109

110110
Returns the correlation coefficient as a ``float``, or ``None`` if there
111111
aren't any matching rows.
112112

113113
``CovarPop``
114114
------------
115115

116-
.. class:: CovarPop(y, x, sample=False)
116+
.. class:: CovarPop(y, x, sample=False, filter=None)
117117

118118
Returns the population covariance as a ``float``, or ``None`` if there
119119
aren't any matching rows.
@@ -129,31 +129,31 @@ field or an expression returning a numeric data. Both are required.
129129
``RegrAvgX``
130130
------------
131131

132-
.. class:: RegrAvgX(y, x)
132+
.. class:: RegrAvgX(y, x, filter=None)
133133

134134
Returns the average of the independent variable (``sum(x)/N``) as a
135135
``float``, or ``None`` if there aren't any matching rows.
136136

137137
``RegrAvgY``
138138
------------
139139

140-
.. class:: RegrAvgY(y, x)
140+
.. class:: RegrAvgY(y, x, filter=None)
141141

142142
Returns the average of the dependent variable (``sum(y)/N``) as a
143143
``float``, or ``None`` if there aren't any matching rows.
144144

145145
``RegrCount``
146146
-------------
147147

148-
.. class:: RegrCount(y, x)
148+
.. class:: RegrCount(y, x, filter=None)
149149

150150
Returns an ``int`` of the number of input rows in which both expressions
151151
are not null.
152152

153153
``RegrIntercept``
154154
-----------------
155155

156-
.. class:: RegrIntercept(y, x)
156+
.. class:: RegrIntercept(y, x, filter=None)
157157

158158
Returns the y-intercept of the least-squares-fit linear equation determined
159159
by the ``(x, y)`` pairs as a ``float``, or ``None`` if there aren't any
@@ -162,15 +162,15 @@ field or an expression returning a numeric data. Both are required.
162162
``RegrR2``
163163
----------
164164

165-
.. class:: RegrR2(y, x)
165+
.. class:: RegrR2(y, x, filter=None)
166166

167167
Returns the square of the correlation coefficient as a ``float``, or
168168
``None`` if there aren't any matching rows.
169169

170170
``RegrSlope``
171171
-------------
172172

173-
.. class:: RegrSlope(y, x)
173+
.. class:: RegrSlope(y, x, filter=None)
174174

175175
Returns the slope of the least-squares-fit linear equation determined
176176
by the ``(x, y)`` pairs as a ``float``, or ``None`` if there aren't any
@@ -179,15 +179,15 @@ field or an expression returning a numeric data. Both are required.
179179
``RegrSXX``
180180
-----------
181181

182-
.. class:: RegrSXX(y, x)
182+
.. class:: RegrSXX(y, x, filter=None)
183183

184184
Returns ``sum(x^2) - sum(x)^2/N`` ("sum of squares" of the independent
185185
variable) as a ``float``, or ``None`` if there aren't any matching rows.
186186

187187
``RegrSXY``
188188
-----------
189189

190-
.. class:: RegrSXY(y, x)
190+
.. class:: RegrSXY(y, x, filter=None)
191191

192192
Returns ``sum(x*y) - sum(x) * sum(y)/N`` ("sum of products" of independent
193193
times dependent variable) as a ``float``, or ``None`` if there aren't any
@@ -196,7 +196,7 @@ field or an expression returning a numeric data. Both are required.
196196
``RegrSYY``
197197
-----------
198198

199-
.. class:: RegrSYY(y, x)
199+
.. class:: RegrSYY(y, x, filter=None)
200200

201201
Returns ``sum(y^2) - sum(y)^2/N`` ("sum of squares" of the dependent
202202
variable) as a ``float``, or ``None`` if there aren't any matching rows.

0 commit comments

Comments
 (0)