Skip to content

Commit cbb4ee3

Browse files
Ilya Gurovlarkee
Ilya Gurov
andauthored
feat: use DML batches in executemany() method (googleapis#412)
* feat: use mutations for executemany() inserts * add unit test and fix parsing * add use_mutations flag into Connection class * use three-values flag for use_mutations * update docstrings * use batch DMLs for executemany() method * prepare args before inserting into SQL statement * erase mutation mentions * next step * next step * next step * fixes * add unit tests for UPDATE and DELETE statements * don't propagate errors to users on retry * lint fixes * use run_in_transaction * refactor the tests code * fix merge conflict * fix the unit test * revert some changes * use executemany for test data insert Co-authored-by: larkee <31196561+larkee@users.noreply.github.com>
1 parent a2b81be commit cbb4ee3

File tree

4 files changed

+395
-27
lines changed

4 files changed

+395
-27
lines changed

google/cloud/spanner_dbapi/connection.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT
3333
from google.cloud.spanner_dbapi.version import PY_VERSION
3434

35+
from google.rpc.code_pb2 import ABORTED
36+
3537

3638
AUTOCOMMIT_MODE_WARNING = "This method is non-operational in autocommit mode"
3739
MAX_INTERNAL_RETRIES = 50
@@ -175,25 +177,41 @@ def _rerun_previous_statements(self):
175177
from the last transaction.
176178
"""
177179
for statement in self._statements:
178-
res_iter, retried_checksum = self.run_statement(statement, retried=True)
179-
# executing all the completed statements
180-
if statement != self._statements[-1]:
181-
for res in res_iter:
182-
retried_checksum.consume_result(res)
183-
184-
_compare_checksums(statement.checksum, retried_checksum)
185-
# executing the failed statement
180+
if isinstance(statement, list):
181+
statements, checksum = statement
182+
183+
transaction = self.transaction_checkout()
184+
status, res = transaction.batch_update(statements)
185+
186+
if status.code == ABORTED:
187+
self.connection._transaction = None
188+
raise Aborted(status.details)
189+
190+
retried_checksum = ResultsChecksum()
191+
retried_checksum.consume_result(res)
192+
retried_checksum.consume_result(status.code)
193+
194+
_compare_checksums(checksum, retried_checksum)
186195
else:
187-
# streaming up to the failed result or
188-
# to the end of the streaming iterator
189-
while len(retried_checksum) < len(statement.checksum):
190-
try:
191-
res = next(iter(res_iter))
196+
res_iter, retried_checksum = self.run_statement(statement, retried=True)
197+
# executing all the completed statements
198+
if statement != self._statements[-1]:
199+
for res in res_iter:
192200
retried_checksum.consume_result(res)
193-
except StopIteration:
194-
break
195201

196-
_compare_checksums(statement.checksum, retried_checksum)
202+
_compare_checksums(statement.checksum, retried_checksum)
203+
# executing the failed statement
204+
else:
205+
# streaming up to the failed result or
206+
# to the end of the streaming iterator
207+
while len(retried_checksum) < len(statement.checksum):
208+
try:
209+
res = next(iter(res_iter))
210+
retried_checksum.consume_result(res)
211+
except StopIteration:
212+
break
213+
214+
_compare_checksums(statement.checksum, retried_checksum)
197215

198216
def transaction_checkout(self):
199217
"""Get a Cloud Spanner transaction.

google/cloud/spanner_dbapi/cursor.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141
from google.cloud.spanner_dbapi.utils import PeekIterator
4242
from google.cloud.spanner_dbapi.utils import StreamedManyResultSets
4343

44+
from google.rpc.code_pb2 import ABORTED, OK
45+
4446
_UNSET_COUNT = -1
4547

4648
ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"])
@@ -156,6 +158,15 @@ def _do_execute_update(self, transaction, sql, params):
156158

157159
return result
158160

161+
def _do_batch_update(self, transaction, statements, many_result_set):
162+
status, res = transaction.batch_update(statements)
163+
many_result_set.add_iter(res)
164+
165+
if status.code == ABORTED:
166+
raise Aborted(status.details)
167+
elif status.code != OK:
168+
raise OperationalError(status.details)
169+
159170
def execute(self, sql, args=None):
160171
"""Prepares and executes a Spanner database operation.
161172
@@ -258,9 +269,50 @@ def executemany(self, operation, seq_of_params):
258269

259270
many_result_set = StreamedManyResultSets()
260271

261-
for params in seq_of_params:
262-
self.execute(operation, params)
263-
many_result_set.add_iter(self._itr)
272+
if classification in (parse_utils.STMT_INSERT, parse_utils.STMT_UPDATING):
273+
statements = []
274+
275+
for params in seq_of_params:
276+
sql, params = parse_utils.sql_pyformat_args_to_spanner(
277+
operation, params
278+
)
279+
statements.append((sql, params, get_param_types(params)))
280+
281+
if self.connection.autocommit:
282+
self.connection.database.run_in_transaction(
283+
self._do_batch_update, statements, many_result_set
284+
)
285+
else:
286+
retried = False
287+
while True:
288+
try:
289+
transaction = self.connection.transaction_checkout()
290+
291+
res_checksum = ResultsChecksum()
292+
if not retried:
293+
self.connection._statements.append(
294+
(statements, res_checksum)
295+
)
296+
297+
status, res = transaction.batch_update(statements)
298+
many_result_set.add_iter(res)
299+
res_checksum.consume_result(res)
300+
res_checksum.consume_result(status.code)
301+
302+
if status.code == ABORTED:
303+
self.connection._transaction = None
304+
raise Aborted(status.details)
305+
elif status.code != OK:
306+
raise OperationalError(status.details)
307+
break
308+
except Aborted:
309+
self.connection.retry_transaction()
310+
retried = True
311+
312+
else:
313+
for params in seq_of_params:
314+
self.execute(operation, params)
315+
many_result_set.add_iter(self._itr)
264316

265317
self._result_set = many_result_set
266318
self._itr = many_result_set

tests/system/test_system_dbapi.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -343,20 +343,20 @@ def test_execute_many(self):
343343
conn = Connection(Config.INSTANCE, self._db)
344344
cursor = conn.cursor()
345345

346-
cursor.execute(
346+
cursor.executemany(
347347
"""
348348
INSERT INTO contacts (contact_id, first_name, last_name, email)
349-
VALUES (1, 'first-name', 'last-name', 'test.email@example.com'),
350-
(2, 'first-name2', 'last-name2', 'test.email2@example.com')
351-
"""
349+
VALUES (%s, %s, %s, %s)
350+
""",
351+
[
352+
(1, "first-name", "last-name", "test.email@example.com"),
353+
(2, "first-name2", "last-name2", "test.email2@example.com"),
354+
],
352355
)
353356
conn.commit()
354357

355358
cursor.executemany(
356-
"""
357-
SELECT * FROM contacts WHERE contact_id = @a1
358-
""",
359-
({"a1": 1}, {"a1": 2}),
359+
"""SELECT * FROM contacts WHERE contact_id = @a1""", ({"a1": 1}, {"a1": 2}),
360360
)
361361
res = cursor.fetchall()
362362
conn.commit()

0 commit comments

Comments
 (0)