diff --git a/bigframes/core/indexers.py b/bigframes/core/indexers.py index f4c4f9011e..4f5a9471b9 100644 --- a/bigframes/core/indexers.py +++ b/bigframes/core/indexers.py @@ -15,7 +15,7 @@ from __future__ import annotations import typing -from typing import Tuple, Union +from typing import List, Tuple, Union import ibis import pandas as pd @@ -271,40 +271,59 @@ def _loc_getitem_series_or_dataframe( if isinstance(key, bigframes.series.Series) and key.dtype == "boolean": return series_or_dataframe[key] elif isinstance(key, bigframes.series.Series): - # TODO(henryjsolberg): support MultiIndex temp_name = guid.generate_guid(prefix="temp_series_name_") + if len(series_or_dataframe.index.names) > 1: + temp_name = series_or_dataframe.index.names[0] key = key.rename(temp_name) keys_df = key.to_frame() keys_df = keys_df.set_index(temp_name, drop=True) return _perform_loc_list_join(series_or_dataframe, keys_df) elif isinstance(key, bigframes.core.indexes.Index): - # TODO(henryjsolberg): support MultiIndex block = key._data._get_block() block = block.select_columns(()) keys_df = bigframes.dataframe.DataFrame(block) return _perform_loc_list_join(series_or_dataframe, keys_df) elif pd.api.types.is_list_like(key): - # TODO(henryjsolberg): support MultiIndex - if len(key) == 0: # type: ignore + key = typing.cast(List, key) + if len(key) == 0: return typing.cast( Union[bigframes.dataframe.DataFrame, bigframes.series.Series], series_or_dataframe.iloc[0:0], ) - - # We can't upload a DataFrame with None as the column name, so set it - # an arbitrary string. - index_name = series_or_dataframe.index.name - index_name_is_none = index_name is None - if index_name_is_none: - index_name = "unnamed_col" - - keys_df = bigframes.dataframe.DataFrame( - {index_name: key}, session=series_or_dataframe._get_block().expr._session - ) - keys_df = keys_df.set_index(index_name, drop=True) - - if index_name_is_none: - keys_df.index.name = None + if pd.api.types.is_list_like(key[0]): + original_index_names = series_or_dataframe.index.names + num_index_cols = len(original_index_names) + + entry_col_count_correct = [len(entry) == num_index_cols for entry in key] + if not all(entry_col_count_correct): + # pandas usually throws TypeError in these cases- tuple causes IndexError, but that + # seems like unintended behavior + raise TypeError( + "All entries must be of equal length when indexing by list of listlikes" + ) + temporary_index_names = [ + guid.generate_guid(prefix="temp_loc_index_") + for _ in range(len(original_index_names)) + ] + index_cols_dict = {} + for i in range(num_index_cols): + index_name = temporary_index_names[i] + values = [entry[i] for entry in key] + index_cols_dict[index_name] = values + keys_df = bigframes.dataframe.DataFrame(index_cols_dict) + keys_df = keys_df.set_index(temporary_index_names, drop=True) + keys_df = keys_df.rename_axis(original_index_names) + else: + # We can't upload a DataFrame with None as the column name, so set it + # an arbitrary string. + index_name = series_or_dataframe.index.name + index_name_is_none = index_name is None + if index_name_is_none: + index_name = "unnamed_col" + keys_df = bigframes.dataframe.DataFrame({index_name: key}) + keys_df = keys_df.set_index(index_name, drop=True) + if index_name_is_none: + keys_df.index.name = None return _perform_loc_list_join(series_or_dataframe, keys_df) elif isinstance(key, slice): if (key.start is None) and (key.stop is None) and (key.step is None): diff --git a/tests/system/small/test_dataframe.py b/tests/system/small/test_dataframe.py index ecafb7c1bf..309e8df4f0 100644 --- a/tests/system/small/test_dataframe.py +++ b/tests/system/small/test_dataframe.py @@ -2787,6 +2787,22 @@ def test_loc_list_integer_index(scalars_df_index, scalars_pandas_df_index): ) +def test_loc_list_multiindex(scalars_df_index, scalars_pandas_df_index): + scalars_df_multiindex = scalars_df_index.set_index(["string_col", "int64_col"]) + scalars_pandas_df_multiindex = scalars_pandas_df_index.set_index( + ["string_col", "int64_col"] + ) + index_list = [("Hello, World!", -234892), ("Hello, World!", 123456789)] + + bf_result = scalars_df_multiindex.loc[index_list] + pd_result = scalars_pandas_df_multiindex.loc[index_list] + + pd.testing.assert_frame_equal( + bf_result.to_pandas(), + pd_result, + ) + + def test_iloc_list(scalars_df_index, scalars_pandas_df_index): index_list = [0, 0, 0, 5, 4, 7] @@ -2863,6 +2879,24 @@ def test_loc_bf_series_string_index(scalars_df_index, scalars_pandas_df_index): ) +def test_loc_bf_series_multiindex(scalars_df_index, scalars_pandas_df_index): + pd_string_series = scalars_pandas_df_index.string_col.iloc[[0, 5, 1, 1, 5]] + bf_string_series = scalars_df_index.string_col.iloc[[0, 5, 1, 1, 5]] + + scalars_df_multiindex = scalars_df_index.set_index(["string_col", "int64_col"]) + scalars_pandas_df_multiindex = scalars_pandas_df_index.set_index( + ["string_col", "int64_col"] + ) + + bf_result = scalars_df_multiindex.loc[bf_string_series] + pd_result = scalars_pandas_df_multiindex.loc[pd_string_series] + + pd.testing.assert_frame_equal( + bf_result.to_pandas(), + pd_result, + ) + + def test_loc_bf_index_integer_index(scalars_df_index, scalars_pandas_df_index): pd_index = scalars_pandas_df_index.iloc[[0, 5, 1, 1, 5]].index bf_index = scalars_df_index.iloc[[0, 5, 1, 1, 5]].index diff --git a/tests/system/small/test_series.py b/tests/system/small/test_series.py index f2ced841da..bd9edbb1ca 100644 --- a/tests/system/small/test_series.py +++ b/tests/system/small/test_series.py @@ -2624,6 +2624,22 @@ def test_loc_list_integer_index(scalars_df_index, scalars_pandas_df_index): ) +def test_loc_list_multiindex(scalars_df_index, scalars_pandas_df_index): + scalars_df_multiindex = scalars_df_index.set_index(["string_col", "int64_col"]) + scalars_pandas_df_multiindex = scalars_pandas_df_index.set_index( + ["string_col", "int64_col"] + ) + index_list = [("Hello, World!", -234892), ("Hello, World!", 123456789)] + + bf_result = scalars_df_multiindex.int64_too.loc[index_list] + pd_result = scalars_pandas_df_multiindex.int64_too.loc[index_list] + + pd.testing.assert_series_equal( + bf_result.to_pandas(), + pd_result, + ) + + def test_iloc_list(scalars_df_index, scalars_pandas_df_index): index_list = [0, 0, 0, 5, 4, 7] @@ -2681,6 +2697,24 @@ def test_loc_bf_series_string_index(scalars_df_index, scalars_pandas_df_index): ) +def test_loc_bf_series_multiindex(scalars_df_index, scalars_pandas_df_index): + pd_string_series = scalars_pandas_df_index.string_col.iloc[[0, 5, 1, 1, 5]] + bf_string_series = scalars_df_index.string_col.iloc[[0, 5, 1, 1, 5]] + + scalars_df_multiindex = scalars_df_index.set_index(["string_col", "int64_col"]) + scalars_pandas_df_multiindex = scalars_pandas_df_index.set_index( + ["string_col", "int64_col"] + ) + + bf_result = scalars_df_multiindex.int64_too.loc[bf_string_series] + pd_result = scalars_pandas_df_multiindex.int64_too.loc[pd_string_series] + + pd.testing.assert_series_equal( + bf_result.to_pandas(), + pd_result, + ) + + def test_loc_bf_index_integer_index(scalars_df_index, scalars_pandas_df_index): pd_index = scalars_pandas_df_index.iloc[[0, 5, 1, 1, 5]].index bf_index = scalars_df_index.iloc[[0, 5, 1, 1, 5]].index