diff --git a/pybigquery/__init__.py b/pybigquery/__init__.py index e69de29b..d844867f 100644 --- a/pybigquery/__init__.py +++ b/pybigquery/__init__.py @@ -0,0 +1,42 @@ +from sqlalchemy.dialects.postgresql import array +from sqlalchemy.sql import expression, operators, sqltypes + +__all__ = ["array", "struct"] + + +class STRUCT(sqltypes.Indexable, sqltypes.TypeEngine): + # NOTE: STRUCT names/types aren't currently supported. + + __visit_name__ = "STRUCT" + + class Comparator(sqltypes.Indexable.Comparator): + def _setup_getitem(self, index): + return operators.getitem, index, self.type + + comparator_factory = Comparator + + +class struct(expression.ClauseList, expression.ColumnElement): + """ Create a BigQuery struct literal from a collection of named expressions/clauses. + """ + # NOTE: Struct subfields aren't currently propagated/validated. + + __visit_name__ = "struct" + + def __init__(self, clauses, field=None, **kw): + self.field = field + self.type = STRUCT() + super().__init__(*clauses, **kw) + + def _bind_param(self, operator, obj, _assume_scalar=False, type_=None): + if operator is operators.getitem: + # TODO: + # - Validate field in clauses (or error if no clauses) + # - If the field is a sub-struct, return with all clauses, otherwise none. + return struct([], field=obj) + + def self_group(self, against=None): + if not self.field and against in (operators.getitem,): + return expression.Grouping(self) + else: + return self diff --git a/pybigquery/sqlalchemy_bigquery.py b/pybigquery/sqlalchemy_bigquery.py index 74573458..719c0221 100644 --- a/pybigquery/sqlalchemy_bigquery.py +++ b/pybigquery/sqlalchemy_bigquery.py @@ -133,6 +133,22 @@ def __init__(self, dialect, statement, column_keys=None, kwargs['compile_kwargs'] = util.immutabledict({'include_table': False}) super(BigQueryCompiler, self).__init__(dialect, statement, column_keys, inline, **kwargs) + def visit_array(self, element, **kw): + return "[%s]" % self.visit_clauselist(element, **kw) + + def visit_struct(self, element, within_columns_clause=True, **kw): + if element.field: + return self.preparer.quote_column(element.field) + kw["within_columns_clause"] = True + values = self.visit_clauselist(element, **kw) + return "struct(%s)" % values + + def visit_getitem_binary(self, binary, operator, **kw): + return "%s.%s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + def visit_select(self, *args, **kwargs): """ Use labels for every column.