Closed
Description
Python 3.10.15
sqlalchemy_bigquery 1.12.0 (issue introduced in 1.11.0).
When creating a new FunctionalElement to compile some custom function that does not support the default
dialect, and grouping by that element, we get an error as we are using the default
dialect instead of compiler dialect.
Steps to reproduce
- Create a new FunctionElement
- Create a function that compiles it. Make sure it does not support the
default
dialect. - Create a query that groups by a field using this function and adding a label to it
Code example
import sqlalchemy
class CustomLower(sqlalchemy.sql.functions.FunctionElement):
name = "custom_lower"
@sqlalchemy.ext.compiler.compiles(CustomLower)
def compile_custom_intersect(element, compiler, **kwargs):
if compiler.dialect.name != "bigquery":
raise sqlalchemy.exc.CompileError(
f"custom_lower is not supported for dialect '{compiler.dialect.name}'"
)
clauses = list(element.clauses)
field = compiler.process(clauses[0], **kwargs)
return f"LOWER({field})"
db_metadata = MetaData()
engine = create_engine("bigquery://", future=True)
table1 = setup_table(
engine,
"table1",
metadata,
sqlalchemy.Column("foo", sqlalchemy.String),
sqlalchemy.Column("bar", sqlalchemy.Integer),
)
lower_foo = CustomLower(table1.c.foo).label("some_label")
q = (
sqlalchemy.select(lower_foo, sqlalchemy.func.max(table1.c.bar))
.select_from(table1)
.group_by(lower_foo)
)
print(q.compile(engine).string)
Error received
sqlalchemy.exc.CompileError: custom_lower is not supported for dialect 'default'
This is because when trying to understand if we have the grouping_ops
in the element we use str()
to stringify it instead of compiling it using the dialect.
The solution should be in visit_label
:
We should not use str(column_label)
, but rather column_label.compile(dialect=self.dialect).string
Example fix
def visit_label(self, *args, within_group_by=False, **kwargs):
# Use labels in GROUP BY clause.
#
# Flag set in the group_by_clause method. Works around missing
# equivalent to supports_simple_order_by_label for group by.
if within_group_by:
column_label = args[0]
sql_keywords = {"GROUPING SETS", "ROLLUP", "CUBE"}
label_str = column_label.compile(dialect=self.dialect).string
if not any(keyword in label_str for keyword in sql_keywords):
kwargs["render_label_as_label"] = column_label
return super(BigQueryCompiler, self).visit_label(*args, **kwargs)