diff --git a/bigframes/core/compile/polars/compiler.py b/bigframes/core/compile/polars/compiler.py index dfa2ebc818..c31c122078 100644 --- a/bigframes/core/compile/polars/compiler.py +++ b/bigframes/core/compile/polars/compiler.py @@ -487,8 +487,14 @@ def compile_offsets(self, node: nodes.PromoteOffsetsNode): def compile_join(self, node: nodes.JoinNode): left = self.compile_node(node.left_child) right = self.compile_node(node.right_child) - left_on = [l_name.id.sql for l_name, _ in node.conditions] - right_on = [r_name.id.sql for _, r_name in node.conditions] + + left_on = [] + right_on = [] + for left_ex, right_ex in node.conditions: + left_ex, right_ex = lowering._coerce_comparables(left_ex, right_ex) + left_on.append(self.expr_compiler.compile_expression(left_ex)) + right_on.append(self.expr_compiler.compile_expression(right_ex)) + if node.type == "right": return self._ordered_join( right, left, "left", right_on, left_on, node.joins_nulls @@ -502,8 +508,8 @@ def _ordered_join( left_frame: pl.LazyFrame, right_frame: pl.LazyFrame, how: Literal["inner", "outer", "left", "cross"], - left_on: Sequence[str], - right_on: Sequence[str], + left_on: Sequence[pl.Expr], + right_on: Sequence[pl.Expr], join_nulls: bool, ): if how == "right": diff --git a/bigframes/session/polars_executor.py b/bigframes/session/polars_executor.py index 8f669901a4..3c23e4c200 100644 --- a/bigframes/session/polars_executor.py +++ b/bigframes/session/polars_executor.py @@ -37,6 +37,7 @@ nodes.AggregateNode, nodes.FilterNode, nodes.ConcatNode, + nodes.JoinNode, ) _COMPATIBLE_SCALAR_OPS = ( diff --git a/tests/system/small/engines/test_join.py b/tests/system/small/engines/test_join.py new file mode 100644 index 0000000000..e1f9fe6070 --- /dev/null +++ b/tests/system/small/engines/test_join.py @@ -0,0 +1,90 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Literal + +import pytest + +from bigframes import operations as ops +from bigframes.core import array_value, expression, ordering +from bigframes.session import polars_executor +from bigframes.testing.engine_utils import assert_equivalence_execution + +pytest.importorskip("polars") + +# Polars used as reference as its fast and local. Generally though, prefer gbq engine where they disagree. +REFERENCE_ENGINE = polars_executor.PolarsExecutor() + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("join_type", ["left", "inner", "right", "outer"]) +def test_engines_join_on_key( + scalars_array_value: array_value.ArrayValue, + engine, + join_type: Literal["inner", "outer", "left", "right"], +): + result, _ = scalars_array_value.relational_join( + scalars_array_value, conditions=(("int64_col", "int64_col"),), type=join_type + ) + + assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("join_type", ["left", "inner", "right", "outer"]) +def test_engines_join_on_coerced_key( + scalars_array_value: array_value.ArrayValue, + engine, + join_type: Literal["inner", "outer", "left", "right"], +): + result, _ = scalars_array_value.relational_join( + scalars_array_value, conditions=(("int64_col", "float64_col"),), type=join_type + ) + + assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("join_type", ["left", "inner", "right", "outer"]) +def test_engines_join_multi_key( + scalars_array_value: array_value.ArrayValue, + engine, + join_type: Literal["inner", "outer", "left", "right"], +): + l_input = scalars_array_value.order_by([ordering.ascending_over("float64_col")]) + l_input, l_join_cols = scalars_array_value.compute_values( + [ + ops.mod_op.as_expr("int64_col", expression.const(2)), + ops.invert_op.as_expr("bool_col"), + ] + ) + r_input, r_join_cols = scalars_array_value.compute_values( + [ops.mod_op.as_expr("int64_col", expression.const(3)), expression.const(True)] + ) + + conditions = tuple((l_col, r_col) for l_col, r_col in zip(l_join_cols, r_join_cols)) + + result, _ = l_input.relational_join(r_input, conditions=conditions, type=join_type) + + assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine) + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_cross_join( + scalars_array_value: array_value.ArrayValue, + engine, +): + result, _ = scalars_array_value.relational_join(scalars_array_value, type="cross") + + assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine)