diff --git a/bigframes/core/compile/sqlglot/compiler.py b/bigframes/core/compile/sqlglot/compiler.py index 606fe41b5e..03d1a0a2de 100644 --- a/bigframes/core/compile/sqlglot/compiler.py +++ b/bigframes/core/compile/sqlglot/compiler.py @@ -22,6 +22,7 @@ from bigframes.core import expression, guid, identifiers, nodes, pyarrow_utils, rewrite from bigframes.core.compile import configs +from bigframes.core.compile.sqlglot.expressions import typed_expr import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler import bigframes.core.compile.sqlglot.sqlglot_ir as ir import bigframes.core.ordering as bf_ordering @@ -218,6 +219,29 @@ def compile_filter( condition = scalar_compiler.compile_scalar_expression(node.predicate) return child.filter(condition) + @_compile_node.register + def compile_join( + self, node: nodes.JoinNode, left: ir.SQLGlotIR, right: ir.SQLGlotIR + ) -> ir.SQLGlotIR: + conditions = tuple( + ( + typed_expr.TypedExpr( + scalar_compiler.compile_scalar_expression(left), left.output_type + ), + typed_expr.TypedExpr( + scalar_compiler.compile_scalar_expression(right), right.output_type + ), + ) + for left, right in node.conditions + ) + + return left.join( + right, + join_type=node.type, + conditions=conditions, + joins_nulls=node.joins_nulls, + ) + @_compile_node.register def compile_concat( self, node: nodes.ConcatNode, *children: ir.SQLGlotIR diff --git a/bigframes/core/compile/sqlglot/sqlglot_ir.py b/bigframes/core/compile/sqlglot/sqlglot_ir.py index 6bc2b55162..3b4d7ed0ce 100644 --- a/bigframes/core/compile/sqlglot/sqlglot_ir.py +++ b/bigframes/core/compile/sqlglot/sqlglot_ir.py @@ -26,6 +26,7 @@ from bigframes import dtypes from bigframes.core import guid +from bigframes.core.compile.sqlglot.expressions import typed_expr import bigframes.core.compile.sqlglot.sqlglot_types as sgt import bigframes.core.local_data as local_data import bigframes.core.schema as bf_schema @@ -212,7 +213,8 @@ def select( for id, expr in selected_cols ] - new_expr = self._encapsulate_as_cte().select(*selections, append=False) + new_expr, _ = self._encapsulate_as_cte() + new_expr = new_expr.select(*selections, append=False) return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) def order_by( @@ -247,7 +249,8 @@ def project( ) for id, expr in projected_cols ] - new_expr = self._encapsulate_as_cte().select(*projected_cols_expr, append=True) + new_expr, _ = self._encapsulate_as_cte() + new_expr = new_expr.select(*projected_cols_expr, append=True) return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) def filter( @@ -255,11 +258,43 @@ def filter( condition: sge.Expression, ) -> SQLGlotIR: """Filters the query with the given condition.""" - new_expr = self._encapsulate_as_cte() + new_expr, _ = self._encapsulate_as_cte() return SQLGlotIR( expr=new_expr.where(condition, append=False), uid_gen=self.uid_gen ) + def join( + self, + right: SQLGlotIR, + join_type: typing.Literal["inner", "outer", "left", "right", "cross"], + conditions: tuple[tuple[typed_expr.TypedExpr, typed_expr.TypedExpr], ...], + *, + joins_nulls: bool = True, + ) -> SQLGlotIR: + """Joins the current query with another SQLGlotIR instance.""" + left_select, left_table = self._encapsulate_as_cte() + right_select, right_table = right._encapsulate_as_cte() + + left_ctes = left_select.args.pop("with", []) + right_ctes = right_select.args.pop("with", []) + merged_ctes = [*left_ctes, *right_ctes] + + join_conditions = [ + _join_condition(left, right, joins_nulls) for left, right in conditions + ] + join_on = sge.And(expressions=join_conditions) if join_conditions else None + + join_type_str = join_type if join_type != "outer" else "full outer" + new_expr = ( + sge.Select() + .select(sge.Star()) + .from_(left_table) + .join(right_table, on=join_on, join_type=join_type_str) + ) + new_expr.set("with", sge.With(expressions=merged_ctes)) + + return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) + def insert( self, destination: bigquery.TableReference, @@ -320,12 +355,12 @@ def _explode_single_column( offset=offset, ) selection = sge.Star(replace=[unnested_column_alias.as_(column)]) + # TODO: "CROSS" if not keep_empty else "LEFT" # TODO: overlaps_with_parent to replace existing column. - new_expr = ( - self._encapsulate_as_cte() - .select(selection, append=False) - .join(unnest_expr, join_type="CROSS") + new_expr, _ = self._encapsulate_as_cte() + new_expr = new_expr.select(selection, append=False).join( + unnest_expr, join_type="CROSS" ) return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) @@ -373,16 +408,15 @@ def _explode_multiple_columns( for column in columns ] ) - new_expr = ( - self._encapsulate_as_cte() - .select(selection, append=False) - .join(unnest_expr, join_type="CROSS") + new_expr, _ = self._encapsulate_as_cte() + new_expr = new_expr.select(selection, append=False).join( + unnest_expr, join_type="CROSS" ) return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen) def _encapsulate_as_cte( self, - ) -> sge.Select: + ) -> typing.Tuple[sge.Select, sge.Table]: """Transforms a given sge.Select query by pushing its main SELECT statement into a new CTE and then generates a 'SELECT * FROM new_cte_name' for the new query.""" @@ -397,11 +431,10 @@ def _encapsulate_as_cte( alias=new_cte_name, ) new_with_clause = sge.With(expressions=[*existing_ctes, new_cte]) - new_select_expr = ( - sge.Select().select(sge.Star()).from_(sge.Table(this=new_cte_name)) - ) + new_table_expr = sge.Table(this=new_cte_name) + new_select_expr = sge.Select().select(sge.Star()).from_(new_table_expr) new_select_expr.set("with", new_with_clause) - return new_select_expr + return new_select_expr, new_table_expr def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression: @@ -451,3 +484,11 @@ def _table(table: bigquery.TableReference) -> sge.Table: db=sg.to_identifier(table.dataset_id, quoted=True), catalog=sg.to_identifier(table.project, quoted=True), ) + + +def _join_condition( + left: typed_expr.TypedExpr, + right: typed_expr.TypedExpr, + joins_nulls: bool, +) -> typing.Union[sge.EQ, sge.And]: + return sge.EQ(this=left.expr, expression=right.expr) diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index 1ca5b8b035..1ef287842e 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -3369,8 +3369,6 @@ def merge( "right", "cross", ] = "inner", - # TODO(garrettwu): Currently can take inner, outer, left and right. To support - # cross joins on: Union[blocks.Label, Sequence[blocks.Label], None] = None, *, left_on: Union[blocks.Label, Sequence[blocks.Label], None] = None, diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join/out.sql new file mode 100644 index 0000000000..aefaa28dfb --- /dev/null +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_join/test_compile_join/out.sql @@ -0,0 +1,31 @@ +WITH `bfcte_1` AS ( + SELECT + `int64_col` AS `bfcol_0`, + `rowindex` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_2` AS ( + SELECT + `bfcol_1` AS `bfcol_2`, + `bfcol_0` AS `bfcol_3` + FROM `bfcte_1` +), `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_4`, + `int64_too` AS `bfcol_5` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_3` AS ( + SELECT + `bfcol_4` AS `bfcol_6`, + `bfcol_5` AS `bfcol_7` + FROM `bfcte_0` +), `bfcte_4` AS ( + SELECT + * + FROM `bfcte_2` + LEFT JOIN `bfcte_3` + ON `bfcol_2` = `bfcol_6` +) +SELECT + `bfcol_3` AS `int64_col`, + `bfcol_7` AS `int64_too` +FROM `bfcte_4` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/test_compile_join.py b/tests/unit/core/compile/sqlglot/test_compile_join.py new file mode 100644 index 0000000000..a530ed4fc3 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/test_compile_join.py @@ -0,0 +1,51 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import bigframes.pandas as bpd + +pytest.importorskip("pytest_snapshot") + + +def test_compile_join(scalar_types_df: bpd.DataFrame, snapshot): + left = scalar_types_df[["int64_col"]] + right = scalar_types_df.set_index("int64_col")[["int64_too"]] + join = left.join(right) + snapshot.assert_match(join.sql, "out.sql") + + +def test_compile_join_w_how(scalar_types_df: bpd.DataFrame): + left = scalar_types_df[["int64_col"]] + right = scalar_types_df.set_index("int64_col")[["int64_too"]] + + join_sql = left.join(right, how="left").sql + assert "LEFT JOIN" in join_sql + assert "ON" in join_sql + + join_sql = left.join(right, how="right").sql + assert "RIGHT JOIN" in join_sql + assert "ON" in join_sql + + join_sql = left.join(right, how="outer").sql + assert "FULL OUTER JOIN" in join_sql + assert "ON" in join_sql + + join_sql = left.join(right, how="inner").sql + assert "INNER JOIN" in join_sql + assert "ON" in join_sql + + join_sql = left.merge(right, how="cross").sql + assert "CROSS JOIN" in join_sql + assert "ON" not in join_sql diff --git a/third_party/bigframes_vendored/pandas/core/frame.py b/third_party/bigframes_vendored/pandas/core/frame.py index 0606032d34..40ab5a7352 100644 --- a/third_party/bigframes_vendored/pandas/core/frame.py +++ b/third_party/bigframes_vendored/pandas/core/frame.py @@ -4748,7 +4748,7 @@ def merge( right: Object to merge with. how: - ``{'left', 'right', 'outer', 'inner'}, default 'inner'`` + ``{'left', 'right', 'outer', 'inner', 'cross'}, default 'inner'`` Type of merge to be performed. ``left``: use only keys from left frame, similar to a SQL left outer join; preserve key order.