From e8b8c9a2ead528ae7417d72d2e975e3b5c657bf4 Mon Sep 17 00:00:00 2001 From: Shobhit Singh Date: Tue, 24 Oct 2023 18:57:29 +0000 Subject: [PATCH] feat: Implement operator `@` for `DataFrame.dot` --- bigframes/dataframe.py | 2 ++ tests/system/small/test_dataframe.py | 33 +++++++++++++++++++++++++++ tests/system/small/test_multiindex.py | 16 +++++++++++++ 3 files changed, 51 insertions(+) diff --git a/bigframes/dataframe.py b/bigframes/dataframe.py index 32a2908a42..70995ce1c3 100644 --- a/bigframes/dataframe.py +++ b/bigframes/dataframe.py @@ -2678,3 +2678,5 @@ def get_right_id(id): result = result[other.name].rename() return result + + __matmul__ = dot diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index 19e50eb06d..536a046bcc 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -3207,6 +3207,23 @@ def test_df_dot( ) +def test_df_dot_operator( + matrix_2by3_df, matrix_2by3_pandas_df, matrix_3by4_df, matrix_3by4_pandas_df +): + bf_result = (matrix_2by3_df @ matrix_3by4_df).to_pandas() + pd_result = matrix_2by3_pandas_df @ matrix_3by4_pandas_df + + # Patch pandas dtypes for testing parity + # Pandas result is object instead of Int64 (nullable) dtype. + for name in pd_result.columns: + pd_result[name] = pd_result[name].astype(pd.Int64Dtype()) + + pd.testing.assert_frame_equal( + bf_result, + pd_result, + ) + + def test_df_dot_series( matrix_2by3_df, matrix_2by3_pandas_df, matrix_3by4_df, matrix_3by4_pandas_df ): @@ -3221,3 +3238,19 @@ def test_df_dot_series( bf_result, pd_result, ) + + +def test_df_dot_operator_series( + matrix_2by3_df, matrix_2by3_pandas_df, matrix_3by4_df, matrix_3by4_pandas_df +): + bf_result = (matrix_2by3_df @ matrix_3by4_df["x"]).to_pandas() + pd_result = matrix_2by3_pandas_df @ matrix_3by4_pandas_df["x"] + + # Patch pandas dtypes for testing parity + # Pandas result is object instead of Int64 (nullable) dtype. + pd_result = pd_result.astype(pd.Int64Dtype()) + + pd.testing.assert_series_equal( + bf_result, + pd_result, + ) diff --git a/tests/system/small/test_multiindex.py b/tests/system/small/test_multiindex.py index b5c78de69c..3f2d707f7c 100644 --- a/tests/system/small/test_multiindex.py +++ b/tests/system/small/test_multiindex.py @@ -947,6 +947,9 @@ def test_df_multi_index_dot_not_supported(): with pytest.raises(NotImplementedError, match="Multi-index input is not supported"): bf1.dot(bf2) + with pytest.raises(NotImplementedError, match="Multi-index input is not supported"): + bf1 @ bf2 + # right multi-index right_index = pandas.MultiIndex.from_tuples([("a", "aa"), ("a", "ab"), ("b", "bb")]) bf1 = bpd.DataFrame(left_matrix) @@ -954,6 +957,9 @@ def test_df_multi_index_dot_not_supported(): with pytest.raises(NotImplementedError, match="Multi-index input is not supported"): bf1.dot(bf2) + with pytest.raises(NotImplementedError, match="Multi-index input is not supported"): + bf1 @ bf2 + def test_column_multi_index_dot_not_supported(): left_matrix = [[1, 2, 3], [2, 5, 7]] @@ -971,6 +977,11 @@ def test_column_multi_index_dot_not_supported(): ): bf1.dot(bf2) + with pytest.raises( + NotImplementedError, match="Multi-level column input is not supported" + ): + bf1 @ bf2 + # right multi-columns bf1 = bpd.DataFrame(left_matrix) bf2 = bpd.DataFrame(right_matrix, columns=multi_level_columns) @@ -978,3 +989,8 @@ def test_column_multi_index_dot_not_supported(): NotImplementedError, match="Multi-level column input is not supported" ): bf1.dot(bf2) + + with pytest.raises( + NotImplementedError, match="Multi-level column input is not supported" + ): + bf1 @ bf2