Skip to content

Commit f522555

Browse files
committed
[soc2010/query-refactor] Implemented count() (and by extension the Count() aggregate on the primary key).
git-svn-id: http://code.djangoproject.com/svn/django/branches/soc2010/query-refactor@13353 bcc190cf-cafb-0310-a4f2-bffc1f526a37
1 parent 8f441f0 commit f522555

File tree

5 files changed

+44
-6
lines changed

5 files changed

+44
-6
lines changed

django/contrib/mongodb/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ def flush(self, style, only_django=False):
4646
tables = self.connection.introspection.table_names()
4747
for table in tables:
4848
self.connection.db.drop_collection(table)
49+
50+
def check_aggregate_support(self, aggregate):
51+
# TODO: this really should use the generic aggregates, not the SQL ones
52+
from django.db.models.sql.aggregates import Count
53+
return isinstance(aggregate, Count)
4954

5055
class DatabaseWrapper(BaseDatabaseWrapper):
5156
def __init__(self, *args, **kwargs):

django/contrib/mongodb/compiler.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ def make_atom(self, lhs, lookup_type, value_annotation, params_or_value):
3232
column = "_id"
3333
return column, params[0]
3434

35-
def build_query(self):
36-
assert not self.query.aggregates
37-
assert len([a for a in self.query.alias_map if self.query.alias_refcount[a]]) == 1
38-
assert self.query.default_cols
35+
def build_query(self, aggregates=False):
36+
assert len([a for a in self.query.alias_map if self.query.alias_refcount[a]]) <= 1
37+
if not aggregates:
38+
assert self.query.default_cols
3939
assert not self.query.distinct
4040
assert not self.query.extra
4141
assert not self.query.having
@@ -60,6 +60,17 @@ def has_results(self):
6060
return False
6161
else:
6262
return True
63+
64+
def get_aggregates(self):
65+
assert len(self.query.aggregates) == 1
66+
agg = self.query.aggregates.values()[0]
67+
assert (
68+
isinstance(agg, self.query.aggregates_module.Count) and (
69+
agg.col == "*" or
70+
isinstance(agg.col, tuple) and agg.col == (self.query.model._meta.db_table, self.query.model._meta.pk.column)
71+
)
72+
)
73+
return [self.build_query(aggregates=True).count()]
6374

6475

6576
class SQLInsertCompiler(SQLCompiler):

django/db/models/sql/compiler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -675,7 +675,10 @@ def has_results(self):
675675
self.query.clear_ordering(True)
676676
self.query.set_limits(high=1)
677677
return bool(self.execute_sql(SINGLE))
678-
678+
679+
def get_aggregates(self):
680+
return self.execute_sql(SINGLE)
681+
679682
def results_iter(self):
680683
"""
681684
Returns an iterator over the results from executing this query.

django/db/models/sql/query.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def get_aggregation(self, using):
363363
query.related_select_cols = []
364364
query.related_select_fields = []
365365

366-
result = query.get_compiler(using).execute_sql(SINGLE)
366+
result = query.get_compiler(using).get_aggregates()
367367
if result is None:
368368
result = [None for q in query.aggregate_select.items()]
369369

tests/regressiontests/mongodb/tests.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from django.db.models import Count
12
from django.test import TestCase
23

34
from models import Artist
@@ -25,3 +26,21 @@ def test_update(self):
2526

2627
l = Artist.objects.get(pk=pk)
2728
self.assertTrue(not l.good)
29+
30+
def test_count(self):
31+
Artist.objects.create(name="Billy Joel", good=True)
32+
Artist.objects.create(name="John Mellencamp", good=True)
33+
Artist.objects.create(name="Warren Zevon", good=True)
34+
Artist.objects.create(name="Matisyahu", good=True)
35+
Artist.objects.create(name="Gary US Bonds", good=True)
36+
37+
self.assertEqual(Artist.objects.count(), 5)
38+
self.assertEqual(Artist.objects.filter(good=True).count(), 5)
39+
40+
Artist.objects.create(name="Bon Iver", good=False)
41+
42+
self.assertEqual(Artist.objects.count(), 6)
43+
self.assertEqual(Artist.objects.filter(good=True).count(), 5)
44+
self.assertEqual(Artist.objects.filter(good=False).count(), 1)
45+
46+
self.assertEqual(Artist.objects.aggregate(c=Count("pk")), {"c": 6})

0 commit comments

Comments
 (0)