diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index 22c66719f7..7de4bdbc91 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -3434,15 +3434,9 @@ def merge( ) return DataFrame(result_block) - if on is None: - if left_on is None or right_on is None: - raise ValueError("Must specify `on` or `left_on` + `right_on`.") - else: - if left_on is not None or right_on is not None: - raise ValueError( - "Can not pass both `on` and `left_on` + `right_on` params." - ) - left_on, right_on = on, on + left_on, right_on = self._validate_left_right_on( + right, on, left_on=left_on, right_on=right_on + ) if utils.is_list_like(left_on): left_on = list(left_on) # type: ignore @@ -3479,6 +3473,41 @@ def merge( ) return DataFrame(block) + def _validate_left_right_on( + self, + right: DataFrame, + on: Union[blocks.Label, Sequence[blocks.Label], None] = None, + *, + left_on: Union[blocks.Label, Sequence[blocks.Label], None] = None, + right_on: Union[blocks.Label, Sequence[blocks.Label], None] = None, + ): + if on is not None: + if left_on is not None or right_on is not None: + raise ValueError( + "Can not pass both `on` and `left_on` + `right_on` params." + ) + return on, on + + if left_on is not None and right_on is not None: + return left_on, right_on + + left_cols = self.columns + right_cols = right.columns + common_cols = left_cols.intersection(right_cols) + if len(common_cols) == 0: + raise ValueError( + "No common columns to perform merge on." + f"Merge options: left_on={left_on}, " + f"right_on={right_on}, " + ) + if ( + not left_cols.join(common_cols, how="inner").is_unique + or not right_cols.join(common_cols, how="inner").is_unique + ): + raise ValueError(f"Data columns not unique: {repr(common_cols)}") + + return common_cols, common_cols + def join( self, other: Union[DataFrame, bigframes.series.Series], diff --git a/tests/system/small/test_pandas.py b/tests/system/small/test_pandas.py index 4e8d3d20f7..550a75e1bb 100644 --- a/tests/system/small/test_pandas.py +++ b/tests/system/small/test_pandas.py @@ -13,6 +13,7 @@ # limitations under the License. from datetime import datetime +import re import typing import pandas as pd @@ -343,7 +344,7 @@ def test_merge_left_on_right_on(scalars_dfs, merge_how): assert_pandas_df_equal(bf_result, pd_result, ignore_order=True) -def test_pd_merge_cross(scalars_dfs): +def test_merge_cross(scalars_dfs): scalars_df, scalars_pandas_df = scalars_dfs left_columns = ["int64_col", "float64_col", "int64_too"] right_columns = ["int64_col", "bool_col", "string_col", "rowindex_2"] @@ -398,6 +399,61 @@ def test_merge_series(scalars_dfs, merge_how): assert_pandas_df_equal(bf_result, pd_result, ignore_order=True) +def test_merge_w_common_columns(scalars_dfs): + scalars_df, scalars_pandas_df = scalars_dfs + left_columns = ["int64_col", "int64_too"] + right_columns = ["int64_col", "bool_col"] + + df = bpd.merge( + scalars_df[left_columns], scalars_df[right_columns], "inner", sort=True + ) + + pd_result = pd.merge( + scalars_pandas_df[left_columns], + scalars_pandas_df[right_columns], + "inner", + sort=True, + ) + assert_pandas_df_equal(df.to_pandas(), pd_result, ignore_order=True) + + +def test_merge_raises_error_when_no_common_columns(scalars_dfs): + scalars_df, _ = scalars_dfs + left_columns = ["float64_col", "int64_too"] + right_columns = ["int64_col", "bool_col"] + + left = scalars_df[left_columns] + right = scalars_df[right_columns] + + with pytest.raises( + ValueError, + match="No common columns to perform merge on.", + ): + bpd.merge(left, right, "inner") + + +def test_merge_raises_error_when_left_right_on_set(scalars_dfs): + scalars_df, _ = scalars_dfs + left_columns = ["int64_col", "int64_too"] + right_columns = ["int64_col", "bool_col"] + + left = scalars_df[left_columns] + right = scalars_df[right_columns] + + with pytest.raises( + ValueError, + match=re.escape("Can not pass both `on` and `left_on` + `right_on` params."), + ): + bpd.merge( + left, + right, + "inner", + left_on="int64_too", + right_on="int64_col", + on="int64_col", + ) + + def _convert_pandas_category(pd_s: pd.Series): """ Transforms a pandas Series with Categorical dtype into a bigframes-compatible