Skip to content

feat: Handle passing of arrays to in statements more efficiently in SQLAlchemy 1.4 and higher #253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Aug 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions sqlalchemy_bigquery/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,39 @@ def visit_bindparam(
skip_bind_expression=False,
**kwargs,
):
type_ = bindparam.type
unnest = False
if (
bindparam.expanding
and not isinstance(type_, NullType)
and not literal_binds
):
# Normally, when performing an IN operation, like:
#
# foo IN (some_sequence)
#
# SQAlchemy passes `foo` as a parameter and unpacks
# `some_sequence` and passes each element as a parameter.
# This mechanism is refered to as "expanding". It's
# inefficient and can't handle large arrays. (It's also
# very complicated, but that's not the issue we care about
# here. :) ) BigQuery lets us use arrays directly in this
# context, we just need to call UNNEST on an array when
# it's used in IN.
#
# So, if we get an `expanding` flag, and if we have a known type
# (and don't have literal binds, which are implemented in-line in
# in the SQL), we turn off expanding and we set an unnest flag
# so that we add an UNNEST() call (below).
#
# The NullType/known-type check has to do with some extreme
# edge cases having to do with empty in-lists that get special
# hijinks from SQLAlchemy that we don't want to disturb. :)
if getattr(bindparam, "expand_op", None) is not None:
assert bindparam.expand_op.__name__.endswith("in_op") # in in
bindparam.expanding = False
unnest = True

param = super(BigQueryCompiler, self).visit_bindparam(
bindparam,
within_columns_clause,
Expand All @@ -491,7 +524,6 @@ def visit_bindparam(
**kwargs,
)

type_ = bindparam.type
if literal_binds or isinstance(type_, NullType):
return param

Expand All @@ -512,7 +544,6 @@ def visit_bindparam(
if bq_type[-1] == ">" and bq_type.startswith("ARRAY<"):
# Values get arrayified at a lower level.
bq_type = bq_type[6:-1]

bq_type = self.__remove_type_parameter(bq_type)

assert_(param != "%s", f"Unexpected param: {param}")
Expand All @@ -528,6 +559,9 @@ def visit_bindparam(
assert_(type_ is None)
param = f"%({name}:{bq_type})s"

if unnest:
param = f"UNNEST({param})"

return param


Expand Down
21 changes: 21 additions & 0 deletions tests/system/test_sqlalchemy_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,27 @@ class MyTable(Base):
assert sorted(db.query(sqlalchemy.distinct(MyTable.my_column)).all()) == expected


@pytest.mark.skipif(
packaging.version.parse(sqlalchemy.__version__) < packaging.version.parse("1.4"),
reason="requires sqlalchemy 1.4 or higher",
)
def test_huge_in():
engine = sqlalchemy.create_engine("bigquery://")
conn = engine.connect()
try:
assert list(
conn.execute(
sqlalchemy.select([sqlalchemy.literal(-1).in_(list(range(99999)))])
)
) == [(False,)]
except Exception:
error = True
else:
error = False

assert not error, "execution failed"


@pytest.mark.skipif(
packaging.version.parse(sqlalchemy.__version__) < packaging.version.parse("1.4"),
reason="unnest (and other table-valued-function) support required version 1.4",
Expand Down
17 changes: 13 additions & 4 deletions tests/unit/fauxdbi.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,11 +261,20 @@ def __handle_problematic_literal_inserts(
else:
return operation

__handle_unnest = substitute_string_re_method(
r"UNNEST\(\[ ([^\]]+)? \]\)", # UNNEST([ ... ])
flags=re.IGNORECASE,
repl=r"(\1)",
@substitute_re_method(
r"""
UNNEST\(
(
\[ (?P<exp>[^\]]+)? \] # UNNEST([ ... ])
|
([?]) # UNNEST(?)
)
\)
""",
flags=re.IGNORECASE | re.VERBOSE,
)
def __handle_unnest(self, m):
return "(" + (m.group("exp") or "?") + ")"

def __handle_true_false(self, operation):
# Older sqlite versions, like those used on the CI servers
Expand Down
144 changes: 71 additions & 73 deletions tests/unit/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from conftest import (
setup_table,
sqlalchemy_version,
sqlalchemy_1_3_or_higher,
sqlalchemy_1_4_or_higher,
sqlalchemy_before_1_4,
Expand Down Expand Up @@ -214,18 +215,6 @@ def test_disable_quote(faux_conn):
assert faux_conn.test_data["execute"][-1][0] == ("SELECT `t`.foo \nFROM `t`")


def _normalize_in_params(query, params):
# We have to normalize parameter names, because they
# change with sqlalchemy versions.
newnames = sorted(
((p, f"p_{i}") for i, p in enumerate(sorted(params))), key=lambda i: -len(i[0])
)
for old, new in newnames:
query = query.replace(old, new)

return query, {new: params[old] for old, new in newnames}


@sqlalchemy_before_1_4
def test_select_in_lit_13(faux_conn):
[[isin]] = faux_conn.execute(
Expand All @@ -240,66 +229,74 @@ def test_select_in_lit_13(faux_conn):


@sqlalchemy_1_4_or_higher
def test_select_in_lit(faux_conn):
[[isin]] = faux_conn.execute(
sqlalchemy.select([sqlalchemy.literal(1).in_([1, 2, 3])])
)
assert isin
assert _normalize_in_params(*faux_conn.test_data["execute"][-1]) == (
"SELECT %(p_0:INT64)s IN "
"UNNEST([ %(p_1:INT64)s, %(p_2:INT64)s, %(p_3:INT64)s ]) AS `anon_1`",
{"p_1": 1, "p_2": 2, "p_3": 3, "p_0": 1},
def test_select_in_lit(faux_conn, last_query):
faux_conn.execute(sqlalchemy.select([sqlalchemy.literal(1).in_([1, 2, 3])]))
last_query(
"SELECT %(param_1:INT64)s IN UNNEST(%(param_2:INT64)s) AS `anon_1`",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the type of param_2 be ARRAY<INT64>?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, because the BQ DB API does special handling of arrays.

It sees that we have a scalar type of INT64 and that we have a sequence, and then creates a ArrayQueryParameter.

It happens that since I added struct support, passing ARRAY<INT64> would probably work (because I have to handle structs of arrays). But just usng INT64 works too.

{"param_1": 1, "param_2": [1, 2, 3]},
)


def test_select_in_param(faux_conn):
def test_select_in_param(faux_conn, last_query):
[[isin]] = faux_conn.execute(
sqlalchemy.select(
[sqlalchemy.literal(1).in_(sqlalchemy.bindparam("q", expanding=True))]
),
dict(q=[1, 2, 3]),
)
assert isin
assert faux_conn.test_data["execute"][-1] == (
"SELECT %(param_1:INT64)s IN UNNEST("
"[ %(q_1:INT64)s, %(q_2:INT64)s, %(q_3:INT64)s ]"
") AS `anon_1`",
{"param_1": 1, "q_1": 1, "q_2": 2, "q_3": 3},
)
if sqlalchemy_version >= packaging.version.parse("1.4"):
last_query(
"SELECT %(param_1:INT64)s IN UNNEST(%(q:INT64)s) AS `anon_1`",
{"param_1": 1, "q": [1, 2, 3]},
)
else:
assert isin
last_query(
"SELECT %(param_1:INT64)s IN UNNEST("
"[ %(q_1:INT64)s, %(q_2:INT64)s, %(q_3:INT64)s ]"
") AS `anon_1`",
{"param_1": 1, "q_1": 1, "q_2": 2, "q_3": 3},
)


def test_select_in_param1(faux_conn):
def test_select_in_param1(faux_conn, last_query):
[[isin]] = faux_conn.execute(
sqlalchemy.select(
[sqlalchemy.literal(1).in_(sqlalchemy.bindparam("q", expanding=True))]
),
dict(q=[1]),
)
assert isin
assert faux_conn.test_data["execute"][-1] == (
"SELECT %(param_1:INT64)s IN UNNEST(" "[ %(q_1:INT64)s ]" ") AS `anon_1`",
{"param_1": 1, "q_1": 1},
)
if sqlalchemy_version >= packaging.version.parse("1.4"):
last_query(
"SELECT %(param_1:INT64)s IN UNNEST(%(q:INT64)s) AS `anon_1`",
{"param_1": 1, "q": [1]},
)
else:
assert isin
last_query(
"SELECT %(param_1:INT64)s IN UNNEST(" "[ %(q_1:INT64)s ]" ") AS `anon_1`",
{"param_1": 1, "q_1": 1},
)


@sqlalchemy_1_3_or_higher
def test_select_in_param_empty(faux_conn):
def test_select_in_param_empty(faux_conn, last_query):
[[isin]] = faux_conn.execute(
sqlalchemy.select(
[sqlalchemy.literal(1).in_(sqlalchemy.bindparam("q", expanding=True))]
),
dict(q=[]),
)
assert not isin
assert faux_conn.test_data["execute"][-1] == (
"SELECT %(param_1:INT64)s IN(NULL) AND (1 != 1) AS `anon_1`"
if (
packaging.version.parse(sqlalchemy.__version__)
>= packaging.version.parse("1.4")
if sqlalchemy_version >= packaging.version.parse("1.4"):
last_query(
"SELECT %(param_1:INT64)s IN UNNEST(%(q:INT64)s) AS `anon_1`",
{"param_1": 1, "q": []},
)
else:
assert not isin
last_query(
"SELECT %(param_1:INT64)s IN UNNEST([ ]) AS `anon_1`", {"param_1": 1}
)
else "SELECT %(param_1:INT64)s IN UNNEST([ ]) AS `anon_1`",
{"param_1": 1},
)


@sqlalchemy_before_1_4
Expand All @@ -316,53 +313,54 @@ def test_select_notin_lit13(faux_conn):


@sqlalchemy_1_4_or_higher
def test_select_notin_lit(faux_conn):
[[isnotin]] = faux_conn.execute(
sqlalchemy.select([sqlalchemy.literal(0).notin_([1, 2, 3])])
def test_select_notin_lit(faux_conn, last_query):
faux_conn.execute(sqlalchemy.select([sqlalchemy.literal(0).notin_([1, 2, 3])]))
last_query(
"SELECT (%(param_1:INT64)s NOT IN UNNEST(%(param_2:INT64)s)) AS `anon_1`",
{"param_1": 0, "param_2": [1, 2, 3]},
)
assert isnotin

assert _normalize_in_params(*faux_conn.test_data["execute"][-1]) == (
"SELECT (%(p_0:INT64)s NOT IN "
"UNNEST([ %(p_1:INT64)s, %(p_2:INT64)s, %(p_3:INT64)s ])) AS `anon_1`",
{"p_0": 0, "p_1": 1, "p_2": 2, "p_3": 3},
)


def test_select_notin_param(faux_conn):
def test_select_notin_param(faux_conn, last_query):
[[isnotin]] = faux_conn.execute(
sqlalchemy.select(
[sqlalchemy.literal(1).notin_(sqlalchemy.bindparam("q", expanding=True))]
),
dict(q=[1, 2, 3]),
)
assert not isnotin
assert faux_conn.test_data["execute"][-1] == (
"SELECT (%(param_1:INT64)s NOT IN UNNEST("
"[ %(q_1:INT64)s, %(q_2:INT64)s, %(q_3:INT64)s ]"
")) AS `anon_1`",
{"param_1": 1, "q_1": 1, "q_2": 2, "q_3": 3},
)
if sqlalchemy_version >= packaging.version.parse("1.4"):
last_query(
"SELECT (%(param_1:INT64)s NOT IN UNNEST(%(q:INT64)s)) AS `anon_1`",
{"param_1": 1, "q": [1, 2, 3]},
)
else:
assert not isnotin
last_query(
"SELECT (%(param_1:INT64)s NOT IN UNNEST("
"[ %(q_1:INT64)s, %(q_2:INT64)s, %(q_3:INT64)s ]"
")) AS `anon_1`",
{"param_1": 1, "q_1": 1, "q_2": 2, "q_3": 3},
)


@sqlalchemy_1_3_or_higher
def test_select_notin_param_empty(faux_conn):
def test_select_notin_param_empty(faux_conn, last_query):
[[isnotin]] = faux_conn.execute(
sqlalchemy.select(
[sqlalchemy.literal(1).notin_(sqlalchemy.bindparam("q", expanding=True))]
),
dict(q=[]),
)
assert isnotin
assert faux_conn.test_data["execute"][-1] == (
"SELECT (%(param_1:INT64)s NOT IN(NULL) OR (1 = 1)) AS `anon_1`"
if (
packaging.version.parse(sqlalchemy.__version__)
>= packaging.version.parse("1.4")
if sqlalchemy_version >= packaging.version.parse("1.4"):
last_query(
"SELECT (%(param_1:INT64)s NOT IN UNNEST(%(q:INT64)s)) AS `anon_1`",
{"param_1": 1, "q": []},
)
else:
assert isnotin
last_query(
"SELECT (%(param_1:INT64)s NOT IN UNNEST([ ])) AS `anon_1`", {"param_1": 1}
)
else "SELECT (%(param_1:INT64)s NOT IN UNNEST([ ])) AS `anon_1`",
{"param_1": 1},
)


def test_literal_binds_kwarg_with_an_IN_operator_252(faux_conn):
Expand Down