From 7c9b816114bb98adb1975bdb92d06fc4a82a3761 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 1 Jul 2025 10:52:54 -0700 Subject: [PATCH 01/13] refactor: subclass DerefOp as ResolvedDerefOp (#1874) * refactor: subclass DerefOp as ResolvedDerefOp * replace the `field` attribute by id, dtype, nullable * final cleanup --- bigframes/core/compile/polars/compiler.py | 4 +- .../core/compile/sqlglot/scalar_compiler.py | 7 --- bigframes/core/expression.py | 52 +++++-------------- bigframes/core/rewrite/schema_binding.py | 13 +++++ tests/unit/core/test_expression.py | 4 +- 5 files changed, 29 insertions(+), 51 deletions(-) diff --git a/bigframes/core/compile/polars/compiler.py b/bigframes/core/compile/polars/compiler.py index 6b76f3f53d..40037735d4 100644 --- a/bigframes/core/compile/polars/compiler.py +++ b/bigframes/core/compile/polars/compiler.py @@ -120,9 +120,9 @@ def _( @compile_expression.register def _( self, - expression: ex.SchemaFieldRefExpression, + expression: ex.ResolvedDerefOp, ) -> pl.Expr: - return pl.col(expression.field.id.sql) + return pl.col(expression.id.sql) @compile_expression.register def _( diff --git a/bigframes/core/compile/sqlglot/scalar_compiler.py b/bigframes/core/compile/sqlglot/scalar_compiler.py index f553518300..0db507b0fa 100644 --- a/bigframes/core/compile/sqlglot/scalar_compiler.py +++ b/bigframes/core/compile/sqlglot/scalar_compiler.py @@ -42,13 +42,6 @@ def compile_deref_expression(expr: expression.DerefOp) -> sge.Expression: return sge.ColumnDef(this=sge.to_identifier(expr.id.sql, quoted=True)) -@compile_scalar_expression.register -def compile_field_ref_expression( - expr: expression.SchemaFieldRefExpression, -) -> sge.Expression: - return sge.ColumnDef(this=sge.to_identifier(expr.field.id.sql, quoted=True)) - - @compile_scalar_expression.register def compile_constant_expression( expr: expression.ScalarConstantExpression, diff --git a/bigframes/core/expression.py b/bigframes/core/expression.py index 40ba70c555..7b20e430ff 100644 --- a/bigframes/core/expression.py +++ b/bigframes/core/expression.py @@ -429,55 +429,27 @@ def transform_children(self, t: Callable[[Expression], Expression]) -> Expressio @dataclasses.dataclass(frozen=True) -class SchemaFieldRefExpression(Expression): - """An expression representing a schema field. This is essentially a DerefOp with input schema bound.""" +class ResolvedDerefOp(DerefOp): + """An expression that refers to a column by ID and resolved with schema bound.""" - field: field.Field + dtype: dtypes.Dtype + is_nullable: bool - @property - def column_references(self) -> typing.Tuple[ids.ColumnId, ...]: - return (self.field.id,) - - @property - def is_const(self) -> bool: - return False - - @property - def nullable(self) -> bool: - return self.field.nullable + @classmethod + def from_field(cls, f: field.Field): + return cls(id=f.id, dtype=f.dtype, is_nullable=f.nullable) @property def is_resolved(self) -> bool: return True @property - def output_type(self) -> dtypes.ExpressionType: - return self.field.dtype - - def bind_variables( - self, bindings: Mapping[str, Expression], allow_partial_bindings: bool = False - ) -> Expression: - return self - - def bind_refs( - self, - bindings: Mapping[ids.ColumnId, Expression], - allow_partial_bindings: bool = False, - ) -> Expression: - if self.field.id in bindings.keys(): - return bindings[self.field.id] - return self - - @property - def is_bijective(self) -> bool: - return True + def nullable(self) -> bool: + return self.is_nullable @property - def is_identity(self) -> bool: - return True - - def transform_children(self, t: Callable[[Expression], Expression]) -> Expression: - return self + def output_type(self) -> dtypes.ExpressionType: + return self.dtype @dataclasses.dataclass(frozen=True) @@ -589,7 +561,7 @@ def bind_schema_fields( return expr expr_by_id = { - id: SchemaFieldRefExpression(field) for id, field in field_by_id.items() + id: ResolvedDerefOp.from_field(field) for id, field in field_by_id.items() } return expr.bind_refs(expr_by_id) diff --git a/bigframes/core/rewrite/schema_binding.py b/bigframes/core/rewrite/schema_binding.py index aa5cb986b9..af0593211c 100644 --- a/bigframes/core/rewrite/schema_binding.py +++ b/bigframes/core/rewrite/schema_binding.py @@ -52,4 +52,17 @@ def bind_schema_to_node( return dataclasses.replace(node, by=tuple(bound_bys)) + if isinstance(node, nodes.JoinNode): + conditions = tuple( + ( + ex.ResolvedDerefOp.from_field(node.left_child.field_by_id[left.id]), + ex.ResolvedDerefOp.from_field(node.right_child.field_by_id[right.id]), + ) + for left, right in node.conditions + ) + return dataclasses.replace( + node, + conditions=conditions, + ) + return node diff --git a/tests/unit/core/test_expression.py b/tests/unit/core/test_expression.py index 9534c8605a..4c3d233879 100644 --- a/tests/unit/core/test_expression.py +++ b/tests/unit/core/test_expression.py @@ -77,8 +77,8 @@ def test_deref_op_dtype_resolution(): def test_field_ref_expr_dtype_resolution_short_circuit(): - expression = ex.SchemaFieldRefExpression( - field.Field(ids.ColumnId("mycol"), dtype=dtypes.INT_DTYPE) + expression = ex.ResolvedDerefOp( + id=ids.ColumnId("mycol"), dtype=dtypes.INT_DTYPE, is_nullable=True ) field_bindings = _create_field_bindings({"anotherCol": dtypes.STRING_DTYPE}) From c289f7061320ec6d9de099cab2416cc9f289baac Mon Sep 17 00:00:00 2001 From: Alicia Williams Date: Tue, 1 Jul 2025 12:33:57 -0700 Subject: [PATCH 02/13] docs: update gsutil commands to gcloud commands (#1876) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Tim Sweña (Swast) --- .../generative_ai/bq_dataframes_llm_code_generation.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/notebooks/generative_ai/bq_dataframes_llm_code_generation.ipynb b/notebooks/generative_ai/bq_dataframes_llm_code_generation.ipynb index 68e10cb5ed..4f1329129e 100644 --- a/notebooks/generative_ai/bq_dataframes_llm_code_generation.ipynb +++ b/notebooks/generative_ai/bq_dataframes_llm_code_generation.ipynb @@ -1093,7 +1093,7 @@ "import uuid\n", "BUCKET_ID = \"code-samples-\" + str(uuid.uuid1())\n", "\n", - "!gsutil mb gs://{BUCKET_ID}" + "!gcloud storage buckets create gs://{BUCKET_ID}" ] }, { @@ -1272,7 +1272,7 @@ "outputs": [], "source": [ "# # Delete the Google Cloud Storage bucket and files\n", - "# ! gsutil rm -r gs://{BUCKET_ID}\n", + "# ! gcloud storage rm gs://{BUCKET_ID} --recursive\n", "# print(f\"Deleted bucket '{BUCKET_ID}'.\")" ] } From 23d6fb4fa06ad1e2707c941047bbfde2e77feeb3 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 1 Jul 2025 12:39:24 -0700 Subject: [PATCH 03/13] refactor: add compile_join (#1851) --- bigframes/core/compile/sqlglot/compiler.py | 24 ++++++ bigframes/core/compile/sqlglot/sqlglot_ir.py | 73 +++++++++++++++---- bigframes/dataframe.py | 2 - .../test_compile_join/out.sql | 31 ++++++++ .../core/compile/sqlglot/test_compile_join.py | 51 +++++++++++++ .../bigframes_vendored/pandas/core/frame.py | 2 +- 6 files changed, 164 insertions(+), 19 deletions(-) create mode 100644 tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join/out.sql create mode 100644 tests/unit/core/compile/sqlglot/test_compile_join.py diff --git a/bigframes/core/compile/sqlglot/compiler.py b/bigframes/core/compile/sqlglot/compiler.py index 606fe41b5e..03d1a0a2de 100644 --- a/bigframes/core/compile/sqlglot/compiler.py +++ b/bigframes/core/compile/sqlglot/compiler.py @@ -22,6 +22,7 @@ from bigframes.core import expression, guid, identifiers, nodes, pyarrow_utils, rewrite from bigframes.core.compile import configs +from bigframes.core.compile.sqlglot.expressions import typed_expr import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler import bigframes.core.compile.sqlglot.sqlglot_ir as ir import bigframes.core.ordering as bf_ordering @@ -218,6 +219,29 @@ def compile_filter( condition = scalar_compiler.compile_scalar_expression(node.predicate) return child.filter(condition) + @_compile_node.register + def compile_join( + self, node: nodes.JoinNode, left: ir.SQLGlotIR, right: ir.SQLGlotIR + ) -> ir.SQLGlotIR: + conditions = tuple( + ( + typed_expr.TypedExpr( + scalar_compiler.compile_scalar_expression(left), left.output_type + ), + typed_expr.TypedExpr( + scalar_compiler.compile_scalar_expression(right), right.output_type + ), + ) + for left, right in node.conditions + ) + + return left.join( + right, + join_type=node.type, + conditions=conditions, + joins_nulls=node.joins_nulls, + ) + @_compile_node.register def compile_concat( self, node: nodes.ConcatNode, *children: ir.SQLGlotIR diff --git a/bigframes/core/compile/sqlglot/sqlglot_ir.py b/bigframes/core/compile/sqlglot/sqlglot_ir.py index 6bc2b55162..3b4d7ed0ce 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_ir.py +++ b/bigframes/core/compile/sqlglot/sqlglot_ir.py @@ -26,6 +26,7 @@ from bigframes import dtypes from bigframes.core import guid +from bigframes.core.compile.sqlglot.expressions import typed_expr import bigframes.core.compile.sqlglot.sqlglot_types as sgt import bigframes.core.local_data as local_data import bigframes.core.schema as bf_schema @@ -212,7 +213,8 @@ def select( for id, expr in selected_cols ] - new_expr = self._encapsulate_as_cte().select(*selections, append=False) + new_expr, _ = self._encapsulate_as_cte() + new_expr = new_expr.select(*selections, append=False) return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) def order_by( @@ -247,7 +249,8 @@ def project( ) for id, expr in projected_cols ] - new_expr = self._encapsulate_as_cte().select(*projected_cols_expr, append=True) + new_expr, _ = self._encapsulate_as_cte() + new_expr = new_expr.select(*projected_cols_expr, append=True) return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) def filter( @@ -255,11 +258,43 @@ def filter( condition: sge.Expression, ) -> SQLGlotIR: """Filters the query with the given condition.""" - new_expr = self._encapsulate_as_cte() + new_expr, _ = self._encapsulate_as_cte() return SQLGlotIR( expr=new_expr.where(condition, append=False), uid_gen=self.uid_gen ) + def join( + self, + right: SQLGlotIR, + join_type: typing.Literal["inner", "outer", "left", "right", "cross"], + conditions: tuple[tuple[typed_expr.TypedExpr, typed_expr.TypedExpr], ...], + *, + joins_nulls: bool = True, + ) -> SQLGlotIR: + """Joins the current query with another SQLGlotIR instance.""" + left_select, left_table = self._encapsulate_as_cte() + right_select, right_table = right._encapsulate_as_cte() + + left_ctes = left_select.args.pop("with", []) + right_ctes = right_select.args.pop("with", []) + merged_ctes = [*left_ctes, *right_ctes] + + join_conditions = [ + _join_condition(left, right, joins_nulls) for left, right in conditions + ] + join_on = sge.And(expressions=join_conditions) if join_conditions else None + + join_type_str = join_type if join_type != "outer" else "full outer" + new_expr = ( + sge.Select() + .select(sge.Star()) + .from_(left_table) + .join(right_table, on=join_on, join_type=join_type_str) + ) + new_expr.set("with", sge.With(expressions=merged_ctes)) + + return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) + def insert( self, destination: bigquery.TableReference, @@ -320,12 +355,12 @@ def _explode_single_column( offset=offset, ) selection = sge.Star(replace=[unnested_column_alias.as_(column)]) + # TODO: "CROSS" if not keep_empty else "LEFT" # TODO: overlaps_with_parent to replace existing column. - new_expr = ( - self._encapsulate_as_cte() - .select(selection, append=False) - .join(unnest_expr, join_type="CROSS") + new_expr, _ = self._encapsulate_as_cte() + new_expr = new_expr.select(selection, append=False).join( + unnest_expr, join_type="CROSS" ) return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) @@ -373,16 +408,15 @@ def _explode_multiple_columns( for column in columns ] ) - new_expr = ( - self._encapsulate_as_cte() - .select(selection, append=False) - .join(unnest_expr, join_type="CROSS") + new_expr, _ = self._encapsulate_as_cte() + new_expr = new_expr.select(selection, append=False).join( + unnest_expr, join_type="CROSS" ) return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) def _encapsulate_as_cte( self, - ) -> sge.Select: + ) -> typing.Tuple[sge.Select, sge.Table]: """Transforms a given sge.Select query by pushing its main SELECT statement into a new CTE and then generates a 'SELECT * FROM new_cte_name' for the new query.""" @@ -397,11 +431,10 @@ def _encapsulate_as_cte( alias=new_cte_name, ) new_with_clause = sge.With(expressions=[*existing_ctes, new_cte]) - new_select_expr = ( - sge.Select().select(sge.Star()).from_(sge.Table(this=new_cte_name)) - ) + new_table_expr = sge.Table(this=new_cte_name) + new_select_expr = sge.Select().select(sge.Star()).from_(new_table_expr) new_select_expr.set("with", new_with_clause) - return new_select_expr + return new_select_expr, new_table_expr def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression: @@ -451,3 +484,11 @@ def _table(table: bigquery.TableReference) -> sge.Table: db=sg.to_identifier(table.dataset_id, quoted=True), catalog=sg.to_identifier(table.project, quoted=True), ) + + +def _join_condition( + left: typed_expr.TypedExpr, + right: typed_expr.TypedExpr, + joins_nulls: bool, +) -> typing.Union[sge.EQ, sge.And]: + return sge.EQ(this=left.expr, expression=right.expr) diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index 1ca5b8b035..1ef287842e 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -3369,8 +3369,6 @@ def merge( "right", "cross", ] = "inner", - # TODO(garrettwu): Currently can take inner, outer, left and right. To support - # cross joins on: Union[blocks.Label, Sequence[blocks.Label], None] = None, *, left_on: Union[blocks.Label, Sequence[blocks.Label], None] = None, diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join/out.sql new file mode 100644 index 0000000000..aefaa28dfb --- /dev/null +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join/out.sql @@ -0,0 +1,31 @@ +WITH `bfcte_1` AS ( + SELECT + `int64_col` AS `bfcol_0`, + `rowindex` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_2` AS ( + SELECT + `bfcol_1` AS `bfcol_2`, + `bfcol_0` AS `bfcol_3` + FROM `bfcte_1` +), `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_4`, + `int64_too` AS `bfcol_5` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_3` AS ( + SELECT + `bfcol_4` AS `bfcol_6`, + `bfcol_5` AS `bfcol_7` + FROM `bfcte_0` +), `bfcte_4` AS ( + SELECT + * + FROM `bfcte_2` + LEFT JOIN `bfcte_3` + ON `bfcol_2` = `bfcol_6` +) +SELECT + `bfcol_3` AS `int64_col`, + `bfcol_7` AS `int64_too` +FROM `bfcte_4` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/test_compile_join.py b/tests/unit/core/compile/sqlglot/test_compile_join.py new file mode 100644 index 0000000000..a530ed4fc3 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/test_compile_join.py @@ -0,0 +1,51 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import bigframes.pandas as bpd + +pytest.importorskip("pytest_snapshot") + + +def test_compile_join(scalar_types_df: bpd.DataFrame, snapshot): + left = scalar_types_df[["int64_col"]] + right = scalar_types_df.set_index("int64_col")[["int64_too"]] + join = left.join(right) + snapshot.assert_match(join.sql, "out.sql") + + +def test_compile_join_w_how(scalar_types_df: bpd.DataFrame): + left = scalar_types_df[["int64_col"]] + right = scalar_types_df.set_index("int64_col")[["int64_too"]] + + join_sql = left.join(right, how="left").sql + assert "LEFT JOIN" in join_sql + assert "ON" in join_sql + + join_sql = left.join(right, how="right").sql + assert "RIGHT JOIN" in join_sql + assert "ON" in join_sql + + join_sql = left.join(right, how="outer").sql + assert "FULL OUTER JOIN" in join_sql + assert "ON" in join_sql + + join_sql = left.join(right, how="inner").sql + assert "INNER JOIN" in join_sql + assert "ON" in join_sql + + join_sql = left.merge(right, how="cross").sql + assert "CROSS JOIN" in join_sql + assert "ON" not in join_sql diff --git a/third_party/bigframes_vendored/pandas/core/frame.py b/third_party/bigframes_vendored/pandas/core/frame.py index 0606032d34..40ab5a7352 100644 --- a/third_party/bigframes_vendored/pandas/core/frame.py +++ b/third_party/bigframes_vendored/pandas/core/frame.py @@ -4748,7 +4748,7 @@ def merge( right: Object to merge with. how: - ``{'left', 'right', 'outer', 'inner'}, default 'inner'`` + ``{'left', 'right', 'outer', 'inner', 'cross'}, default 'inner'`` Type of merge to be performed. ``left``: use only keys from left frame, similar to a SQL left outer join; preserve key order. From 6454aff726dee791acbac98f893075ee5ee6d9a1 Mon Sep 17 00:00:00 2001 From: TrevorBergeron Date: Tue, 1 Jul 2025 13:15:46 -0700 Subject: [PATCH 04/13] feat: Add filter pushdown to hybrid engine (#1871) --- bigframes/session/polars_executor.py | 1 + tests/system/small/engines/test_filtering.py | 67 ++++++++++++++++++++ tests/system/small/test_polars_execution.py | 3 +- 3 files changed, 69 insertions(+), 2 deletions(-) create mode 100644 tests/system/small/engines/test_filtering.py diff --git a/bigframes/session/polars_executor.py b/bigframes/session/polars_executor.py index ec00e38606..e60bef1819 100644 --- a/bigframes/session/polars_executor.py +++ b/bigframes/session/polars_executor.py @@ -35,6 +35,7 @@ nodes.ProjectionNode, nodes.SliceNode, nodes.AggregateNode, + nodes.FilterNode, ) _COMPATIBLE_SCALAR_OPS = ( diff --git a/tests/system/small/engines/test_filtering.py b/tests/system/small/engines/test_filtering.py new file mode 100644 index 0000000000..9b7cd034b4 --- /dev/null +++ b/tests/system/small/engines/test_filtering.py @@ -0,0 +1,67 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from bigframes.core import array_value, expression, nodes +import bigframes.operations as ops +from bigframes.session import polars_executor +from bigframes.testing.engine_utils import assert_equivalence_execution + +pytest.importorskip("polars") + +# Polars used as reference as its fast and local. Generally though, prefer gbq engine where they disagree. +REFERENCE_ENGINE = polars_executor.PolarsExecutor() + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_filter_bool_col( + scalars_array_value: array_value.ArrayValue, + engine, +): + node = nodes.FilterNode( + scalars_array_value.node, predicate=expression.deref("bool_col") + ) + assert_equivalence_execution(node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_filter_expr_cond( + scalars_array_value: array_value.ArrayValue, + engine, +): + predicate = ops.gt_op.as_expr( + expression.deref("float64_col"), expression.deref("int64_col") + ) + node = nodes.FilterNode(scalars_array_value.node, predicate=predicate) + assert_equivalence_execution(node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_filter_true( + scalars_array_value: array_value.ArrayValue, + engine, +): + predicate = expression.const(True) + node = nodes.FilterNode(scalars_array_value.node, predicate=predicate) + assert_equivalence_execution(node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_filter_false( + scalars_array_value: array_value.ArrayValue, + engine, +): + predicate = expression.const(False) + node = nodes.FilterNode(scalars_array_value.node, predicate=predicate) + assert_equivalence_execution(node, REFERENCE_ENGINE, engine) diff --git a/tests/system/small/test_polars_execution.py b/tests/system/small/test_polars_execution.py index 0aed693b80..1568a76ec9 100644 --- a/tests/system/small/test_polars_execution.py +++ b/tests/system/small/test_polars_execution.py @@ -53,8 +53,7 @@ def test_polar_execution_sorted_filtered(session_w_polars, scalars_pandas_df_ind .to_pandas() ) - # Filter and isnull not supported by polar engine yet, so falls back to bq execution - assert session_w_polars._metrics.execution_count == (execution_count_before + 1) + assert session_w_polars._metrics.execution_count == execution_count_before assert_pandas_df_equal(bf_result, pd_result) From f30f75053a6966abd1a6a644c23efb86b2ac568d Mon Sep 17 00:00:00 2001 From: TrevorBergeron Date: Tue, 1 Jul 2025 13:58:33 -0700 Subject: [PATCH 05/13] fix: Fix issues where duration type returned as int (#1875) --- bigframes/core/local_data.py | 13 ++---- bigframes/core/pyarrow_utils.py | 10 +++++ bigframes/dtypes.py | 3 ++ bigframes/session/executor.py | 1 + bigframes/testing/utils.py | 10 +++++ tests/data/scalars.jsonl | 18 ++++---- tests/data/scalars_schema.json | 6 +++ .../pandas/core/methods/test_describe.py | 8 +++- tests/system/small/test_dataframe.py | 12 +++-- tests/system/small/test_dataframe_io.py | 38 ++++++++++------ tests/system/small/test_session.py | 44 ++++++++++++++++--- .../compile/sqlglot/test_compile_readlocal.py | 2 + tests/unit/test_dataframe_polars.py | 11 +++-- 13 files changed, 131 insertions(+), 45 deletions(-) diff --git a/bigframes/core/local_data.py b/bigframes/core/local_data.py index a99366ad4c..958113dda3 100644 --- a/bigframes/core/local_data.py +++ b/bigframes/core/local_data.py @@ -30,6 +30,7 @@ import pyarrow as pa import pyarrow.parquet # type: ignore +from bigframes.core import pyarrow_utils import bigframes.core.schema as schemata import bigframes.dtypes @@ -113,7 +114,9 @@ def to_arrow( schema = self.data.schema if duration_type == "int": schema = _schema_durations_to_ints(schema) - batches = map(functools.partial(_cast_pa_batch, schema=schema), batches) + batches = map( + functools.partial(pyarrow_utils.cast_batch, schema=schema), batches + ) if offsets_col is not None: return schema.append(pa.field(offsets_col, pa.int64())), _append_offsets( @@ -468,14 +471,6 @@ def _schema_durations_to_ints(schema: pa.Schema) -> pa.Schema: ) -# TODO: Use RecordBatch.cast once min pyarrow>=16.0 -def _cast_pa_batch(batch: pa.RecordBatch, schema: pa.Schema) -> pa.RecordBatch: - return pa.record_batch( - [arr.cast(type) for arr, type in zip(batch.columns, schema.types)], - schema=schema, - ) - - def _pairwise(iterable): do_yield = False a = None diff --git a/bigframes/core/pyarrow_utils.py b/bigframes/core/pyarrow_utils.py index 4196e68304..b9dc2ea2b3 100644 --- a/bigframes/core/pyarrow_utils.py +++ b/bigframes/core/pyarrow_utils.py @@ -74,6 +74,16 @@ def chunk_by_row_count( yield buffer.take_as_batches(len(buffer)) +def cast_batch(batch: pa.RecordBatch, schema: pa.Schema) -> pa.RecordBatch: + if batch.schema == schema: + return batch + # TODO: Use RecordBatch.cast once min pyarrow>=16.0 + return pa.record_batch( + [arr.cast(type) for arr, type in zip(batch.columns, schema.types)], + schema=schema, + ) + + def truncate_pyarrow_iterable( batches: Iterable[pa.RecordBatch], max_results: int ) -> Iterator[pa.RecordBatch]: diff --git a/bigframes/dtypes.py b/bigframes/dtypes.py index b0a31595e5..20f2f5ee12 100644 --- a/bigframes/dtypes.py +++ b/bigframes/dtypes.py @@ -247,6 +247,7 @@ class SimpleDtypeInfo: "decimal128(38, 9)[pyarrow]", "decimal256(76, 38)[pyarrow]", "binary[pyarrow]", + "duration[us][pyarrow]", ] DTYPE_STRINGS = typing.get_args(DtypeString) @@ -421,6 +422,8 @@ def is_bool_coercable(type_: ExpressionType) -> bool: # special case - both "Int64" and "int64[pyarrow]" are accepted BIGFRAMES_STRING_TO_BIGFRAMES["int64[pyarrow]"] = INT_DTYPE +BIGFRAMES_STRING_TO_BIGFRAMES["duration[us][pyarrow]"] = TIMEDELTA_DTYPE + # For the purposes of dataframe.memory_usage DTYPE_BYTE_SIZES = { type_info.dtype: type_info.logical_bytes for type_info in SIMPLE_TYPES diff --git a/bigframes/session/executor.py b/bigframes/session/executor.py index c913f39791..cc8f086f9f 100644 --- a/bigframes/session/executor.py +++ b/bigframes/session/executor.py @@ -50,6 +50,7 @@ def arrow_batches(self) -> Iterator[pyarrow.RecordBatch]: result_rows = 0 for batch in self._arrow_batches: + batch = pyarrow_utils.cast_batch(batch, self.schema.to_pyarrow()) result_rows += batch.num_rows maximum_result_rows = bigframes.options.compute.maximum_result_rows diff --git a/bigframes/testing/utils.py b/bigframes/testing/utils.py index ecf9ae00f8..c3a8008465 100644 --- a/bigframes/testing/utils.py +++ b/bigframes/testing/utils.py @@ -185,6 +185,16 @@ def convert_pandas_dtypes(df: pd.DataFrame, bytes_col: bool): "timestamp_col" ] + if not isinstance(df["duration_col"].dtype, pd.ArrowDtype): + df["duration_col"] = df["duration_col"].astype(pd.Int64Dtype()) + arrow_table = pa.Table.from_pandas( + pd.DataFrame(df, columns=["duration_col"]), + schema=pa.schema([("duration_col", pa.duration("us"))]), + ) + df["duration_col"] = arrow_table.to_pandas(types_mapper=pd.ArrowDtype)[ + "duration_col" + ] + # Convert geography types columns. if "geography_col" in df.columns: df["geography_col"] = df["geography_col"].astype( diff --git a/tests/data/scalars.jsonl b/tests/data/scalars.jsonl index 2e5a1499b9..6e591cfa72 100644 --- a/tests/data/scalars.jsonl +++ b/tests/data/scalars.jsonl @@ -1,9 +1,9 @@ -{"bool_col": true, "bytes_col": "SGVsbG8sIFdvcmxkIQ==", "date_col": "2021-07-21", "datetime_col": "2021-07-21 11:39:45", "geography_col": "POINT(-122.0838511 37.3860517)", "int64_col": "123456789", "int64_too": "0", "numeric_col": "1.23456789", "float64_col": "1.25", "rowindex": 0, "rowindex_2": 0, "string_col": "Hello, World!", "time_col": "11:41:43.076160", "timestamp_col": "2021-07-21T17:43:43.945289Z"} -{"bool_col": false, "bytes_col": "44GT44KT44Gr44Gh44Gv", "date_col": "1991-02-03", "datetime_col": "1991-01-02 03:45:06", "geography_col": "POINT(-71.104 42.315)", "int64_col": "-987654321", "int64_too": "1", "numeric_col": "1.23456789", "float64_col": "2.51", "rowindex": 1, "rowindex_2": 1, "string_col": "こんにちは", "time_col": "11:14:34.701606", "timestamp_col": "2021-07-21T17:43:43.945289Z"} -{"bool_col": true, "bytes_col": "wqFIb2xhIE11bmRvIQ==", "date_col": "2023-03-01", "datetime_col": "2023-03-01 10:55:13", "geography_col": "POINT(-0.124474760143016 51.5007826749545)", "int64_col": "314159", "int64_too": "0", "numeric_col": "101.1010101", "float64_col": "2.5e10", "rowindex": 2, "rowindex_2": 2, "string_col": " ¡Hola Mundo! ", "time_col": "23:59:59.999999", "timestamp_col": "2023-03-01T10:55:13.250125Z"} -{"bool_col": null, "bytes_col": null, "date_col": null, "datetime_col": null, "geography_col": null, "int64_col": null, "int64_too": "1", "numeric_col": null, "float64_col": null, "rowindex": 3, "rowindex_2": 3, "string_col": null, "time_col": null, "timestamp_col": null} -{"bool_col": false, "bytes_col": "44GT44KT44Gr44Gh44Gv", "date_col": "2021-07-21", "datetime_col": null, "geography_col": null, "int64_col": "-234892", "int64_too": "-2345", "numeric_col": null, "float64_col": null, "rowindex": 4, "rowindex_2": 4, "string_col": "Hello, World!", "time_col": null, "timestamp_col": null} -{"bool_col": false, "bytes_col": "R8O8dGVuIFRhZw==", "date_col": "1980-03-14", "datetime_col": "1980-03-14 15:16:17", "geography_col": null, "int64_col": "55555", "int64_too": "0", "numeric_col": "5.555555", "float64_col": "555.555", "rowindex": 5, "rowindex_2": 5, "string_col": "Güten Tag!", "time_col": "15:16:17.181921", "timestamp_col": "1980-03-14T15:16:17.181921Z"} -{"bool_col": true, "bytes_col": "SGVsbG8JQmlnRnJhbWVzIQc=", "date_col": "2023-05-23", "datetime_col": "2023-05-23 11:37:01", "geography_col": "LINESTRING(-0.127959 51.507728, -0.127026 51.507473)", "int64_col": "101202303", "int64_too": "2", "numeric_col": "-10.090807", "float64_col": "-123.456", "rowindex": 6, "rowindex_2": 6, "string_col": "capitalize, This ", "time_col": "01:02:03.456789", "timestamp_col": "2023-05-23T11:42:55.000001Z"} -{"bool_col": true, "bytes_col": null, "date_col": "2038-01-20", "datetime_col": "2038-01-19 03:14:08", "geography_col": null, "int64_col": "-214748367", "int64_too": "2", "numeric_col": "11111111.1", "float64_col": "42.42", "rowindex": 7, "rowindex_2": 7, "string_col": " سلام", "time_col": "12:00:00.000001", "timestamp_col": "2038-01-19T03:14:17.999999Z"} -{"bool_col": false, "bytes_col": null, "date_col": null, "datetime_col": null, "geography_col": null, "int64_col": "2", "int64_too": "1", "numeric_col": null, "float64_col": "6.87", "rowindex": 8, "rowindex_2": 8, "string_col": "T", "time_col": null, "timestamp_col": null} \ No newline at end of file +{"bool_col": true, "bytes_col": "SGVsbG8sIFdvcmxkIQ==", "date_col": "2021-07-21", "datetime_col": "2021-07-21 11:39:45", "geography_col": "POINT(-122.0838511 37.3860517)", "int64_col": "123456789", "int64_too": "0", "numeric_col": "1.23456789", "float64_col": "1.25", "rowindex": 0, "rowindex_2": 0, "string_col": "Hello, World!", "time_col": "11:41:43.076160", "timestamp_col": "2021-07-21T17:43:43.945289Z", "duration_col": 4} +{"bool_col": false, "bytes_col": "44GT44KT44Gr44Gh44Gv", "date_col": "1991-02-03", "datetime_col": "1991-01-02 03:45:06", "geography_col": "POINT(-71.104 42.315)", "int64_col": "-987654321", "int64_too": "1", "numeric_col": "1.23456789", "float64_col": "2.51", "rowindex": 1, "rowindex_2": 1, "string_col": "こんにちは", "time_col": "11:14:34.701606", "timestamp_col": "2021-07-21T17:43:43.945289Z", "duration_col": -1000000} +{"bool_col": true, "bytes_col": "wqFIb2xhIE11bmRvIQ==", "date_col": "2023-03-01", "datetime_col": "2023-03-01 10:55:13", "geography_col": "POINT(-0.124474760143016 51.5007826749545)", "int64_col": "314159", "int64_too": "0", "numeric_col": "101.1010101", "float64_col": "2.5e10", "rowindex": 2, "rowindex_2": 2, "string_col": " ¡Hola Mundo! ", "time_col": "23:59:59.999999", "timestamp_col": "2023-03-01T10:55:13.250125Z", "duration_col": 0} +{"bool_col": null, "bytes_col": null, "date_col": null, "datetime_col": null, "geography_col": null, "int64_col": null, "int64_too": "1", "numeric_col": null, "float64_col": null, "rowindex": 3, "rowindex_2": 3, "string_col": null, "time_col": null, "timestamp_col": null, "duration_col": null} +{"bool_col": false, "bytes_col": "44GT44KT44Gr44Gh44Gv", "date_col": "2021-07-21", "datetime_col": null, "geography_col": null, "int64_col": "-234892", "int64_too": "-2345", "numeric_col": null, "float64_col": null, "rowindex": 4, "rowindex_2": 4, "string_col": "Hello, World!", "time_col": null, "timestamp_col": null, "duration_col": 31540000000000} +{"bool_col": false, "bytes_col": "R8O8dGVuIFRhZw==", "date_col": "1980-03-14", "datetime_col": "1980-03-14 15:16:17", "geography_col": null, "int64_col": "55555", "int64_too": "0", "numeric_col": "5.555555", "float64_col": "555.555", "rowindex": 5, "rowindex_2": 5, "string_col": "Güten Tag!", "time_col": "15:16:17.181921", "timestamp_col": "1980-03-14T15:16:17.181921Z", "duration_col": 4} +{"bool_col": true, "bytes_col": "SGVsbG8JQmlnRnJhbWVzIQc=", "date_col": "2023-05-23", "datetime_col": "2023-05-23 11:37:01", "geography_col": "LINESTRING(-0.127959 51.507728, -0.127026 51.507473)", "int64_col": "101202303", "int64_too": "2", "numeric_col": "-10.090807", "float64_col": "-123.456", "rowindex": 6, "rowindex_2": 6, "string_col": "capitalize, This ", "time_col": "01:02:03.456789", "timestamp_col": "2023-05-23T11:42:55.000001Z", "duration_col": null} +{"bool_col": true, "bytes_col": null, "date_col": "2038-01-20", "datetime_col": "2038-01-19 03:14:08", "geography_col": null, "int64_col": "-214748367", "int64_too": "2", "numeric_col": "11111111.1", "float64_col": "42.42", "rowindex": 7, "rowindex_2": 7, "string_col": " سلام", "time_col": "12:00:00.000001", "timestamp_col": "2038-01-19T03:14:17.999999Z", "duration_col": 4} +{"bool_col": false, "bytes_col": null, "date_col": null, "datetime_col": null, "geography_col": null, "int64_col": "2", "int64_too": "1", "numeric_col": null, "float64_col": "6.87", "rowindex": 8, "rowindex_2": 8, "string_col": "T", "time_col": null, "timestamp_col": null, "duration_col": 432000000000} diff --git a/tests/data/scalars_schema.json b/tests/data/scalars_schema.json index 1f5d8cdb65..8be4e95228 100644 --- a/tests/data/scalars_schema.json +++ b/tests/data/scalars_schema.json @@ -71,5 +71,11 @@ "mode": "NULLABLE", "name": "timestamp_col", "type": "TIMESTAMP" + }, + { + "mode": "NULLABLE", + "name": "duration_col", + "type": "INTEGER", + "description": "#microseconds" } ] diff --git a/tests/system/small/pandas/core/methods/test_describe.py b/tests/system/small/pandas/core/methods/test_describe.py index dfc7c3fb23..5971e47997 100644 --- a/tests/system/small/pandas/core/methods/test_describe.py +++ b/tests/system/small/pandas/core/methods/test_describe.py @@ -21,7 +21,13 @@ def test_df_describe_non_temporal(scalars_dfs): pytest.importorskip("pandas", minversion="2.0.0") scalars_df, scalars_pandas_df = scalars_dfs # excluding temporal columns here because BigFrames cannot perform percentiles operations on them - unsupported_columns = ["datetime_col", "timestamp_col", "time_col", "date_col"] + unsupported_columns = [ + "datetime_col", + "timestamp_col", + "time_col", + "date_col", + "duration_col", + ] bf_result = scalars_df.drop(columns=unsupported_columns).describe().to_pandas() modified_pd_df = scalars_pandas_df.drop(columns=unsupported_columns) diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index e8d156538f..5045e2268f 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -553,7 +553,7 @@ def test_df_info(scalars_dfs): expected = ( "\n" "Index: 9 entries, 0 to 8\n" - "Data columns (total 13 columns):\n" + "Data columns (total 14 columns):\n" " # Column Non-Null Count Dtype\n" "--- ------------- ---------------- ------------------------------\n" " 0 bool_col 8 non-null boolean\n" @@ -569,8 +569,9 @@ def test_df_info(scalars_dfs): " 10 string_col 8 non-null string\n" " 11 time_col 6 non-null time64[us][pyarrow]\n" " 12 timestamp_col 6 non-null timestamp[us, tz=UTC][pyarrow]\n" - "dtypes: Float64(1), Int64(3), binary[pyarrow](1), boolean(1), date32[day][pyarrow](1), decimal128(38, 9)[pyarrow](1), geometry(1), string(1), time64[us][pyarrow](1), timestamp[us, tz=UTC][pyarrow](1), timestamp[us][pyarrow](1)\n" - "memory usage: 1269 bytes\n" + " 13 duration_col 7 non-null duration[us][pyarrow]\n" + "dtypes: Float64(1), Int64(3), binary[pyarrow](1), boolean(1), date32[day][pyarrow](1), decimal128(38, 9)[pyarrow](1), duration[us][pyarrow](1), geometry(1), string(1), time64[us][pyarrow](1), timestamp[us, tz=UTC][pyarrow](1), timestamp[us][pyarrow](1)\n" + "memory usage: 1341 bytes\n" ) scalars_df, _ = scalars_dfs @@ -1694,6 +1695,7 @@ def test_get_dtypes(scalars_df_default_index): "string_col": pd.StringDtype(storage="pyarrow"), "time_col": pd.ArrowDtype(pa.time64("us")), "timestamp_col": pd.ArrowDtype(pa.timestamp("us", tz="UTC")), + "duration_col": pd.ArrowDtype(pa.duration("us")), } pd.testing.assert_series_equal( dtypes, @@ -4771,6 +4773,9 @@ def test_df_to_json_local_str(scalars_df_index, scalars_pandas_df_index): def test_df_to_json_local_file(scalars_df_index, scalars_pandas_df_index): # TODO: supply a reason why this isn't compatible with pandas 1.x pytest.importorskip("pandas", minversion="2.0.0") + # duration not fully supported at pandas level + scalars_df_index = scalars_df_index.drop(columns="duration_col") + scalars_pandas_df_index = scalars_pandas_df_index.drop(columns="duration_col") with tempfile.TemporaryFile() as bf_result_file, tempfile.TemporaryFile() as pd_result_file: scalars_df_index.to_json(bf_result_file, orient="table") # default_handler for arrow types that have no default conversion @@ -4882,6 +4887,7 @@ def test_df_to_orc(scalars_df_index, scalars_pandas_df_index): "time_col", "timestamp_col", "geography_col", + "duration_col", ] bf_result_file = tempfile.TemporaryFile() diff --git a/tests/system/small/test_dataframe_io.py b/tests/system/small/test_dataframe_io.py index afe3b53d6d..ef6e25a95c 100644 --- a/tests/system/small/test_dataframe_io.py +++ b/tests/system/small/test_dataframe_io.py @@ -55,7 +55,7 @@ def test_sql_executes(scalars_df_default_index, bigquery_client): """ # Do some operations to make for more complex SQL. df = ( - scalars_df_default_index.drop(columns=["geography_col"]) + scalars_df_default_index.drop(columns=["geography_col", "duration_col"]) .groupby("string_col") .max() ) @@ -87,7 +87,7 @@ def test_sql_executes_and_includes_named_index( """ # Do some operations to make for more complex SQL. df = ( - scalars_df_default_index.drop(columns=["geography_col"]) + scalars_df_default_index.drop(columns=["geography_col", "duration_col"]) .groupby("string_col") .max() ) @@ -120,7 +120,7 @@ def test_sql_executes_and_includes_named_multiindex( """ # Do some operations to make for more complex SQL. df = ( - scalars_df_default_index.drop(columns=["geography_col"]) + scalars_df_default_index.drop(columns=["geography_col", "duration_col"]) .groupby(["string_col", "bool_col"]) .max() ) @@ -999,14 +999,16 @@ def test_to_sql_query_unnamed_index_included( scalars_df_default_index: bpd.DataFrame, scalars_pandas_df_default_index: pd.DataFrame, ): - bf_df = scalars_df_default_index.reset_index(drop=True) + bf_df = scalars_df_default_index.reset_index(drop=True).drop(columns="duration_col") sql, idx_ids, idx_labels = bf_df._to_sql_query(include_index=True) assert len(idx_labels) == 1 assert len(idx_ids) == 1 assert idx_labels[0] is None assert idx_ids[0].startswith("bigframes") - pd_df = scalars_pandas_df_default_index.reset_index(drop=True) + pd_df = scalars_pandas_df_default_index.reset_index(drop=True).drop( + columns="duration_col" + ) roundtrip = session.read_gbq(sql, index_col=idx_ids) roundtrip.index.names = [None] utils.assert_pandas_df_equal(roundtrip.to_pandas(), pd_df, check_index_type=False) @@ -1017,14 +1019,18 @@ def test_to_sql_query_named_index_included( scalars_df_default_index: bpd.DataFrame, scalars_pandas_df_default_index: pd.DataFrame, ): - bf_df = scalars_df_default_index.set_index("rowindex_2", drop=True) + bf_df = scalars_df_default_index.set_index("rowindex_2", drop=True).drop( + columns="duration_col" + ) sql, idx_ids, idx_labels = bf_df._to_sql_query(include_index=True) assert len(idx_labels) == 1 assert len(idx_ids) == 1 assert idx_labels[0] == "rowindex_2" assert idx_ids[0] == "rowindex_2" - pd_df = scalars_pandas_df_default_index.set_index("rowindex_2", drop=True) + pd_df = scalars_pandas_df_default_index.set_index("rowindex_2", drop=True).drop( + columns="duration_col" + ) roundtrip = session.read_gbq(sql, index_col=idx_ids) utils.assert_pandas_df_equal(roundtrip.to_pandas(), pd_df) @@ -1034,12 +1040,14 @@ def test_to_sql_query_unnamed_index_excluded( scalars_df_default_index: bpd.DataFrame, scalars_pandas_df_default_index: pd.DataFrame, ): - bf_df = scalars_df_default_index.reset_index(drop=True) + bf_df = scalars_df_default_index.reset_index(drop=True).drop(columns="duration_col") sql, idx_ids, idx_labels = bf_df._to_sql_query(include_index=False) assert len(idx_labels) == 0 assert len(idx_ids) == 0 - pd_df = scalars_pandas_df_default_index.reset_index(drop=True) + pd_df = scalars_pandas_df_default_index.reset_index(drop=True).drop( + columns="duration_col" + ) roundtrip = session.read_gbq(sql) utils.assert_pandas_df_equal( roundtrip.to_pandas(), pd_df, check_index_type=False, ignore_order=True @@ -1051,14 +1059,18 @@ def test_to_sql_query_named_index_excluded( scalars_df_default_index: bpd.DataFrame, scalars_pandas_df_default_index: pd.DataFrame, ): - bf_df = scalars_df_default_index.set_index("rowindex_2", drop=True) + bf_df = scalars_df_default_index.set_index("rowindex_2", drop=True).drop( + columns="duration_col" + ) sql, idx_ids, idx_labels = bf_df._to_sql_query(include_index=False) assert len(idx_labels) == 0 assert len(idx_ids) == 0 - pd_df = scalars_pandas_df_default_index.set_index( - "rowindex_2", drop=True - ).reset_index(drop=True) + pd_df = ( + scalars_pandas_df_default_index.set_index("rowindex_2", drop=True) + .reset_index(drop=True) + .drop(columns="duration_col") + ) roundtrip = session.read_gbq(sql) utils.assert_pandas_df_equal( roundtrip.to_pandas(), pd_df, check_index_type=False, ignore_order=True diff --git a/tests/system/small/test_session.py b/tests/system/small/test_session.py index 809d08c6c1..4bb1c6589a 100644 --- a/tests/system/small/test_session.py +++ b/tests/system/small/test_session.py @@ -54,7 +54,13 @@ def df_and_local_csv(scalars_df_index): # The auto detects of BigQuery load job have restrictions to detect the bytes, # datetime, numeric and geometry types, so they're skipped here. - drop_columns = ["bytes_col", "datetime_col", "numeric_col", "geography_col"] + drop_columns = [ + "bytes_col", + "datetime_col", + "numeric_col", + "geography_col", + "duration_col", + ] scalars_df_index = scalars_df_index.drop(columns=drop_columns) with tempfile.TemporaryDirectory() as dir: @@ -68,7 +74,13 @@ def df_and_local_csv(scalars_df_index): def df_and_gcs_csv(scalars_df_index, gcs_folder): # The auto detects of BigQuery load job have restrictions to detect the bytes, # datetime, numeric and geometry types, so they're skipped here. - drop_columns = ["bytes_col", "datetime_col", "numeric_col", "geography_col"] + drop_columns = [ + "bytes_col", + "datetime_col", + "numeric_col", + "geography_col", + "duration_col", + ] scalars_df_index = scalars_df_index.drop(columns=drop_columns) path = gcs_folder + "test_read_csv_w_gcs_csv*.csv" @@ -1808,6 +1820,7 @@ def test_read_parquet_gcs( df_out = df_out.assign( datetime_col=df_out["datetime_col"].astype("timestamp[us][pyarrow]"), timestamp_col=df_out["timestamp_col"].astype("timestamp[us, tz=UTC][pyarrow]"), + duration_col=df_out["duration_col"].astype("duration[us][pyarrow]"), ) # Make sure we actually have at least some values before comparing. @@ -1856,7 +1869,8 @@ def test_read_parquet_gcs_compressed( # DATETIME gets loaded as TIMESTAMP in parquet. See: # https://cloud.google.com/bigquery/docs/exporting-data#parquet_export_details df_out = df_out.assign( - datetime_col=df_out["datetime_col"].astype("timestamp[us][pyarrow]") + datetime_col=df_out["datetime_col"].astype("timestamp[us][pyarrow]"), + duration_col=df_out["duration_col"].astype("duration[us][pyarrow]"), ) # Make sure we actually have at least some values before comparing. @@ -1914,9 +1928,23 @@ def test_read_json_gcs_bq_engine(session, scalars_dfs, gcs_folder): # The auto detects of BigQuery load job have restrictions to detect the bytes, # datetime, numeric and geometry types, so they're skipped here. - df = df.drop(columns=["bytes_col", "datetime_col", "numeric_col", "geography_col"]) + df = df.drop( + columns=[ + "bytes_col", + "datetime_col", + "numeric_col", + "geography_col", + "duration_col", + ] + ) scalars_df = scalars_df.drop( - columns=["bytes_col", "datetime_col", "numeric_col", "geography_col"] + columns=[ + "bytes_col", + "datetime_col", + "numeric_col", + "geography_col", + "duration_col", + ] ) assert df.shape[0] == scalars_df.shape[0] pd.testing.assert_series_equal( @@ -1949,8 +1977,10 @@ def test_read_json_gcs_default_engine(session, scalars_dfs, gcs_folder): # The auto detects of BigQuery load job have restrictions to detect the bytes, # numeric and geometry types, so they're skipped here. - df = df.drop(columns=["bytes_col", "numeric_col", "geography_col"]) - scalars_df = scalars_df.drop(columns=["bytes_col", "numeric_col", "geography_col"]) + df = df.drop(columns=["bytes_col", "numeric_col", "geography_col", "duration_col"]) + scalars_df = scalars_df.drop( + columns=["bytes_col", "numeric_col", "geography_col", "duration_col"] + ) # pandas read_json does not respect the dtype overrides for these columns df = df.drop(columns=["date_col", "datetime_col", "time_col"]) diff --git a/tests/unit/core/compile/sqlglot/test_compile_readlocal.py b/tests/unit/core/compile/sqlglot/test_compile_readlocal.py index 7307fd9b4e..6f8a2050e5 100644 --- a/tests/unit/core/compile/sqlglot/test_compile_readlocal.py +++ b/tests/unit/core/compile/sqlglot/test_compile_readlocal.py @@ -24,6 +24,8 @@ def test_compile_readlocal( scalar_types_pandas_df: pd.DataFrame, compiler_session: bigframes.Session, snapshot ): + # Durations not yet supported + scalar_types_pandas_df = scalar_types_pandas_df.drop(["duration_col"], axis=1) bf_df = bpd.DataFrame(scalar_types_pandas_df, session=compiler_session) snapshot.assert_match(bf_df.sql, "out.sql") diff --git a/tests/unit/test_dataframe_polars.py b/tests/unit/test_dataframe_polars.py index f7f0cc80bb..467cf7ce3d 100644 --- a/tests/unit/test_dataframe_polars.py +++ b/tests/unit/test_dataframe_polars.py @@ -508,7 +508,7 @@ def test_df_info(scalars_dfs): expected = ( "\n" "Index: 9 entries, 0 to 8\n" - "Data columns (total 13 columns):\n" + "Data columns (total 14 columns):\n" " # Column Non-Null Count Dtype\n" "--- ------------- ---------------- ------------------------------\n" " 0 bool_col 8 non-null boolean\n" @@ -524,8 +524,9 @@ def test_df_info(scalars_dfs): " 10 string_col 8 non-null string\n" " 11 time_col 6 non-null time64[us][pyarrow]\n" " 12 timestamp_col 6 non-null timestamp[us, tz=UTC][pyarrow]\n" - "dtypes: Float64(1), Int64(3), binary[pyarrow](1), boolean(1), date32[day][pyarrow](1), decimal128(38, 9)[pyarrow](1), geometry(1), string(1), time64[us][pyarrow](1), timestamp[us, tz=UTC][pyarrow](1), timestamp[us][pyarrow](1)\n" - "memory usage: 1269 bytes\n" + " 13 duration_col 7 non-null duration[us][pyarrow]\n" + "dtypes: Float64(1), Int64(3), binary[pyarrow](1), boolean(1), date32[day][pyarrow](1), decimal128(38, 9)[pyarrow](1), duration[us][pyarrow](1), geometry(1), string(1), time64[us][pyarrow](1), timestamp[us, tz=UTC][pyarrow](1), timestamp[us][pyarrow](1)\n" + "memory usage: 1341 bytes\n" ) scalars_df, _ = scalars_dfs @@ -4086,6 +4087,9 @@ def test_df_to_json_local_str(scalars_df_index, scalars_pandas_df_index): def test_df_to_json_local_file(scalars_df_index, scalars_pandas_df_index): # TODO: supply a reason why this isn't compatible with pandas 1.x pytest.importorskip("pandas", minversion="2.0.0") + # duration not fully supported at pandas level + scalars_df_index = scalars_df_index.drop(columns="duration_col") + scalars_pandas_df_index = scalars_pandas_df_index.drop(columns="duration_col") with tempfile.TemporaryFile() as bf_result_file, tempfile.TemporaryFile() as pd_result_file: scalars_df_index.to_json(bf_result_file, orient="table") # default_handler for arrow types that have no default conversion @@ -4197,6 +4201,7 @@ def test_df_to_orc(scalars_df_index, scalars_pandas_df_index): "time_col", "timestamp_col", "geography_col", + "duration_col", ] bf_result_file = tempfile.TemporaryFile() From 3e6dfe77a60ed8eaf1c4846dd28dd19824fe8d3c Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 2 Jul 2025 11:12:20 -0700 Subject: [PATCH 06/13] refactor: add _join_condition for all types (#1880) Fixes internal issue 427501553 --- .../core/compile/sqlglot/scalar_compiler.py | 2 +- bigframes/core/compile/sqlglot/sqlglot_ir.py | 87 ++++++++++++++++++- bigframes/dtypes.py | 5 +- .../test_compile_join/out.sql | 3 +- .../test_compile_join_w_on/bool_col/out.sql | 33 +++++++ .../float64_col/out.sql | 33 +++++++ .../test_compile_join_w_on/int64_col/out.sql | 33 +++++++ .../numeric_col/out.sql | 33 +++++++ .../test_compile_join_w_on/string_col/out.sql | 28 ++++++ .../test_compile_join_w_on/time_col/out.sql | 28 ++++++ .../core/compile/sqlglot/test_compile_join.py | 10 +++ 11 files changed, 290 insertions(+), 5 deletions(-) create mode 100644 tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/bool_col/out.sql create mode 100644 tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/float64_col/out.sql create mode 100644 tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/int64_col/out.sql create mode 100644 tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/numeric_col/out.sql create mode 100644 tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/string_col/out.sql create mode 100644 tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/time_col/out.sql diff --git a/bigframes/core/compile/sqlglot/scalar_compiler.py b/bigframes/core/compile/sqlglot/scalar_compiler.py index 0db507b0fa..683dd38c9a 100644 --- a/bigframes/core/compile/sqlglot/scalar_compiler.py +++ b/bigframes/core/compile/sqlglot/scalar_compiler.py @@ -39,7 +39,7 @@ def compile_scalar_expression( @compile_scalar_expression.register def compile_deref_expression(expr: expression.DerefOp) -> sge.Expression: - return sge.ColumnDef(this=sge.to_identifier(expr.id.sql, quoted=True)) + return sge.Column(this=sge.to_identifier(expr.id.sql, quoted=True)) @compile_scalar_expression.register diff --git a/bigframes/core/compile/sqlglot/sqlglot_ir.py b/bigframes/core/compile/sqlglot/sqlglot_ir.py index 3b4d7ed0ce..d5902fa6fc 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_ir.py +++ b/bigframes/core/compile/sqlglot/sqlglot_ir.py @@ -491,4 +491,89 @@ def _join_condition( right: typed_expr.TypedExpr, joins_nulls: bool, ) -> typing.Union[sge.EQ, sge.And]: - return sge.EQ(this=left.expr, expression=right.expr) + """Generates a join condition to match pandas's null-handling logic. + + Pandas treats null values as distinct from each other, leading to a + cross-join-like behavior for null keys. In contrast, BigQuery SQL treats + null values as equal, leading to a inner-join-like behavior. + + This function generates the appropriate SQL condition to replicate the + desired pandas behavior in BigQuery. + + Args: + left: The left-side join key. + right: The right-side join key. + joins_nulls: If True, generates complex logic to handle nulls/NaNs. + Otherwise, uses a simple equality check where appropriate. + """ + is_floating_types = ( + left.dtype == dtypes.FLOAT_DTYPE and right.dtype == dtypes.FLOAT_DTYPE + ) + if not is_floating_types and not joins_nulls: + return sge.EQ(this=left.expr, expression=right.expr) + + is_numeric_types = dtypes.is_numeric( + left.dtype, include_bool=False + ) and dtypes.is_numeric(right.dtype, include_bool=False) + if is_numeric_types: + return _join_condition_for_numeric(left, right) + else: + return _join_condition_for_others(left, right) + + +def _join_condition_for_others( + left: typed_expr.TypedExpr, + right: typed_expr.TypedExpr, +) -> sge.And: + """Generates a join condition for non-numeric types to match pandas's + null-handling logic. + """ + left_str = _cast(left.expr, "STRING") + right_str = _cast(right.expr, "STRING") + left_0 = sge.func("COALESCE", left_str, _literal("0", dtypes.STRING_DTYPE)) + left_1 = sge.func("COALESCE", left_str, _literal("1", dtypes.STRING_DTYPE)) + right_0 = sge.func("COALESCE", right_str, _literal("0", dtypes.STRING_DTYPE)) + right_1 = sge.func("COALESCE", right_str, _literal("1", dtypes.STRING_DTYPE)) + return sge.And( + this=sge.EQ(this=left_0, expression=right_0), + expression=sge.EQ(this=left_1, expression=right_1), + ) + + +def _join_condition_for_numeric( + left: typed_expr.TypedExpr, + right: typed_expr.TypedExpr, +) -> sge.And: + """Generates a join condition for non-numeric types to match pandas's + null-handling logic. Specifically for FLOAT types, Pandas treats NaN aren't + equal so need to coalesce as well with different constants. + """ + is_floating_types = ( + left.dtype == dtypes.FLOAT_DTYPE and right.dtype == dtypes.FLOAT_DTYPE + ) + left_0 = sge.func("COALESCE", left.expr, _literal(0, left.dtype)) + left_1 = sge.func("COALESCE", left.expr, _literal(1, left.dtype)) + right_0 = sge.func("COALESCE", right.expr, _literal(0, right.dtype)) + right_1 = sge.func("COALESCE", right.expr, _literal(1, right.dtype)) + if not is_floating_types: + return sge.And( + this=sge.EQ(this=left_0, expression=right_0), + expression=sge.EQ(this=left_1, expression=right_1), + ) + + left_2 = sge.If( + this=sge.IsNan(this=left.expr), true=_literal(2, left.dtype), false=left_0 + ) + left_3 = sge.If( + this=sge.IsNan(this=left.expr), true=_literal(3, left.dtype), false=left_1 + ) + right_2 = sge.If( + this=sge.IsNan(this=right.expr), true=_literal(2, right.dtype), false=right_0 + ) + right_3 = sge.If( + this=sge.IsNan(this=right.expr), true=_literal(3, right.dtype), false=right_1 + ) + return sge.And( + this=sge.EQ(this=left_2, expression=right_2), + expression=sge.EQ(this=left_3, expression=right_3), + ) diff --git a/bigframes/dtypes.py b/bigframes/dtypes.py index 20f2f5ee12..0be31505df 100644 --- a/bigframes/dtypes.py +++ b/bigframes/dtypes.py @@ -341,8 +341,9 @@ def is_json_encoding_type(type_: ExpressionType) -> bool: return type_ != GEO_DTYPE -def is_numeric(type_: ExpressionType) -> bool: - return type_ in NUMERIC_BIGFRAMES_TYPES_PERMISSIVE +def is_numeric(type_: ExpressionType, include_bool: bool = True) -> bool: + is_numeric = type_ in NUMERIC_BIGFRAMES_TYPES_PERMISSIVE + return is_numeric if include_bool else is_numeric and type_ != BOOL_DTYPE def is_iterable(type_: ExpressionType) -> bool: diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join/out.sql index aefaa28dfb..85eab4487a 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join/out.sql @@ -23,7 +23,8 @@ WITH `bfcte_1` AS ( * FROM `bfcte_2` LEFT JOIN `bfcte_3` - ON `bfcol_2` = `bfcol_6` + ON COALESCE(`bfcol_2`, 0) = COALESCE(`bfcol_6`, 0) + AND COALESCE(`bfcol_2`, 1) = COALESCE(`bfcol_6`, 1) ) SELECT `bfcol_3` AS `int64_col`, diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/bool_col/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/bool_col/out.sql new file mode 100644 index 0000000000..a073e35c69 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/bool_col/out.sql @@ -0,0 +1,33 @@ +WITH `bfcte_1` AS ( + SELECT + `bool_col` AS `bfcol_0`, + `rowindex` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_2` AS ( + SELECT + `bfcol_1` AS `bfcol_2`, + `bfcol_0` AS `bfcol_3` + FROM `bfcte_1` +), `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_4`, + `rowindex` AS `bfcol_5` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_3` AS ( + SELECT + `bfcol_5` AS `bfcol_6`, + `bfcol_4` AS `bfcol_7` + FROM `bfcte_0` +), `bfcte_4` AS ( + SELECT + * + FROM `bfcte_2` + INNER JOIN `bfcte_3` + ON COALESCE(CAST(`bfcol_3` AS STRING), '0') = COALESCE(CAST(`bfcol_7` AS STRING), '0') + AND COALESCE(CAST(`bfcol_3` AS STRING), '1') = COALESCE(CAST(`bfcol_7` AS STRING), '1') +) +SELECT + `bfcol_2` AS `rowindex_x`, + `bfcol_3` AS `bool_col`, + `bfcol_6` AS `rowindex_y` +FROM `bfcte_4` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/float64_col/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/float64_col/out.sql new file mode 100644 index 0000000000..1d04343f31 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/float64_col/out.sql @@ -0,0 +1,33 @@ +WITH `bfcte_1` AS ( + SELECT + `float64_col` AS `bfcol_0`, + `rowindex` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_2` AS ( + SELECT + `bfcol_1` AS `bfcol_2`, + `bfcol_0` AS `bfcol_3` + FROM `bfcte_1` +), `bfcte_0` AS ( + SELECT + `float64_col` AS `bfcol_4`, + `rowindex` AS `bfcol_5` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_3` AS ( + SELECT + `bfcol_5` AS `bfcol_6`, + `bfcol_4` AS `bfcol_7` + FROM `bfcte_0` +), `bfcte_4` AS ( + SELECT + * + FROM `bfcte_2` + INNER JOIN `bfcte_3` + ON IF(IS_NAN(`bfcol_3`), 2, COALESCE(`bfcol_3`, 0)) = IF(IS_NAN(`bfcol_7`), 2, COALESCE(`bfcol_7`, 0)) + AND IF(IS_NAN(`bfcol_3`), 3, COALESCE(`bfcol_3`, 1)) = IF(IS_NAN(`bfcol_7`), 3, COALESCE(`bfcol_7`, 1)) +) +SELECT + `bfcol_2` AS `rowindex_x`, + `bfcol_3` AS `float64_col`, + `bfcol_6` AS `rowindex_y` +FROM `bfcte_4` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/int64_col/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/int64_col/out.sql new file mode 100644 index 0000000000..80ec5d19d1 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/int64_col/out.sql @@ -0,0 +1,33 @@ +WITH `bfcte_1` AS ( + SELECT + `int64_col` AS `bfcol_0`, + `rowindex` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_2` AS ( + SELECT + `bfcol_1` AS `bfcol_2`, + `bfcol_0` AS `bfcol_3` + FROM `bfcte_1` +), `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_4`, + `rowindex` AS `bfcol_5` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_3` AS ( + SELECT + `bfcol_5` AS `bfcol_6`, + `bfcol_4` AS `bfcol_7` + FROM `bfcte_0` +), `bfcte_4` AS ( + SELECT + * + FROM `bfcte_2` + INNER JOIN `bfcte_3` + ON COALESCE(`bfcol_3`, 0) = COALESCE(`bfcol_7`, 0) + AND COALESCE(`bfcol_3`, 1) = COALESCE(`bfcol_7`, 1) +) +SELECT + `bfcol_2` AS `rowindex_x`, + `bfcol_3` AS `int64_col`, + `bfcol_6` AS `rowindex_y` +FROM `bfcte_4` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/numeric_col/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/numeric_col/out.sql new file mode 100644 index 0000000000..22ce6f5b29 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/numeric_col/out.sql @@ -0,0 +1,33 @@ +WITH `bfcte_1` AS ( + SELECT + `numeric_col` AS `bfcol_0`, + `rowindex` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_2` AS ( + SELECT + `bfcol_1` AS `bfcol_2`, + `bfcol_0` AS `bfcol_3` + FROM `bfcte_1` +), `bfcte_0` AS ( + SELECT + `numeric_col` AS `bfcol_4`, + `rowindex` AS `bfcol_5` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_3` AS ( + SELECT + `bfcol_5` AS `bfcol_6`, + `bfcol_4` AS `bfcol_7` + FROM `bfcte_0` +), `bfcte_4` AS ( + SELECT + * + FROM `bfcte_2` + INNER JOIN `bfcte_3` + ON COALESCE(`bfcol_3`, CAST(0 AS NUMERIC)) = COALESCE(`bfcol_7`, CAST(0 AS NUMERIC)) + AND COALESCE(`bfcol_3`, CAST(1 AS NUMERIC)) = COALESCE(`bfcol_7`, CAST(1 AS NUMERIC)) +) +SELECT + `bfcol_2` AS `rowindex_x`, + `bfcol_3` AS `numeric_col`, + `bfcol_6` AS `rowindex_y` +FROM `bfcte_4` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/string_col/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/string_col/out.sql new file mode 100644 index 0000000000..5e8d072d46 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/string_col/out.sql @@ -0,0 +1,28 @@ +WITH `bfcte_1` AS ( + SELECT + `rowindex` AS `bfcol_0`, + `string_col` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_0` AS ( + SELECT + `rowindex` AS `bfcol_2`, + `string_col` AS `bfcol_3` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_2` AS ( + SELECT + `bfcol_2` AS `bfcol_4`, + `bfcol_3` AS `bfcol_5` + FROM `bfcte_0` +), `bfcte_3` AS ( + SELECT + * + FROM `bfcte_1` + INNER JOIN `bfcte_2` + ON COALESCE(CAST(`bfcol_1` AS STRING), '0') = COALESCE(CAST(`bfcol_5` AS STRING), '0') + AND COALESCE(CAST(`bfcol_1` AS STRING), '1') = COALESCE(CAST(`bfcol_5` AS STRING), '1') +) +SELECT + `bfcol_0` AS `rowindex_x`, + `bfcol_1` AS `string_col`, + `bfcol_4` AS `rowindex_y` +FROM `bfcte_3` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/time_col/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/time_col/out.sql new file mode 100644 index 0000000000..b0df619f25 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join_w_on/time_col/out.sql @@ -0,0 +1,28 @@ +WITH `bfcte_1` AS ( + SELECT + `rowindex` AS `bfcol_0`, + `time_col` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_0` AS ( + SELECT + `rowindex` AS `bfcol_2`, + `time_col` AS `bfcol_3` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_2` AS ( + SELECT + `bfcol_2` AS `bfcol_4`, + `bfcol_3` AS `bfcol_5` + FROM `bfcte_0` +), `bfcte_3` AS ( + SELECT + * + FROM `bfcte_1` + INNER JOIN `bfcte_2` + ON COALESCE(CAST(`bfcol_1` AS STRING), '0') = COALESCE(CAST(`bfcol_5` AS STRING), '0') + AND COALESCE(CAST(`bfcol_1` AS STRING), '1') = COALESCE(CAST(`bfcol_5` AS STRING), '1') +) +SELECT + `bfcol_0` AS `rowindex_x`, + `bfcol_1` AS `time_col`, + `bfcol_4` AS `rowindex_y` +FROM `bfcte_3` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/test_compile_join.py b/tests/unit/core/compile/sqlglot/test_compile_join.py index a530ed4fc3..ac016eec02 100644 --- a/tests/unit/core/compile/sqlglot/test_compile_join.py +++ b/tests/unit/core/compile/sqlglot/test_compile_join.py @@ -49,3 +49,13 @@ def test_compile_join_w_how(scalar_types_df: bpd.DataFrame): join_sql = left.merge(right, how="cross").sql assert "CROSS JOIN" in join_sql assert "ON" not in join_sql + + +@pytest.mark.parametrize( + ("on"), + ["bool_col", "int64_col", "float64_col", "string_col", "time_col", "numeric_col"], +) +def test_compile_join_w_on(scalar_types_df: bpd.DataFrame, on: str, snapshot): + df = scalar_types_df[["rowindex", on]] + merge = df.merge(df, left_on=on, right_on=on) + snapshot.assert_match(merge.sql, "out.sql") From 7e8658b085e534dac8f4d63e3ea612e69cf45779 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 2 Jul 2025 15:01:53 -0700 Subject: [PATCH 07/13] refactor: add compile_random_sample (#1884) Fixes internal issue 429248387 --- bigframes/core/compile/sqlglot/compiler.py | 6 + bigframes/core/compile/sqlglot/sqlglot_ir.py | 192 ++++++++++++------ .../core/compile/sqlglot/sqlglot_types.py | 2 + tests/unit/core/compile/sqlglot/conftest.py | 11 + .../test_compile_random_sample/out.sql | 184 +++++++++++++++++ .../test_compile_readlocal/out.sql | 16 +- .../sqlglot/test_compile_random_sample.py | 35 ++++ .../compile/sqlglot/test_compile_readlocal.py | 2 - 8 files changed, 380 insertions(+), 68 deletions(-) create mode 100644 tests/unit/core/compile/sqlglot/snapshots/test_compile_random_sample/test_compile_random_sample/out.sql create mode 100644 tests/unit/core/compile/sqlglot/test_compile_random_sample.py diff --git a/bigframes/core/compile/sqlglot/compiler.py b/bigframes/core/compile/sqlglot/compiler.py index 03d1a0a2de..93f072973c 100644 --- a/bigframes/core/compile/sqlglot/compiler.py +++ b/bigframes/core/compile/sqlglot/compiler.py @@ -261,6 +261,12 @@ def compile_explode( columns = tuple(ref.id.sql for ref in node.column_ids) return child.explode(columns, offsets_col) + @_compile_node.register + def compile_random_sample( + self, node: nodes.RandomSampleNode, child: ir.SQLGlotIR + ) -> ir.SQLGlotIR: + return child.sample(node.fraction) + def _replace_unsupported_ops(node: nodes.BigFrameNode): node = nodes.bottom_up(node, rewrite.rewrite_slice) diff --git a/bigframes/core/compile/sqlglot/sqlglot_ir.py b/bigframes/core/compile/sqlglot/sqlglot_ir.py index d5902fa6fc..c0bed4090c 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_ir.py +++ b/bigframes/core/compile/sqlglot/sqlglot_ir.py @@ -25,7 +25,7 @@ import sqlglot.expressions as sge from bigframes import dtypes -from bigframes.core import guid +from bigframes.core import guid, utils from bigframes.core.compile.sqlglot.expressions import typed_expr import bigframes.core.compile.sqlglot.sqlglot_types as sgt import bigframes.core.local_data as local_data @@ -71,7 +71,10 @@ def from_pyarrow( schema: bf_schema.ArraySchema, uid_gen: guid.SequentialUIDGenerator, ) -> SQLGlotIR: - """Builds SQLGlot expression from pyarrow table.""" + """Builds SQLGlot expression from a pyarrow table. + + This is used to represent in-memory data as a SQL query. + """ dtype_expr = sge.DataType( this=sge.DataType.Type.STRUCT, expressions=[ @@ -117,6 +120,16 @@ def from_table( alias_names: typing.Sequence[str], uid_gen: guid.SequentialUIDGenerator, ) -> SQLGlotIR: + """Builds a SQLGlotIR expression from a BigQuery table. + + Args: + project_id (str): The project ID of the BigQuery table. + dataset_id (str): The dataset ID of the BigQuery table. + table_id (str): The table ID of the BigQuery table. + col_names (typing.Sequence[str]): The names of the columns to select. + alias_names (typing.Sequence[str]): The aliases for the selected columns. + uid_gen (guid.SequentialUIDGenerator): A generator for unique identifiers. + """ selections = [ sge.Alias( this=sge.to_identifier(col_name, quoted=cls.quoted), @@ -137,7 +150,7 @@ def from_query_string( cls, query_string: str, ) -> SQLGlotIR: - """Builds SQLGlot expression from a query string""" + """Builds a SQLGlot expression from a query string""" uid_gen: guid.SequentialUIDGenerator = guid.SequentialUIDGenerator() cte_name = sge.to_identifier( next(uid_gen.get_uid_stream("bfcte_")), quoted=cls.quoted @@ -157,7 +170,7 @@ def from_union( output_ids: typing.Sequence[str], uid_gen: guid.SequentialUIDGenerator, ) -> SQLGlotIR: - """Builds SQLGlot expression by union of multiple select expressions.""" + """Builds a SQLGlot expression by unioning of multiple select expressions.""" assert ( len(list(selects)) >= 2 ), f"At least two select expressions must be provided, but got {selects}." @@ -205,6 +218,7 @@ def select( self, selected_cols: tuple[tuple[str, sge.Expression], ...], ) -> SQLGlotIR: + """Replaces new selected columns of the current SELECT clause.""" selections = [ sge.Alias( this=expr, @@ -213,15 +227,41 @@ def select( for id, expr in selected_cols ] - new_expr, _ = self._encapsulate_as_cte() + new_expr = _select_to_cte( + self.expr, + sge.to_identifier( + next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted + ), + ) new_expr = new_expr.select(*selections, append=False) return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) + def project( + self, + projected_cols: tuple[tuple[str, sge.Expression], ...], + ) -> SQLGlotIR: + """Adds new columns to the SELECT clause.""" + projected_cols_expr = [ + sge.Alias( + this=expr, + alias=sge.to_identifier(id, quoted=self.quoted), + ) + for id, expr in projected_cols + ] + new_expr = _select_to_cte( + self.expr, + sge.to_identifier( + next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted + ), + ) + new_expr = new_expr.select(*projected_cols_expr, append=True) + return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) + def order_by( self, ordering: tuple[sge.Ordered, ...], ) -> SQLGlotIR: - """Adds ORDER BY clause to the query.""" + """Adds an ORDER BY clause to the query.""" if len(ordering) == 0: return SQLGlotIR(expr=self.expr.copy(), uid_gen=self.uid_gen) new_expr = self.expr.order_by(*ordering) @@ -231,34 +271,24 @@ def limit( self, limit: int | None, ) -> SQLGlotIR: - """Adds LIMIT clause to the query.""" + """Adds a LIMIT clause to the query.""" if limit is not None: new_expr = self.expr.limit(limit) else: new_expr = self.expr.copy() return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) - def project( - self, - projected_cols: tuple[tuple[str, sge.Expression], ...], - ) -> SQLGlotIR: - projected_cols_expr = [ - sge.Alias( - this=expr, - alias=sge.to_identifier(id, quoted=self.quoted), - ) - for id, expr in projected_cols - ] - new_expr, _ = self._encapsulate_as_cte() - new_expr = new_expr.select(*projected_cols_expr, append=True) - return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) - def filter( self, condition: sge.Expression, ) -> SQLGlotIR: - """Filters the query with the given condition.""" - new_expr, _ = self._encapsulate_as_cte() + """Filters the query by adding a WHERE clause.""" + new_expr = _select_to_cte( + self.expr, + sge.to_identifier( + next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted + ), + ) return SQLGlotIR( expr=new_expr.where(condition, append=False), uid_gen=self.uid_gen ) @@ -272,8 +302,15 @@ def join( joins_nulls: bool = True, ) -> SQLGlotIR: """Joins the current query with another SQLGlotIR instance.""" - left_select, left_table = self._encapsulate_as_cte() - right_select, right_table = right._encapsulate_as_cte() + left_cte_name = sge.to_identifier( + next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted + ) + right_cte_name = sge.to_identifier( + next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted + ) + + left_select = _select_to_cte(self.expr, left_cte_name) + right_select = _select_to_cte(right.expr, right_cte_name) left_ctes = left_select.args.pop("with", []) right_ctes = right_select.args.pop("with", []) @@ -288,17 +325,50 @@ def join( new_expr = ( sge.Select() .select(sge.Star()) - .from_(left_table) - .join(right_table, on=join_on, join_type=join_type_str) + .from_(sge.Table(this=left_cte_name)) + .join(sge.Table(this=right_cte_name), on=join_on, join_type=join_type_str) ) new_expr.set("with", sge.With(expressions=merged_ctes)) return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) + def explode( + self, + column_names: tuple[str, ...], + offsets_col: typing.Optional[str], + ) -> SQLGlotIR: + """Unnests one or more array columns.""" + num_columns = len(list(column_names)) + assert num_columns > 0, "At least one column must be provided for explode." + if num_columns == 1: + return self._explode_single_column(column_names[0], offsets_col) + else: + return self._explode_multiple_columns(column_names, offsets_col) + + def sample(self, fraction: float) -> SQLGlotIR: + """Uniform samples a fraction of the rows.""" + uuid_col = sge.to_identifier( + next(self.uid_gen.get_uid_stream("bfcol_")), quoted=self.quoted + ) + uuid_expr = sge.Alias(this=sge.func("RAND"), alias=uuid_col) + condition = sge.LT( + this=uuid_col, + expression=_literal(fraction, dtypes.FLOAT_DTYPE), + ) + + new_cte_name = sge.to_identifier( + next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted + ) + new_expr = _select_to_cte( + self.expr.select(uuid_expr, append=True), new_cte_name + ).where(condition, append=False) + return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) + def insert( self, destination: bigquery.TableReference, ) -> str: + """Generates an INSERT INTO SQL statement from the current SELECT clause.""" return sge.insert(self.expr.subquery(), _table(destination)).sql( dialect=self.dialect, pretty=self.pretty ) @@ -307,6 +377,9 @@ def replace( self, destination: bigquery.TableReference, ) -> str: + """Generates a MERGE statement to replace the destination table's contents. + by the current SELECT clause. + """ # Workaround for SQLGlot breaking change: # https://github.com/tobymao/sqlglot/pull/4495 whens_expr = [ @@ -325,23 +398,10 @@ def replace( ).sql(dialect=self.dialect, pretty=self.pretty) return f"{merge_str}\n{whens_str}" - def explode( - self, - column_names: tuple[str, ...], - offsets_col: typing.Optional[str], - ) -> SQLGlotIR: - num_columns = len(list(column_names)) - assert num_columns > 0, "At least one column must be provided for explode." - if num_columns == 1: - return self._explode_single_column(column_names[0], offsets_col) - else: - return self._explode_multiple_columns(column_names, offsets_col) - def _explode_single_column( self, column_name: str, offsets_col: typing.Optional[str] ) -> SQLGlotIR: """Helper method to handle the case of exploding a single column.""" - offset = ( sge.to_identifier(offsets_col, quoted=self.quoted) if offsets_col else None ) @@ -358,7 +418,12 @@ def _explode_single_column( # TODO: "CROSS" if not keep_empty else "LEFT" # TODO: overlaps_with_parent to replace existing column. - new_expr, _ = self._encapsulate_as_cte() + new_expr = _select_to_cte( + self.expr, + sge.to_identifier( + next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted + ), + ) new_expr = new_expr.select(selection, append=False).join( unnest_expr, join_type="CROSS" ) @@ -408,33 +473,32 @@ def _explode_multiple_columns( for column in columns ] ) - new_expr, _ = self._encapsulate_as_cte() + new_expr = _select_to_cte( + self.expr, + sge.to_identifier( + next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted + ), + ) new_expr = new_expr.select(selection, append=False).join( unnest_expr, join_type="CROSS" ) return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) - def _encapsulate_as_cte( - self, - ) -> typing.Tuple[sge.Select, sge.Table]: - """Transforms a given sge.Select query by pushing its main SELECT statement - into a new CTE and then generates a 'SELECT * FROM new_cte_name' - for the new query.""" - select_expr = self.expr.copy() - existing_ctes = select_expr.args.pop("with", []) - new_cte_name = sge.to_identifier( - next(self.uid_gen.get_uid_stream("bfcte_")), quoted=self.quoted - ) - new_cte = sge.CTE( - this=select_expr, - alias=new_cte_name, - ) - new_with_clause = sge.With(expressions=[*existing_ctes, new_cte]) - new_table_expr = sge.Table(this=new_cte_name) - new_select_expr = sge.Select().select(sge.Star()).from_(new_table_expr) - new_select_expr.set("with", new_with_clause) - return new_select_expr, new_table_expr +def _select_to_cte(expr: sge.Select, cte_name: sge.Identifier) -> sge.Select: + """Transforms a given sge.Select query by pushing its main SELECT statement + into a new CTE and then generates a 'SELECT * FROM new_cte_name' + for the new query.""" + select_expr = expr.copy() + existing_ctes = select_expr.args.pop("with", []) + new_cte = sge.CTE( + this=select_expr, + alias=cte_name, + ) + new_with_clause = sge.With(expressions=[*existing_ctes, new_cte]) + new_select_expr = sge.Select().select(sge.Star()).from_(sge.Table(this=cte_name)) + new_select_expr.set("with", new_with_clause) + return new_select_expr def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression: @@ -454,6 +518,8 @@ def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression: return sge.func("ST_GEOGFROMTEXT", sge.convert(wkt)) elif dtype == dtypes.JSON_DTYPE: return sge.ParseJSON(this=sge.convert(str(value))) + elif dtype == dtypes.TIMEDELTA_DTYPE: + return sge.convert(utils.timedelta_to_micros(value)) elif dtypes.is_struct_like(dtype): items = [ _literal(value=value[field_name], dtype=field_dtype).as_( diff --git a/bigframes/core/compile/sqlglot/sqlglot_types.py b/bigframes/core/compile/sqlglot/sqlglot_types.py index 0cfeaae3e9..5b0f70077d 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_types.py +++ b/bigframes/core/compile/sqlglot/sqlglot_types.py @@ -59,6 +59,8 @@ def from_bigframes_dtype( return "JSON" elif bigframes_dtype == bigframes.dtypes.GEO_DTYPE: return "GEOGRAPHY" + elif bigframes_dtype == bigframes.dtypes.TIMEDELTA_DTYPE: + return "INT64" elif isinstance(bigframes_dtype, pd.ArrowDtype): if pa.types.is_list(bigframes_dtype.pyarrow_dtype): inner_bigframes_dtype = bigframes.dtypes.arrow_dtype_to_bigframes_dtype( diff --git a/tests/unit/core/compile/sqlglot/conftest.py b/tests/unit/core/compile/sqlglot/conftest.py index 645daddd46..754c19ac90 100644 --- a/tests/unit/core/compile/sqlglot/conftest.py +++ b/tests/unit/core/compile/sqlglot/conftest.py @@ -21,6 +21,7 @@ import pytest from bigframes import dtypes +import bigframes.core as core import bigframes.pandas as bpd import bigframes.testing.mocks as mocks import bigframes.testing.utils @@ -115,6 +116,16 @@ def scalar_types_pandas_df() -> pd.DataFrame: return df +@pytest.fixture(scope="module") +def scalar_types_array_value( + scalar_types_pandas_df: pd.DataFrame, compiler_session: bigframes.Session +) -> core.ArrayValue: + managed_data_source = core.local_data.ManagedArrowTable.from_pandas( + scalar_types_pandas_df + ) + return core.ArrayValue.from_managed(managed_data_source, compiler_session) + + @pytest.fixture(scope="session") def nested_structs_types_table_schema() -> typing.Sequence[bigquery.SchemaField]: return [ diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_random_sample/test_compile_random_sample/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_random_sample/test_compile_random_sample/out.sql new file mode 100644 index 0000000000..aae34716d8 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_random_sample/test_compile_random_sample/out.sql @@ -0,0 +1,184 @@ +WITH `bfcte_0` AS ( + SELECT + *, + RAND() AS `bfcol_16` + FROM UNNEST(ARRAY>[STRUCT( + TRUE, + CAST(b'Hello, World!' AS BYTES), + CAST('2021-07-21' AS DATE), + CAST('2021-07-21T11:39:45' AS DATETIME), + ST_GEOGFROMTEXT('POINT(-122.0838511 37.3860517)'), + 123456789, + 0, + CAST(1.234567890 AS NUMERIC), + 1.25, + 0, + 0, + 'Hello, World!', + CAST('11:41:43.076160' AS TIME), + CAST('2021-07-21T17:43:43.945289+00:00' AS TIMESTAMP), + 4, + 0 + ), STRUCT( + FALSE, + CAST(b'\xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf' AS BYTES), + CAST('1991-02-03' AS DATE), + CAST('1991-01-02T03:45:06' AS DATETIME), + ST_GEOGFROMTEXT('POINT(-71.104 42.315)'), + -987654321, + 1, + CAST(1.234567890 AS NUMERIC), + 2.51, + 1, + 1, + 'こんにちは', + CAST('11:14:34.701606' AS TIME), + CAST('2021-07-21T17:43:43.945289+00:00' AS TIMESTAMP), + -1000000, + 1 + ), STRUCT( + TRUE, + CAST(b'\xc2\xa1Hola Mundo!' AS BYTES), + CAST('2023-03-01' AS DATE), + CAST('2023-03-01T10:55:13' AS DATETIME), + ST_GEOGFROMTEXT('POINT(-0.124474760143016 51.5007826749545)'), + 314159, + 0, + CAST(101.101010100 AS NUMERIC), + 25000000000.0, + 2, + 2, + ' ¡Hola Mundo! ', + CAST('23:59:59.999999' AS TIME), + CAST('2023-03-01T10:55:13.250125+00:00' AS TIMESTAMP), + 0, + 2 + ), STRUCT( + CAST(NULL AS BOOLEAN), + CAST(NULL AS BYTES), + CAST(NULL AS DATE), + CAST(NULL AS DATETIME), + CAST(NULL AS GEOGRAPHY), + CAST(NULL AS INT64), + 1, + CAST(NULL AS NUMERIC), + CAST(NULL AS FLOAT64), + 3, + 3, + CAST(NULL AS STRING), + CAST(NULL AS TIME), + CAST(NULL AS TIMESTAMP), + CAST(NULL AS INT64), + 3 + ), STRUCT( + FALSE, + CAST(b'\xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf' AS BYTES), + CAST('2021-07-21' AS DATE), + CAST(NULL AS DATETIME), + CAST(NULL AS GEOGRAPHY), + -234892, + -2345, + CAST(NULL AS NUMERIC), + CAST(NULL AS FLOAT64), + 4, + 4, + 'Hello, World!', + CAST(NULL AS TIME), + CAST(NULL AS TIMESTAMP), + 31540000000000, + 4 + ), STRUCT( + FALSE, + CAST(b'G\xc3\xbcten Tag' AS BYTES), + CAST('1980-03-14' AS DATE), + CAST('1980-03-14T15:16:17' AS DATETIME), + CAST(NULL AS GEOGRAPHY), + 55555, + 0, + CAST(5.555555000 AS NUMERIC), + 555.555, + 5, + 5, + 'Güten Tag!', + CAST('15:16:17.181921' AS TIME), + CAST('1980-03-14T15:16:17.181921+00:00' AS TIMESTAMP), + 4, + 5 + ), STRUCT( + TRUE, + CAST(b'Hello\tBigFrames!\x07' AS BYTES), + CAST('2023-05-23' AS DATE), + CAST('2023-05-23T11:37:01' AS DATETIME), + ST_GEOGFROMTEXT('LINESTRING(-0.127959 51.507728, -0.127026 51.507473)'), + 101202303, + 2, + CAST(-10.090807000 AS NUMERIC), + -123.456, + 6, + 6, + 'capitalize, This ', + CAST('01:02:03.456789' AS TIME), + CAST('2023-05-23T11:42:55.000001+00:00' AS TIMESTAMP), + CAST(NULL AS INT64), + 6 + ), STRUCT( + TRUE, + CAST(NULL AS BYTES), + CAST('2038-01-20' AS DATE), + CAST('2038-01-19T03:14:08' AS DATETIME), + CAST(NULL AS GEOGRAPHY), + -214748367, + 2, + CAST(11111111.100000000 AS NUMERIC), + 42.42, + 7, + 7, + ' سلام', + CAST('12:00:00.000001' AS TIME), + CAST('2038-01-19T03:14:17.999999+00:00' AS TIMESTAMP), + 4, + 7 + ), STRUCT( + FALSE, + CAST(NULL AS BYTES), + CAST(NULL AS DATE), + CAST(NULL AS DATETIME), + CAST(NULL AS GEOGRAPHY), + 2, + 1, + CAST(NULL AS NUMERIC), + 6.87, + 8, + 8, + 'T', + CAST(NULL AS TIME), + CAST(NULL AS TIMESTAMP), + 432000000000, + 8 + )]) +), `bfcte_1` AS ( + SELECT + * + FROM `bfcte_0` + WHERE + `bfcol_16` < 0.1 +) +SELECT + `bfcol_0` AS `bool_col`, + `bfcol_1` AS `bytes_col`, + `bfcol_2` AS `date_col`, + `bfcol_3` AS `datetime_col`, + `bfcol_4` AS `geography_col`, + `bfcol_5` AS `int64_col`, + `bfcol_6` AS `int64_too`, + `bfcol_7` AS `numeric_col`, + `bfcol_8` AS `float64_col`, + `bfcol_9` AS `rowindex`, + `bfcol_10` AS `rowindex_2`, + `bfcol_11` AS `string_col`, + `bfcol_12` AS `time_col`, + `bfcol_13` AS `timestamp_col`, + `bfcol_14` AS `duration_col` +FROM `bfcte_1` +ORDER BY + `bfcol_15` ASC NULLS LAST \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal/out.sql index 70d73db6a7..2b080b0b7c 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_readlocal/test_compile_readlocal/out.sql @@ -1,7 +1,7 @@ WITH `bfcte_0` AS ( SELECT * - FROM UNNEST(ARRAY>[STRUCT( + FROM UNNEST(ARRAY>[STRUCT( 0, TRUE, CAST(b'Hello, World!' AS BYTES), @@ -17,6 +17,7 @@ WITH `bfcte_0` AS ( 'Hello, World!', CAST('11:41:43.076160' AS TIME), CAST('2021-07-21T17:43:43.945289+00:00' AS TIMESTAMP), + 4, 0 ), STRUCT( 1, @@ -34,6 +35,7 @@ WITH `bfcte_0` AS ( 'こんにちは', CAST('11:14:34.701606' AS TIME), CAST('2021-07-21T17:43:43.945289+00:00' AS TIMESTAMP), + -1000000, 1 ), STRUCT( 2, @@ -51,6 +53,7 @@ WITH `bfcte_0` AS ( ' ¡Hola Mundo! ', CAST('23:59:59.999999' AS TIME), CAST('2023-03-01T10:55:13.250125+00:00' AS TIMESTAMP), + 0, 2 ), STRUCT( 3, @@ -68,6 +71,7 @@ WITH `bfcte_0` AS ( CAST(NULL AS STRING), CAST(NULL AS TIME), CAST(NULL AS TIMESTAMP), + CAST(NULL AS INT64), 3 ), STRUCT( 4, @@ -85,6 +89,7 @@ WITH `bfcte_0` AS ( 'Hello, World!', CAST(NULL AS TIME), CAST(NULL AS TIMESTAMP), + 31540000000000, 4 ), STRUCT( 5, @@ -102,6 +107,7 @@ WITH `bfcte_0` AS ( 'Güten Tag!', CAST('15:16:17.181921' AS TIME), CAST('1980-03-14T15:16:17.181921+00:00' AS TIMESTAMP), + 4, 5 ), STRUCT( 6, @@ -119,6 +125,7 @@ WITH `bfcte_0` AS ( 'capitalize, This ', CAST('01:02:03.456789' AS TIME), CAST('2023-05-23T11:42:55.000001+00:00' AS TIMESTAMP), + CAST(NULL AS INT64), 6 ), STRUCT( 7, @@ -136,6 +143,7 @@ WITH `bfcte_0` AS ( ' سلام', CAST('12:00:00.000001' AS TIME), CAST('2038-01-19T03:14:17.999999+00:00' AS TIMESTAMP), + 4, 7 ), STRUCT( 8, @@ -153,6 +161,7 @@ WITH `bfcte_0` AS ( 'T', CAST(NULL AS TIME), CAST(NULL AS TIMESTAMP), + 432000000000, 8 )]) ) @@ -171,7 +180,8 @@ SELECT `bfcol_11` AS `rowindex_2`, `bfcol_12` AS `string_col`, `bfcol_13` AS `time_col`, - `bfcol_14` AS `timestamp_col` + `bfcol_14` AS `timestamp_col`, + `bfcol_15` AS `duration_col` FROM `bfcte_0` ORDER BY - `bfcol_15` ASC NULLS LAST \ No newline at end of file + `bfcol_16` ASC NULLS LAST \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/test_compile_random_sample.py b/tests/unit/core/compile/sqlglot/test_compile_random_sample.py new file mode 100644 index 0000000000..6e333f0421 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/test_compile_random_sample.py @@ -0,0 +1,35 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from bigframes.core import nodes +import bigframes.core as core +import bigframes.core.compile.sqlglot as sqlglot + +pytest.importorskip("pytest_snapshot") + + +def test_compile_random_sample( + scalar_types_array_value: core.ArrayValue, + snapshot, +): + """This test verifies the SQL compilation of a RandomSampleNode. + + Because BigFrames doesn't expose a public API for creating a random sample + operation, this test constructs the node directly and then compiles it to SQL. + """ + node = nodes.RandomSampleNode(scalar_types_array_value.node, fraction=0.1) + sql = sqlglot.compiler.SQLGlotCompiler().compile(node) + snapshot.assert_match(sql, "out.sql") diff --git a/tests/unit/core/compile/sqlglot/test_compile_readlocal.py b/tests/unit/core/compile/sqlglot/test_compile_readlocal.py index 6f8a2050e5..7307fd9b4e 100644 --- a/tests/unit/core/compile/sqlglot/test_compile_readlocal.py +++ b/tests/unit/core/compile/sqlglot/test_compile_readlocal.py @@ -24,8 +24,6 @@ def test_compile_readlocal( scalar_types_pandas_df: pd.DataFrame, compiler_session: bigframes.Session, snapshot ): - # Durations not yet supported - scalar_types_pandas_df = scalar_types_pandas_df.drop(["duration_col"], axis=1) bf_df = bpd.DataFrame(scalar_types_pandas_df, session=compiler_session) snapshot.assert_match(bf_df.sql, "out.sql") From e43d15d535d6d5fd73c33967271f3591c41dffb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Swe=C3=B1a=20=28Swast=29?= Date: Mon, 7 Jul 2025 13:59:37 -0500 Subject: [PATCH 08/13] feat: `df.to_pandas_batches()` returns one empty DataFrame if `df` is empty (#1878) --- bigframes/core/blocks.py | 16 ++++++++++++++++ tests/system/small/test_dataframe_io.py | 22 ++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/bigframes/core/blocks.py b/bigframes/core/blocks.py index 6d476cc795..dbbf9ee864 100644 --- a/bigframes/core/blocks.py +++ b/bigframes/core/blocks.py @@ -620,15 +620,31 @@ def to_pandas_batches( ordered=True, use_explicit_destination=allow_large_results, ) + + total_batches = 0 for df in execute_result.to_pandas_batches( page_size=page_size, max_results=max_results ): + total_batches += 1 self._copy_index_to_pandas(df) if squeeze: yield df.squeeze(axis=1) else: yield df + # To reduce the number of edge cases to consider when working with the + # results of this, always return at least one DataFrame. See: + # b/428918844. + if total_batches == 0: + df = pd.DataFrame( + { + col: pd.Series([], dtype=self.expr.get_column_type(col)) + for col in itertools.chain(self.value_columns, self.index_columns) + } + ) + self._copy_index_to_pandas(df) + yield df + def _copy_index_to_pandas(self, df: pd.DataFrame): """Set the index on pandas DataFrame to match this block. diff --git a/tests/system/small/test_dataframe_io.py b/tests/system/small/test_dataframe_io.py index ef6e25a95c..1d6ae370c5 100644 --- a/tests/system/small/test_dataframe_io.py +++ b/tests/system/small/test_dataframe_io.py @@ -347,6 +347,28 @@ def test_to_pandas_batches_w_correct_dtypes(scalars_df_default_index): pd.testing.assert_series_equal(actual, expected) +def test_to_pandas_batches_w_empty_dataframe(session): + """Verify to_pandas_batches() APIs returns at least one DataFrame. + + See b/428918844 for additional context. + """ + empty = bpd.DataFrame( + { + "idx1": [], + "idx2": [], + "col1": pandas.Series([], dtype="string[pyarrow]"), + "col2": pandas.Series([], dtype="Int64"), + }, + session=session, + ).set_index(["idx1", "idx2"], drop=True) + + results = list(empty.to_pandas_batches()) + assert len(results) == 1 + assert list(results[0].index.names) == ["idx1", "idx2"] + assert list(results[0].columns) == ["col1", "col2"] + pandas.testing.assert_series_equal(results[0].dtypes, empty.dtypes) + + @pytest.mark.parametrize("allow_large_results", (True, False)) def test_to_pandas_batches_w_page_size_and_max_results(session, allow_large_results): """Verify to_pandas_batches() APIs returns the expected page size. From 4185afe05733fba7afc349bfe4dd9227540bb34e Mon Sep 17 00:00:00 2001 From: jialuoo Date: Mon, 7 Jul 2025 14:32:00 -0700 Subject: [PATCH 09/13] fea: support multi index for dataframe where (#1881) * feat: support multi index for dataframe where * fix test * fix * resolve the comments --- bigframes/dataframe.py | 4 +- tests/system/small/test_dataframe.py | 24 ++--- tests/system/small/test_multiindex.py | 150 ++++++++++++++++++++++++++ tests/unit/test_dataframe_polars.py | 2 +- 4 files changed, 165 insertions(+), 15 deletions(-) diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index 1ef287842e..1884f0beff 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -2741,9 +2741,9 @@ def where(self, cond, other=None): if isinstance(other, bigframes.series.Series): raise ValueError("Seires is not a supported replacement type!") - if self.columns.nlevels > 1 or self.index.nlevels > 1: + if self.columns.nlevels > 1: raise NotImplementedError( - "The dataframe.where() method does not support multi-index and/or multi-column." + "The dataframe.where() method does not support multi-column." ) aligned_block, (_, _) = self._block.join(cond._block, how="left") diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index 5045e2268f..91a83dfd73 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -375,15 +375,6 @@ def test_insert(scalars_dfs, loc, column, value, allow_duplicates): pd.testing.assert_frame_equal(bf_df.to_pandas(), pd_df, check_dtype=False) -def test_where_series_cond(scalars_df_index, scalars_pandas_df_index): - # Condition is dataframe, other is None (as default). - cond_bf = scalars_df_index["int64_col"] > 0 - cond_pd = scalars_pandas_df_index["int64_col"] > 0 - bf_result = scalars_df_index.where(cond_bf).to_pandas() - pd_result = scalars_pandas_df_index.where(cond_pd) - pandas.testing.assert_frame_equal(bf_result, pd_result) - - def test_mask_series_cond(scalars_df_index, scalars_pandas_df_index): cond_bf = scalars_df_index["int64_col"] > 0 cond_pd = scalars_pandas_df_index["int64_col"] > 0 @@ -395,8 +386,8 @@ def test_mask_series_cond(scalars_df_index, scalars_pandas_df_index): pandas.testing.assert_frame_equal(bf_result, pd_result) -def test_where_series_multi_index(scalars_df_index, scalars_pandas_df_index): - # Test when a dataframe has multi-index or multi-columns. +def test_where_multi_column(scalars_df_index, scalars_pandas_df_index): + # Test when a dataframe has multi-columns. columns = ["int64_col", "float64_col"] dataframe_bf = scalars_df_index[columns] @@ -409,10 +400,19 @@ def test_where_series_multi_index(scalars_df_index, scalars_pandas_df_index): dataframe_bf.where(cond_bf).to_pandas() assert ( str(context.value) - == "The dataframe.where() method does not support multi-index and/or multi-column." + == "The dataframe.where() method does not support multi-column." ) +def test_where_series_cond(scalars_df_index, scalars_pandas_df_index): + # Condition is dataframe, other is None (as default). + cond_bf = scalars_df_index["int64_col"] > 0 + cond_pd = scalars_pandas_df_index["int64_col"] > 0 + bf_result = scalars_df_index.where(cond_bf).to_pandas() + pd_result = scalars_pandas_df_index.where(cond_pd) + pandas.testing.assert_frame_equal(bf_result, pd_result) + + def test_where_series_cond_const_other(scalars_df_index, scalars_pandas_df_index): # Condition is a series, other is a constant. columns = ["int64_col", "float64_col"] diff --git a/tests/system/small/test_multiindex.py b/tests/system/small/test_multiindex.py index b63468d311..e4852cc8fb 100644 --- a/tests/system/small/test_multiindex.py +++ b/tests/system/small/test_multiindex.py @@ -19,6 +19,22 @@ import bigframes.pandas as bpd from bigframes.testing.utils import assert_pandas_df_equal +# Sample MultiIndex for testing DataFrames where() method. +_MULTI_INDEX = pandas.MultiIndex.from_tuples( + [ + (0, "a"), + (1, "b"), + (2, "c"), + (0, "d"), + (1, "e"), + (2, "f"), + (0, "g"), + (1, "h"), + (2, "i"), + ], + names=["A", "B"], +) + def test_multi_index_from_arrays(): bf_idx = bpd.MultiIndex.from_arrays( @@ -541,6 +557,140 @@ def test_multi_index_dataframe_join_on(scalars_dfs, how): assert_pandas_df_equal(bf_result, pd_result, ignore_order=True) +def test_multi_index_dataframe_where_series_cond_none_other( + scalars_df_index, scalars_pandas_df_index +): + columns = ["int64_col", "float64_col"] + + # Create multi-index dataframe. + dataframe_bf = bpd.DataFrame( + scalars_df_index[columns].values, + index=_MULTI_INDEX, + columns=scalars_df_index[columns].columns, + ) + dataframe_pd = pandas.DataFrame( + scalars_pandas_df_index[columns].values, + index=_MULTI_INDEX, + columns=scalars_pandas_df_index[columns].columns, + ) + dataframe_bf.columns.name = "test_name" + dataframe_pd.columns.name = "test_name" + + # When condition is series and other is None. + series_cond_bf = dataframe_bf["int64_col"] > 0 + series_cond_pd = dataframe_pd["int64_col"] > 0 + + bf_result = dataframe_bf.where(series_cond_bf).to_pandas() + pd_result = dataframe_pd.where(series_cond_pd) + pandas.testing.assert_frame_equal( + bf_result, + pd_result, + check_index_type=False, + check_dtype=False, + ) + # Assert the index is still MultiIndex after the operation. + assert isinstance(bf_result.index, pandas.MultiIndex), "Expected a MultiIndex" + assert isinstance(pd_result.index, pandas.MultiIndex), "Expected a MultiIndex" + + +def test_multi_index_dataframe_where_series_cond_dataframe_other( + scalars_df_index, scalars_pandas_df_index +): + columns = ["int64_col", "int64_too"] + + # Create multi-index dataframe. + dataframe_bf = bpd.DataFrame( + scalars_df_index[columns].values, + index=_MULTI_INDEX, + columns=scalars_df_index[columns].columns, + ) + dataframe_pd = pandas.DataFrame( + scalars_pandas_df_index[columns].values, + index=_MULTI_INDEX, + columns=scalars_pandas_df_index[columns].columns, + ) + + # When condition is series and other is dataframe. + series_cond_bf = dataframe_bf["int64_col"] > 1000.0 + series_cond_pd = dataframe_pd["int64_col"] > 1000.0 + dataframe_other_bf = dataframe_bf * 100.0 + dataframe_other_pd = dataframe_pd * 100.0 + + bf_result = dataframe_bf.where(series_cond_bf, dataframe_other_bf).to_pandas() + pd_result = dataframe_pd.where(series_cond_pd, dataframe_other_pd) + pandas.testing.assert_frame_equal( + bf_result, + pd_result, + check_index_type=False, + check_dtype=False, + ) + + +def test_multi_index_dataframe_where_dataframe_cond_constant_other( + scalars_df_index, scalars_pandas_df_index +): + columns = ["int64_col", "float64_col"] + + # Create multi-index dataframe. + dataframe_bf = bpd.DataFrame( + scalars_df_index[columns].values, + index=_MULTI_INDEX, + columns=scalars_df_index[columns].columns, + ) + dataframe_pd = pandas.DataFrame( + scalars_pandas_df_index[columns].values, + index=_MULTI_INDEX, + columns=scalars_pandas_df_index[columns].columns, + ) + + # When condition is dataframe and other is a constant. + dataframe_cond_bf = dataframe_bf > 0 + dataframe_cond_pd = dataframe_pd > 0 + other = 0 + + bf_result = dataframe_bf.where(dataframe_cond_bf, other).to_pandas() + pd_result = dataframe_pd.where(dataframe_cond_pd, other) + pandas.testing.assert_frame_equal( + bf_result, + pd_result, + check_index_type=False, + check_dtype=False, + ) + + +def test_multi_index_dataframe_where_dataframe_cond_dataframe_other( + scalars_df_index, scalars_pandas_df_index +): + columns = ["int64_col", "int64_too", "float64_col"] + + # Create multi-index dataframe. + dataframe_bf = bpd.DataFrame( + scalars_df_index[columns].values, + index=_MULTI_INDEX, + columns=scalars_df_index[columns].columns, + ) + dataframe_pd = pandas.DataFrame( + scalars_pandas_df_index[columns].values, + index=_MULTI_INDEX, + columns=scalars_pandas_df_index[columns].columns, + ) + + # When condition is dataframe and other is dataframe. + dataframe_cond_bf = dataframe_bf < 1000.0 + dataframe_cond_pd = dataframe_pd < 1000.0 + dataframe_other_bf = dataframe_bf * -1.0 + dataframe_other_pd = dataframe_pd * -1.0 + + bf_result = dataframe_bf.where(dataframe_cond_bf, dataframe_other_bf).to_pandas() + pd_result = dataframe_pd.where(dataframe_cond_pd, dataframe_other_pd) + pandas.testing.assert_frame_equal( + bf_result, + pd_result, + check_index_type=False, + check_dtype=False, + ) + + @pytest.mark.parametrize( ("level",), [ diff --git a/tests/unit/test_dataframe_polars.py b/tests/unit/test_dataframe_polars.py index 467cf7ce3d..eae800d409 100644 --- a/tests/unit/test_dataframe_polars.py +++ b/tests/unit/test_dataframe_polars.py @@ -364,7 +364,7 @@ def test_where_series_multi_index(scalars_df_index, scalars_pandas_df_index): dataframe_bf.where(cond_bf).to_pandas() assert ( str(context.value) - == "The dataframe.where() method does not support multi-index and/or multi-column." + == "The dataframe.where() method does not support multi-column." ) From dba2a6e0f7ee820ae8c1fc369b739ff4b23ad375 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20Swe=C3=B1a=20=28Swast=29?= Date: Mon, 7 Jul 2025 16:36:25 -0500 Subject: [PATCH 10/13] chore: round earlier in TPC-H q15 to try and reduce non-determinism due to aggregating twice (#1877) * chore: round earlier in TPC-H q15 to try and reduce non-determinism due to aggregating twice * remove unnecessary code --- third_party/bigframes_vendored/tpch/queries/q15.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/third_party/bigframes_vendored/tpch/queries/q15.py b/third_party/bigframes_vendored/tpch/queries/q15.py index 1cba0ca4bc..0e3460189d 100644 --- a/third_party/bigframes_vendored/tpch/queries/q15.py +++ b/third_party/bigframes_vendored/tpch/queries/q15.py @@ -31,6 +31,11 @@ def q(project_id: str, dataset_id: str, session: bigframes.Session): .agg(TOTAL_REVENUE=bpd.NamedAgg(column="REVENUE", aggfunc="sum")) .rename(columns={"L_SUPPKEY": "SUPPLIER_NO"}) ) + # Round earlier to prevent non-determinism in the later join due to + # differences in distributed floating point operation sort order. + grouped_revenue = grouped_revenue.assign( + TOTAL_REVENUE=grouped_revenue["TOTAL_REVENUE"].round(2) + ) joined_data = bpd.merge( supplier, grouped_revenue, left_on="S_SUPPKEY", right_on="SUPPLIER_NO" @@ -43,10 +48,6 @@ def q(project_id: str, dataset_id: str, session: bigframes.Session): max_revenue_suppliers = joined_data[ joined_data["TOTAL_REVENUE"] == joined_data["MAX_REVENUE"] ] - - max_revenue_suppliers["TOTAL_REVENUE"] = max_revenue_suppliers[ - "TOTAL_REVENUE" - ].round(2) q_final = max_revenue_suppliers[ ["S_SUPPKEY", "S_NAME", "S_ADDRESS", "S_PHONE", "TOTAL_REVENUE"] ].sort_values("S_SUPPKEY") From 8715105239216bffe899ddcbb15805f2e3063af4 Mon Sep 17 00:00:00 2001 From: TrevorBergeron Date: Tue, 8 Jul 2025 09:11:06 -0700 Subject: [PATCH 11/13] feat: Add simple stats support to hybrid local pushdown (#1873) --- bigframes/session/polars_executor.py | 10 +++++- bigframes/testing/engine_utils.py | 11 +++--- .../system/small/engines/test_aggregation.py | 36 +++++++++++++++++++ 3 files changed, 52 insertions(+), 5 deletions(-) diff --git a/bigframes/session/polars_executor.py b/bigframes/session/polars_executor.py index e60bef1819..28ab421905 100644 --- a/bigframes/session/polars_executor.py +++ b/bigframes/session/polars_executor.py @@ -47,7 +47,15 @@ bigframes.operations.ge_op, bigframes.operations.le_op, ) -_COMPATIBLE_AGG_OPS = (agg_ops.SizeOp, agg_ops.SizeUnaryOp) +_COMPATIBLE_AGG_OPS = ( + agg_ops.SizeOp, + agg_ops.SizeUnaryOp, + agg_ops.MinOp, + agg_ops.MaxOp, + agg_ops.SumOp, + agg_ops.MeanOp, + agg_ops.CountOp, +) def _get_expr_ops(expr: expression.Expression) -> set[bigframes.operations.ScalarOp]: diff --git a/bigframes/testing/engine_utils.py b/bigframes/testing/engine_utils.py index f58e5951a1..8aa52cf51a 100644 --- a/bigframes/testing/engine_utils.py +++ b/bigframes/testing/engine_utils.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pandas.testing + from bigframes.core import nodes from bigframes.session import semi_executor @@ -25,7 +27,8 @@ def assert_equivalence_execution( e2_result = engine2.execute(node, ordered=True) assert e1_result is not None assert e2_result is not None - # Schemas might have extra nullity markers, normalize to node expected schema, which should be looser - e1_table = e1_result.to_arrow_table().cast(node.schema.to_pyarrow()) - e2_table = e2_result.to_arrow_table().cast(node.schema.to_pyarrow()) - assert e1_table.equals(e2_table), f"{e1_table} is not equal to {e2_table}" + # Convert to pandas, as pandas has better comparison utils than arrow + assert e1_result.schema == e2_result.schema + e1_table = e1_result.to_pandas() + e2_table = e2_result.to_pandas() + pandas.testing.assert_frame_equal(e1_table, e2_table, rtol=1e-10) diff --git a/tests/system/small/engines/test_aggregation.py b/tests/system/small/engines/test_aggregation.py index 2c323a5f28..8530a6fefa 100644 --- a/tests/system/small/engines/test_aggregation.py +++ b/tests/system/small/engines/test_aggregation.py @@ -25,6 +25,28 @@ REFERENCE_ENGINE = polars_executor.PolarsExecutor() +def apply_agg_to_all_valid( + array: array_value.ArrayValue, op: agg_ops.UnaryAggregateOp, excluded_cols=[] +) -> array_value.ArrayValue: + """ + Apply the aggregation to every column in the array that has a compatible datatype. + """ + exprs_by_name = [] + for arg in array.column_ids: + if arg in excluded_cols: + continue + try: + _ = op.output_type(array.get_column_type(arg)) + expr = expression.UnaryAggregation(op, expression.deref(arg)) + name = f"{arg}-{op.name}" + exprs_by_name.append((expr, name)) + except TypeError: + continue + assert len(exprs_by_name) > 0 + new_arr = array.aggregate(exprs_by_name) + return new_arr + + @pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) def test_engines_aggregate_size( scalars_array_value: array_value.ArrayValue, @@ -48,6 +70,20 @@ def test_engines_aggregate_size( assert_equivalence_execution(node, REFERENCE_ENGINE, engine) +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize( + "op", + [agg_ops.min_op, agg_ops.max_op, agg_ops.mean_op, agg_ops.sum_op, agg_ops.count_op], +) +def test_engines_unary_aggregates( + scalars_array_value: array_value.ArrayValue, + engine, + op, +): + node = apply_agg_to_all_valid(scalars_array_value, op).node + assert_equivalence_execution(node, REFERENCE_ENGINE, engine) + + @pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) @pytest.mark.parametrize( "grouping_cols", From 0581a2ab5ddffb24e4b65202979905f9c868ff84 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 8 Jul 2025 10:05:24 -0700 Subject: [PATCH 12/13] refactor: add json operators to SQLGlot compiler (#1887) --- .../sqlglot/expressions/binary_compiler.py | 5 ++ .../sqlglot/expressions/unary_compiler.py | 46 ++++++++++++++++++ .../test_json_set/out.sql | 20 ++++++++ .../test_json_extract/out.sql | 15 ++++++ .../test_parse_json/out.sql | 15 ++++++ .../expressions/test_binary_compiler.py | 6 +++ .../expressions/test_unary_compiler.py | 47 ++++++++++++++++++- 7 files changed, 152 insertions(+), 2 deletions(-) create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_json_set/out.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_json_extract/out.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_parse_json/out.sql diff --git a/bigframes/core/compile/sqlglot/expressions/binary_compiler.py b/bigframes/core/compile/sqlglot/expressions/binary_compiler.py index ec75d3a3a4..a6eb7182e9 100644 --- a/bigframes/core/compile/sqlglot/expressions/binary_compiler.py +++ b/bigframes/core/compile/sqlglot/expressions/binary_compiler.py @@ -42,3 +42,8 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: @BINARY_OP_REGISTRATION.register(ops.ge_op) def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: return sge.GTE(this=left.expr, expression=right.expr) + + +@BINARY_OP_REGISTRATION.register(ops.JSONSet) +def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: + return sge.func("JSON_SET", left.expr, sge.convert(op.json_path), right.expr) diff --git a/bigframes/core/compile/sqlglot/expressions/unary_compiler.py b/bigframes/core/compile/sqlglot/expressions/unary_compiler.py index 716917b455..9cca15f352 100644 --- a/bigframes/core/compile/sqlglot/expressions/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/expressions/unary_compiler.py @@ -70,3 +70,49 @@ def _(op: ops.ArraySliceOp, expr: TypedExpr) -> sge.Expression: ) return sge.array(selected_elements) + + +# JSON Ops +@UNARY_OP_REGISTRATION.register(ops.JSONExtract) +def _(op: ops.JSONExtract, expr: TypedExpr) -> sge.Expression: + return sge.func("JSON_EXTRACT", expr.expr, sge.convert(op.json_path)) + + +@UNARY_OP_REGISTRATION.register(ops.JSONExtractArray) +def _(op: ops.JSONExtractArray, expr: TypedExpr) -> sge.Expression: + return sge.func("JSON_EXTRACT_ARRAY", expr.expr, sge.convert(op.json_path)) + + +@UNARY_OP_REGISTRATION.register(ops.JSONExtractStringArray) +def _(op: ops.JSONExtractStringArray, expr: TypedExpr) -> sge.Expression: + return sge.func("JSON_EXTRACT_STRING_ARRAY", expr.expr, sge.convert(op.json_path)) + + +@UNARY_OP_REGISTRATION.register(ops.JSONQuery) +def _(op: ops.JSONQuery, expr: TypedExpr) -> sge.Expression: + return sge.func("JSON_QUERY", expr.expr, sge.convert(op.json_path)) + + +@UNARY_OP_REGISTRATION.register(ops.JSONQueryArray) +def _(op: ops.JSONQueryArray, expr: TypedExpr) -> sge.Expression: + return sge.func("JSON_QUERY_ARRAY", expr.expr, sge.convert(op.json_path)) + + +@UNARY_OP_REGISTRATION.register(ops.JSONValue) +def _(op: ops.JSONValue, expr: TypedExpr) -> sge.Expression: + return sge.func("JSON_VALUE", expr.expr, sge.convert(op.json_path)) + + +@UNARY_OP_REGISTRATION.register(ops.JSONValueArray) +def _(op: ops.JSONValueArray, expr: TypedExpr) -> sge.Expression: + return sge.func("JSON_VALUE_ARRAY", expr.expr, sge.convert(op.json_path)) + + +@UNARY_OP_REGISTRATION.register(ops.ParseJSON) +def _(op: ops.ParseJSON, expr: TypedExpr) -> sge.Expression: + return sge.func("PARSE_JSON", expr.expr) + + +@UNARY_OP_REGISTRATION.register(ops.ToJSONString) +def _(op: ops.ToJSONString, expr: TypedExpr) -> sge.Expression: + return sge.func("TO_JSON_STRING", expr.expr) diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_json_set/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_json_set/out.sql new file mode 100644 index 0000000000..f501dd3b86 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_json_set/out.sql @@ -0,0 +1,20 @@ +WITH `bfcte_0` AS ( + SELECT + `rowindex` AS `bfcol_0`, + `json_col` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`json_types` +), `bfcte_1` AS ( + SELECT + *, + JSON_SET(`bfcol_1`, '$.a', 100) AS `bfcol_4` + FROM `bfcte_0` +), `bfcte_2` AS ( + SELECT + *, + JSON_SET(`bfcol_4`, '$.b', 'hi') AS `bfcol_7` + FROM `bfcte_1` +) +SELECT + `bfcol_0` AS `rowindex`, + `bfcol_7` AS `json_col` +FROM `bfcte_2` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_json_extract/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_json_extract/out.sql new file mode 100644 index 0000000000..2ffb0174a8 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_json_extract/out.sql @@ -0,0 +1,15 @@ +WITH `bfcte_0` AS ( + SELECT + `rowindex` AS `bfcol_0`, + `json_col` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`json_types` +), `bfcte_1` AS ( + SELECT + *, + JSON_EXTRACT(`bfcol_1`, '$') AS `bfcol_4` + FROM `bfcte_0` +) +SELECT + `bfcol_0` AS `rowindex`, + `bfcol_4` AS `json_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_parse_json/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_parse_json/out.sql new file mode 100644 index 0000000000..d965ea8f1b --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_parse_json/out.sql @@ -0,0 +1,15 @@ +WITH `bfcte_0` AS ( + SELECT + `rowindex` AS `bfcol_0`, + `string_col` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + JSON_VALUE(`bfcol_1`, '$') AS `bfcol_4` + FROM `bfcte_0` +) +SELECT + `bfcol_0` AS `rowindex`, + `bfcol_4` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_binary_compiler.py b/tests/unit/core/compile/sqlglot/expressions/test_binary_compiler.py index f3c96e9253..9daff51c9f 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_binary_compiler.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_binary_compiler.py @@ -14,6 +14,7 @@ import pytest +import bigframes.bigquery as bbq import bigframes.pandas as bpd pytest.importorskip("pytest_snapshot") @@ -41,3 +42,8 @@ def test_add_string(scalar_types_df: bpd.DataFrame, snapshot): bf_df["string_col"] = bf_df["string_col"] + "a" snapshot.assert_match(bf_df.sql, "out.sql") + + +def test_json_set(json_types_df: bpd.DataFrame, snapshot): + result = bbq.json_set(json_types_df["json_col"], [("$.a", 100), ("$.b", "hi")]) + snapshot.assert_match(result.to_frame().sql, "out.sql") diff --git a/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py index 317c2f891b..6d9101aff0 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py @@ -14,14 +14,14 @@ import pytest -from bigframes import bigquery +import bigframes.bigquery as bbq import bigframes.pandas as bpd pytest.importorskip("pytest_snapshot") def test_array_to_string(repeated_types_df: bpd.DataFrame, snapshot): - result = bigquery.array_to_string(repeated_types_df["string_list_col"], ".") + result = bbq.array_to_string(repeated_types_df["string_list_col"], ".") snapshot.assert_match(result.to_frame().sql, "out.sql") @@ -42,3 +42,46 @@ def test_array_slice_with_start_and_stop(repeated_types_df: bpd.DataFrame, snaps result = repeated_types_df["string_list_col"].list[1:5] snapshot.assert_match(result.to_frame().sql, "out.sql") + + +# JSON Ops +def test_json_extract(json_types_df: bpd.DataFrame, snapshot): + result = bbq.json_extract(json_types_df["json_col"], "$") + expected_sql = "JSON_EXTRACT(`bfcol_1`, '$') AS `bfcol_4`" + assert expected_sql in result.to_frame().sql + snapshot.assert_match(result.to_frame().sql, "out.sql") + + +def test_json_extract_array(json_types_df: bpd.DataFrame): + result = bbq.json_extract_array(json_types_df["json_col"], "$") + expected_sql = "JSON_EXTRACT_ARRAY(`bfcol_1`, '$') AS `bfcol_4`" + assert expected_sql in result.to_frame().sql + + +def test_json_extract_string_array(json_types_df: bpd.DataFrame): + result = bbq.json_extract_string_array(json_types_df["json_col"], "$") + expected_sql = "JSON_EXTRACT_STRING_ARRAY(`bfcol_1`, '$') AS `bfcol_4`" + assert expected_sql in result.to_frame().sql + + +def test_json_query(json_types_df: bpd.DataFrame): + result = bbq.json_query(json_types_df["json_col"], "$") + expected_sql = "JSON_QUERY(`bfcol_1`, '$') AS `bfcol_4`" + assert expected_sql in result.to_frame().sql + + +def test_json_query_array(json_types_df: bpd.DataFrame): + result = bbq.json_query_array(json_types_df["json_col"], "$") + expected_sql = "JSON_QUERY_ARRAY(`bfcol_1`, '$') AS `bfcol_4`" + assert expected_sql in result.to_frame().sql + + +def test_json_value(json_types_df: bpd.DataFrame): + result = bbq.json_value(json_types_df["json_col"], "$") + expected_sql = "JSON_VALUE(`bfcol_1`, '$') AS `bfcol_4`" + assert expected_sql in result.to_frame().sql + + +def test_parse_json(scalar_types_df: bpd.DataFrame, snapshot): + result = bbq.json_value(scalar_types_df["string_col"], "$") + snapshot.assert_match(result.to_frame().sql, "out.sql") From f63caf2db3eb8ddb6937b4ac4375b5ed8f0e8bb9 Mon Sep 17 00:00:00 2001 From: "release-please[bot]" <55107282+release-please[bot]@users.noreply.github.com> Date: Tue, 8 Jul 2025 13:49:22 -0500 Subject: [PATCH 13/13] chore(main): release 2.10.0 (#1879) Co-authored-by: release-please[bot] <55107282+release-please[bot]@users.noreply.github.com> --- CHANGELOG.md | 19 +++++++++++++++++++ bigframes/version.py | 4 ++-- third_party/bigframes_vendored/version.py | 4 ++-- 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 313064241d..8bf0d2a4d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,25 @@ [1]: https://pypi.org/project/bigframes/#history +## [2.10.0](https://github.com/googleapis/python-bigquery-dataframes/compare/v2.9.0...v2.10.0) (2025-07-08) + + +### Features + +* `df.to_pandas_batches()` returns one empty DataFrame if `df` is empty ([#1878](https://github.com/googleapis/python-bigquery-dataframes/issues/1878)) ([e43d15d](https://github.com/googleapis/python-bigquery-dataframes/commit/e43d15d535d6d5fd73c33967271f3591c41dffb3)) +* Add filter pushdown to hybrid engine ([#1871](https://github.com/googleapis/python-bigquery-dataframes/issues/1871)) ([6454aff](https://github.com/googleapis/python-bigquery-dataframes/commit/6454aff726dee791acbac98f893075ee5ee6d9a1)) +* Add simple stats support to hybrid local pushdown ([#1873](https://github.com/googleapis/python-bigquery-dataframes/issues/1873)) ([8715105](https://github.com/googleapis/python-bigquery-dataframes/commit/8715105239216bffe899ddcbb15805f2e3063af4)) + + +### Bug Fixes + +* Fix issues where duration type returned as int ([#1875](https://github.com/googleapis/python-bigquery-dataframes/issues/1875)) ([f30f750](https://github.com/googleapis/python-bigquery-dataframes/commit/f30f75053a6966abd1a6a644c23efb86b2ac568d)) + + +### Documentation + +* Update gsutil commands to gcloud commands ([#1876](https://github.com/googleapis/python-bigquery-dataframes/issues/1876)) ([c289f70](https://github.com/googleapis/python-bigquery-dataframes/commit/c289f7061320ec6d9de099cab2416cc9f289baac)) + ## [2.9.0](https://github.com/googleapis/python-bigquery-dataframes/compare/v2.8.0...v2.9.0) (2025-06-30) diff --git a/bigframes/version.py b/bigframes/version.py index 4f3c9a5124..4d26fb9b8c 100644 --- a/bigframes/version.py +++ b/bigframes/version.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2.9.0" +__version__ = "2.10.0" # {x-release-please-start-date} -__release_date__ = "2025-06-30" +__release_date__ = "2025-07-08" # {x-release-please-end} diff --git a/third_party/bigframes_vendored/version.py b/third_party/bigframes_vendored/version.py index 4f3c9a5124..4d26fb9b8c 100644 --- a/third_party/bigframes_vendored/version.py +++ b/third_party/bigframes_vendored/version.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2.9.0" +__version__ = "2.10.0" # {x-release-please-start-date} -__release_date__ = "2025-06-30" +__release_date__ = "2025-07-08" # {x-release-please-end}