Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions bigframes/core/compile/sqlglot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from bigframes.core import expression, guid, identifiers, nodes, pyarrow_utils, rewrite
from bigframes.core.compile import configs
from bigframes.core.compile.sqlglot.expressions import typed_expr
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
import bigframes.core.ordering as bf_ordering
Expand Down Expand Up @@ -218,6 +219,29 @@ def compile_filter(
condition = scalar_compiler.compile_scalar_expression(node.predicate)
return child.filter(condition)

@_compile_node.register
def compile_join(
self, node: nodes.JoinNode, left: ir.SQLGlotIR, right: ir.SQLGlotIR
) -> ir.SQLGlotIR:
conditions = tuple(
(
typed_expr.TypedExpr(
scalar_compiler.compile_scalar_expression(left), left.output_type
),
typed_expr.TypedExpr(
scalar_compiler.compile_scalar_expression(right), right.output_type
),
)
for left, right in node.conditions
)

return left.join(
right,
join_type=node.type,
conditions=conditions,
joins_nulls=node.joins_nulls,
)

@_compile_node.register
def compile_concat(
self, node: nodes.ConcatNode, *children: ir.SQLGlotIR
Expand Down
73 changes: 57 additions & 16 deletions bigframes/core/compile/sqlglot/sqlglot_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from bigframes import dtypes
from bigframes.core import guid
from bigframes.core.compile.sqlglot.expressions import typed_expr
import bigframes.core.compile.sqlglot.sqlglot_types as sgt
import bigframes.core.local_data as local_data
import bigframes.core.schema as bf_schema
Expand Down Expand Up @@ -212,7 +213,8 @@ def select(
for id, expr in selected_cols
]

new_expr = self._encapsulate_as_cte().select(*selections, append=False)
new_expr, _ = self._encapsulate_as_cte()
new_expr = new_expr.select(*selections, append=False)
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)

def order_by(
Expand Down Expand Up @@ -247,19 +249,52 @@ def project(
)
for id, expr in projected_cols
]
new_expr = self._encapsulate_as_cte().select(*projected_cols_expr, append=True)
new_expr, _ = self._encapsulate_as_cte()
new_expr = new_expr.select(*projected_cols_expr, append=True)
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)

def filter(
self,
condition: sge.Expression,
) -> SQLGlotIR:
"""Filters the query with the given condition."""
new_expr = self._encapsulate_as_cte()
new_expr, _ = self._encapsulate_as_cte()
return SQLGlotIR(
expr=new_expr.where(condition, append=False), uid_gen=self.uid_gen
)

def join(
self,
right: SQLGlotIR,
join_type: typing.Literal["inner", "outer", "left", "right", "cross"],
conditions: tuple[tuple[typed_expr.TypedExpr, typed_expr.TypedExpr], ...],
*,
joins_nulls: bool = True,
) -> SQLGlotIR:
"""Joins the current query with another SQLGlotIR instance."""
left_select, left_table = self._encapsulate_as_cte()
right_select, right_table = right._encapsulate_as_cte()

left_ctes = left_select.args.pop("with", [])
right_ctes = right_select.args.pop("with", [])
merged_ctes = [*left_ctes, *right_ctes]

join_conditions = [
_join_condition(left, right, joins_nulls) for left, right in conditions
]
join_on = sge.And(expressions=join_conditions) if join_conditions else None

join_type_str = join_type if join_type != "outer" else "full outer"
new_expr = (
sge.Select()
.select(sge.Star())
.from_(left_table)
.join(right_table, on=join_on, join_type=join_type_str)
)
new_expr.set("with", sge.With(expressions=merged_ctes))

return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)

def insert(
self,
destination: bigquery.TableReference,
Expand Down Expand Up @@ -320,12 +355,12 @@ def _explode_single_column(
offset=offset,
)
selection = sge.Star(replace=[unnested_column_alias.as_(column)])

# TODO: "CROSS" if not keep_empty else "LEFT"
# TODO: overlaps_with_parent to replace existing column.
new_expr = (
self._encapsulate_as_cte()
.select(selection, append=False)
.join(unnest_expr, join_type="CROSS")
new_expr, _ = self._encapsulate_as_cte()
new_expr = new_expr.select(selection, append=False).join(
unnest_expr, join_type="CROSS"
)
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)

Expand Down Expand Up @@ -373,16 +408,15 @@ def _explode_multiple_columns(
for column in columns
]
)
new_expr = (
self._encapsulate_as_cte()
.select(selection, append=False)
.join(unnest_expr, join_type="CROSS")
new_expr, _ = self._encapsulate_as_cte()
new_expr = new_expr.select(selection, append=False).join(
unnest_expr, join_type="CROSS"
)
return SQLGlotIR(expr=new_expr, uid_gen=self.uid_gen)

def _encapsulate_as_cte(
self,
) -> sge.Select:
) -> typing.Tuple[sge.Select, sge.Table]:
"""Transforms a given sge.Select query by pushing its main SELECT statement
into a new CTE and then generates a 'SELECT * FROM new_cte_name'
for the new query."""
Expand All @@ -397,11 +431,10 @@ def _encapsulate_as_cte(
alias=new_cte_name,
)
new_with_clause = sge.With(expressions=[*existing_ctes, new_cte])
new_select_expr = (
sge.Select().select(sge.Star()).from_(sge.Table(this=new_cte_name))
)
new_table_expr = sge.Table(this=new_cte_name)
new_select_expr = sge.Select().select(sge.Star()).from_(new_table_expr)
new_select_expr.set("with", new_with_clause)
return new_select_expr
return new_select_expr, new_table_expr


def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
Expand Down Expand Up @@ -451,3 +484,11 @@ def _table(table: bigquery.TableReference) -> sge.Table:
db=sg.to_identifier(table.dataset_id, quoted=True),
catalog=sg.to_identifier(table.project, quoted=True),
)


def _join_condition(
left: typed_expr.TypedExpr,
right: typed_expr.TypedExpr,
joins_nulls: bool,
) -> typing.Union[sge.EQ, sge.And]:
return sge.EQ(this=left.expr, expression=right.expr)
2 changes: 0 additions & 2 deletions bigframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3369,8 +3369,6 @@ def merge(
"right",
"cross",
] = "inner",
# TODO(garrettwu): Currently can take inner, outer, left and right. To support
# cross joins
on: Union[blocks.Label, Sequence[blocks.Label], None] = None,
*,
left_on: Union[blocks.Label, Sequence[blocks.Label], None] = None,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
WITH `bfcte_1` AS (
SELECT
`int64_col` AS `bfcol_0`,
`rowindex` AS `bfcol_1`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_2` AS (
SELECT
`bfcol_1` AS `bfcol_2`,
`bfcol_0` AS `bfcol_3`
FROM `bfcte_1`
), `bfcte_0` AS (
SELECT
`int64_col` AS `bfcol_4`,
`int64_too` AS `bfcol_5`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_3` AS (
SELECT
`bfcol_4` AS `bfcol_6`,
`bfcol_5` AS `bfcol_7`
FROM `bfcte_0`
), `bfcte_4` AS (
SELECT
*
FROM `bfcte_2`
LEFT JOIN `bfcte_3`
ON `bfcol_2` = `bfcol_6`
)
SELECT
`bfcol_3` AS `int64_col`,
`bfcol_7` AS `int64_too`
FROM `bfcte_4`
51 changes: 51 additions & 0 deletions tests/unit/core/compile/sqlglot/test_compile_join.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

import bigframes.pandas as bpd

pytest.importorskip("pytest_snapshot")


def test_compile_join(scalar_types_df: bpd.DataFrame, snapshot):
left = scalar_types_df[["int64_col"]]
right = scalar_types_df.set_index("int64_col")[["int64_too"]]
join = left.join(right)
snapshot.assert_match(join.sql, "out.sql")


def test_compile_join_w_how(scalar_types_df: bpd.DataFrame):
left = scalar_types_df[["int64_col"]]
right = scalar_types_df.set_index("int64_col")[["int64_too"]]

join_sql = left.join(right, how="left").sql
assert "LEFT JOIN" in join_sql
assert "ON" in join_sql

join_sql = left.join(right, how="right").sql
assert "RIGHT JOIN" in join_sql
assert "ON" in join_sql

join_sql = left.join(right, how="outer").sql
assert "FULL OUTER JOIN" in join_sql
assert "ON" in join_sql

join_sql = left.join(right, how="inner").sql
assert "INNER JOIN" in join_sql
assert "ON" in join_sql

join_sql = left.merge(right, how="cross").sql
assert "CROSS JOIN" in join_sql
assert "ON" not in join_sql
2 changes: 1 addition & 1 deletion third_party/bigframes_vendored/pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4748,7 +4748,7 @@ def merge(
right:
Object to merge with.
how:
``{'left', 'right', 'outer', 'inner'}, default 'inner'``
``{'left', 'right', 'outer', 'inner', 'cross'}, default 'inner'``
Type of merge to be performed.
``left``: use only keys from left frame, similar to a SQL left outer join;
preserve key order.
Expand Down