Skip to content

Commit 17dce75

Browse files
jacobtylerwallssarahboyce
authored andcommitted
Refs #36430, Refs #36416 -- Simplified batch size calculation in QuerySet.in_bulk().
1 parent 5efef36 commit 17dce75

File tree

3 files changed

+22
-21
lines changed

3 files changed

+22
-21
lines changed

django/db/models/query.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,10 +1182,8 @@ def in_bulk(self, id_list=None, *, field_name="pk"):
11821182
if not id_list:
11831183
return {}
11841184
filter_key = "{}__in".format(field_name)
1185-
max_params = connections[self.db].features.max_query_params or 0
1186-
num_fields = len(opts.pk_fields) if field_name == "pk" else 1
1187-
batch_size = max_params // num_fields
11881185
id_list = tuple(id_list)
1186+
batch_size = connections[self.db].ops.bulk_batch_size([opts.pk], id_list)
11891187
# If the database has a limit on the number of query parameters
11901188
# (e.g. SQLite), retrieve objects in batches if necessary.
11911189
if batch_size and batch_size < len(id_list):

tests/composite_pk/tests.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -147,20 +147,24 @@ def test_in_bulk(self):
147147
result = Comment.objects.in_bulk([self.comment.pk])
148148
self.assertEqual(result, {self.comment.pk: self.comment})
149149

150-
@unittest.mock.patch.object(
151-
type(connection.features), "max_query_params", new_callable=lambda: 10
152-
)
153-
def test_in_bulk_batching(self, mocked_max_query_params):
150+
def test_in_bulk_batching(self):
154151
Comment.objects.all().delete()
155-
num_requiring_batching = (connection.features.max_query_params // 2) + 1
156-
comments = [
157-
Comment(id=i, tenant=self.tenant, user=self.user)
158-
for i in range(1, num_requiring_batching + 1)
159-
]
160-
Comment.objects.bulk_create(comments)
161-
id_list = list(Comment.objects.values_list("pk", flat=True))
162-
with self.assertNumQueries(2):
163-
comment_dict = Comment.objects.in_bulk(id_list=id_list)
152+
batching_required = connection.features.max_query_params is not None
153+
expected_queries = 2 if batching_required else 1
154+
with unittest.mock.patch.object(
155+
type(connection.features), "max_query_params", 10
156+
):
157+
num_requiring_batching = (
158+
connection.ops.bulk_batch_size([Comment._meta.pk], []) + 1
159+
)
160+
comments = [
161+
Comment(id=i, tenant=self.tenant, user=self.user)
162+
for i in range(1, num_requiring_batching + 1)
163+
]
164+
Comment.objects.bulk_create(comments)
165+
id_list = list(Comment.objects.values_list("pk", flat=True))
166+
with self.assertNumQueries(expected_queries):
167+
comment_dict = Comment.objects.in_bulk(id_list=id_list)
164168
self.assertQuerySetEqual(comment_dict, id_list)
165169

166170
def test_iterator(self):

tests/lookup/tests.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -260,16 +260,15 @@ def test_in_bulk_preserve_ordering(self):
260260
{self.au1.id: self.a4, self.au2.id: self.a5},
261261
)
262262

263-
@skipUnlessDBFeature("can_distinct_on_fields")
264263
def test_in_bulk_preserve_ordering_with_batch_size(self):
265-
qs = Article.objects.order_by("author_id", "-pub_date").distinct("author_id")
264+
qs = Article.objects.all()
266265
with (
267-
mock.patch.object(connection.features.__class__, "max_query_params", 1),
266+
mock.patch.object(connection.features.__class__, "max_query_params", 2),
268267
self.assertNumQueries(2),
269268
):
270269
self.assertEqual(
271-
qs.in_bulk([self.au1.id, self.au2.id], field_name="author_id"),
272-
{self.au1.id: self.a4, self.au2.id: self.a5},
270+
list(qs.in_bulk([self.a5.id, self.a4.id, self.a3.id, self.a2.id])),
271+
[5, 4, 2, 3],
273272
)
274273

275274
@skipUnlessDBFeature("can_distinct_on_fields")

0 commit comments

Comments
 (0)