Skip to content

feat: add on parameter in dataframe.rolling() and dataframe.groupby.rolling() #1556

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bigframes/core/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,7 +987,7 @@ def apply_nary_op(
def multi_apply_window_op(
self,
columns: typing.Sequence[str],
op: agg_ops.WindowOp,
op: agg_ops.UnaryWindowOp,
window_spec: windows.WindowSpec,
*,
skip_null_groups: bool = False,
Expand Down Expand Up @@ -1058,7 +1058,7 @@ def project_exprs(
def apply_window_op(
self,
column: str,
op: agg_ops.WindowOp,
op: agg_ops.UnaryWindowOp,
window_spec: windows.WindowSpec,
*,
result_label: Label = None,
Expand Down
12 changes: 10 additions & 2 deletions bigframes/core/groupby/dataframe_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ def rolling(
self,
window: int,
min_periods=None,
on: str | None = None,
closed: Literal["right", "left", "both", "neither"] = "right",
) -> windows.Window:
window_spec = window_specs.WindowSpec(
Expand All @@ -320,8 +321,15 @@ def rolling(
block = self._block.order_by(
[order.ascending_over(col) for col in self._by_col_ids],
)
skip_agg_col_id = (
None if on is None else self._block.resolve_label_exact_or_error(on)
)
return windows.Window(
block, window_spec, self._selected_cols, drop_null_groups=self._dropna
block,
window_spec,
self._selected_cols,
drop_null_groups=self._dropna,
skip_agg_column_id=skip_agg_col_id,
)

@validations.requires_ordering()
Expand Down Expand Up @@ -511,7 +519,7 @@ def _aggregate_all(

def _apply_window_op(
self,
op: agg_ops.WindowOp,
op: agg_ops.UnaryWindowOp,
window: typing.Optional[window_specs.WindowSpec] = None,
numeric_only: bool = False,
):
Expand Down
2 changes: 1 addition & 1 deletion bigframes/core/groupby/series_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def _aggregate(self, aggregate_op: agg_ops.UnaryAggregateOp) -> series.Series:

def _apply_window_op(
self,
op: agg_ops.WindowOp,
op: agg_ops.UnaryWindowOp,
discard_name=False,
window: typing.Optional[window_specs.WindowSpec] = None,
never_skip_nulls: bool = False,
Expand Down
51 changes: 37 additions & 14 deletions bigframes/core/window/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,16 @@ def __init__(
value_column_ids: typing.Sequence[str],
drop_null_groups: bool = True,
is_series: bool = False,
skip_agg_column_id: str | None = None,
):
self._block = block
self._window_spec = window_spec
self._value_column_ids = value_column_ids
self._drop_null_groups = drop_null_groups
self._is_series = is_series
# The column ID that won't be aggregated on.
# This is equivalent to pandas `on` parameter in rolling()
self._skip_agg_column_id = skip_agg_column_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will eventually be a set in order to handle multiple columns for on?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be unlikely. Both pandas and SQL allow only one column to be used for on (or its equivalent)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, maybe not for on, but for grouping columns. We diverge a bit right now I think for some grouping cases. Anyways, can always change later if need be.


def count(self):
return self._apply_aggregate(agg_ops.count_op)
Expand All @@ -66,10 +70,37 @@ def _apply_aggregate(
self,
op: agg_ops.UnaryAggregateOp,
):
block = self._block
labels = [block.col_id_to_label[col] for col in self._value_column_ids]
block, result_ids = block.multi_apply_window_op(
self._value_column_ids,
agg_col_ids = [
col_id
for col_id in self._value_column_ids
if col_id != self._skip_agg_column_id
]
agg_block = self._aggregate_block(op, agg_col_ids)

if self._skip_agg_column_id is not None:
# Concat the skipped column to the result.
agg_block, _ = agg_block.join(
self._block.select_column(self._skip_agg_column_id), how="outer"
)

if self._is_series:
from bigframes.series import Series

return Series(agg_block)
else:
from bigframes.dataframe import DataFrame

# Preserve column order.
column_labels = [
self._block.col_id_to_label[col_id] for col_id in self._value_column_ids
]
return DataFrame(agg_block)._reindex_columns(column_labels)

def _aggregate_block(
self, op: agg_ops.UnaryAggregateOp, agg_col_ids: typing.List[str]
) -> blocks.Block:
block, result_ids = self._block.multi_apply_window_op(
agg_col_ids,
op,
self._window_spec,
skip_null_groups=self._drop_null_groups,
Expand All @@ -85,13 +116,5 @@ def _apply_aggregate(
)
block = block.set_index(col_ids=index_ids)

if self._is_series:
from bigframes.series import Series

return Series(block.select_columns(result_ids).with_column_labels(labels))
else:
from bigframes.dataframe import DataFrame

return DataFrame(
block.select_columns(result_ids).with_column_labels(labels)
)
labels = [self._block.col_id_to_label[col] for col in agg_col_ids]
return block.select_columns(result_ids).with_column_labels(labels)
10 changes: 6 additions & 4 deletions bigframes/core/window_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,12 @@ def __post_init__(self):
class WindowSpec:
"""
Specifies a window over which aggregate and analytic function may be applied.
grouping_keys: set of column ids to group on
preceding: Number of preceding rows in the window
following: Number of preceding rows in the window
ordering: List of columns ids and ordering direction to override base ordering

Attributes:
grouping_keys: A set of column ids to group on
bounds: The window boundaries
ordering: A list of columns ids and ordering direction to override base ordering
min_periods: The minimum number of observations in window required to have a value
"""

grouping_keys: Tuple[ex.DerefOp, ...] = tuple()
Expand Down
11 changes: 9 additions & 2 deletions bigframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3312,14 +3312,21 @@ def rolling(
self,
window: int,
min_periods=None,
on: str | None = None,
closed: Literal["right", "left", "both", "neither"] = "right",
) -> bigframes.core.window.Window:
window_def = windows.WindowSpec(
bounds=windows.RowsWindowBounds.from_window_size(window, closed),
min_periods=min_periods if min_periods is not None else window,
)
skip_agg_col_id = (
None if on is None else self._block.resolve_label_exact_or_error(on)
)
return bigframes.core.window.Window(
self._block, window_def, self._block.value_columns
self._block,
window_def,
self._block.value_columns,
skip_agg_column_id=skip_agg_col_id,
)

@validations.requires_ordering()
Expand Down Expand Up @@ -3483,7 +3490,7 @@ def pct_change(self, periods: int = 1) -> DataFrame:

def _apply_window_op(
self,
op: agg_ops.WindowOp,
op: agg_ops.UnaryWindowOp,
window_spec: windows.WindowSpec,
):
block, result_ids = self._block.multi_apply_window_op(
Expand Down
4 changes: 3 additions & 1 deletion bigframes/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,7 +1378,9 @@ def _apply_aggregation(
) -> Any:
return self._block.get_stat(self._value_column, op)

def _apply_window_op(self, op: agg_ops.WindowOp, window_spec: windows.WindowSpec):
def _apply_window_op(
self, op: agg_ops.UnaryWindowOp, window_spec: windows.WindowSpec
):
block = self._block
block, result_id = block.apply_window_op(
self._value_column, op, window_spec=window_spec, result_label=self.name
Expand Down
87 changes: 52 additions & 35 deletions tests/system/small/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,9 @@
def rolling_dfs(scalars_dfs):
bf_df, pd_df = scalars_dfs

target_cols = ["int64_too", "float64_col", "bool_col"]
target_cols = ["int64_too", "float64_col", "int64_col"]

bf_df = bf_df[target_cols].set_index("bool_col")
pd_df = pd_df[target_cols].set_index("bool_col")

return bf_df, pd_df
return bf_df[target_cols], pd_df[target_cols]


@pytest.fixture(scope="module")
Expand All @@ -49,31 +46,65 @@ def test_dataframe_rolling_closed_param(rolling_dfs, closed):
@pytest.mark.parametrize("closed", ["left", "right", "both", "neither"])
def test_dataframe_groupby_rolling_closed_param(rolling_dfs, closed):
bf_df, pd_df = rolling_dfs
# Need to specify column subset for comparison due to b/406841327
check_columns = ["float64_col", "int64_col"]

actual_result = (
bf_df.groupby(level=0).rolling(window=3, closed=closed).sum().to_pandas()
bf_df.groupby(bf_df["int64_too"] % 2)
.rolling(window=3, closed=closed)
.sum()
.to_pandas()
)

expected_result = pd_df.groupby(level=0).rolling(window=3, closed=closed).sum()
pd.testing.assert_frame_equal(actual_result, expected_result, check_dtype=False)
expected_result = (
pd_df.groupby(pd_df["int64_too"] % 2).rolling(window=3, closed=closed).sum()
)
pd.testing.assert_frame_equal(
actual_result[check_columns], expected_result, check_dtype=False
)


def test_dataframe_rolling_default_closed_param(rolling_dfs):
def test_dataframe_rolling_on(rolling_dfs):
bf_df, pd_df = rolling_dfs

actual_result = bf_df.rolling(window=3).sum().to_pandas()
actual_result = bf_df.rolling(window=3, on="int64_too").sum().to_pandas()

expected_result = pd_df.rolling(window=3).sum()
expected_result = pd_df.rolling(window=3, on="int64_too").sum()
pd.testing.assert_frame_equal(actual_result, expected_result, check_dtype=False)


def test_dataframe_groupby_rolling_default_closed_param(rolling_dfs):
def test_dataframe_rolling_on_invalid_column_raise_error(rolling_dfs):
bf_df, _ = rolling_dfs

with pytest.raises(ValueError):
bf_df.rolling(window=3, on="whatever").sum()


def test_dataframe_groupby_rolling_on(rolling_dfs):
bf_df, pd_df = rolling_dfs
# Need to specify column subset for comparison due to b/406841327
check_columns = ["float64_col", "int64_col"]

actual_result = bf_df.groupby(level=0).rolling(window=3).sum().to_pandas()
actual_result = (
bf_df.groupby(bf_df["int64_too"] % 2)
.rolling(window=3, on="float64_col")
.sum()
.to_pandas()
)

expected_result = pd_df.groupby(level=0).rolling(window=3).sum()
pd.testing.assert_frame_equal(actual_result, expected_result, check_dtype=False)
expected_result = (
pd_df.groupby(pd_df["int64_too"] % 2).rolling(window=3, on="float64_col").sum()
)
pd.testing.assert_frame_equal(
actual_result[check_columns], expected_result, check_dtype=False
)


def test_dataframe_groupby_rolling_on_invalid_column_raise_error(rolling_dfs):
bf_df, _ = rolling_dfs

with pytest.raises(ValueError):
bf_df.groupby(level=0).rolling(window=3, on="whatever").sum()


@pytest.mark.parametrize("closed", ["left", "right", "both", "neither"])
Expand Down Expand Up @@ -103,24 +134,6 @@ def test_series_groupby_rolling_closed_param(rolling_series, closed):
pd.testing.assert_series_equal(actual_result, expected_result, check_dtype=False)


def test_series_rolling_default_closed_param(rolling_series):
bf_series, df_series = rolling_series

actual_result = bf_series.rolling(window=3).sum().to_pandas()

expected_result = df_series.rolling(window=3).sum()
pd.testing.assert_series_equal(actual_result, expected_result, check_dtype=False)


def test_series_groupby_rolling_default_closed_param(rolling_series):
bf_series, df_series = rolling_series

actual_result = bf_series.groupby(bf_series % 2).rolling(window=3).sum().to_pandas()

expected_result = df_series.groupby(df_series % 2).rolling(window=3).sum()
pd.testing.assert_series_equal(actual_result, expected_result, check_dtype=False)


@pytest.mark.parametrize(
("windowing"),
[
Expand Down Expand Up @@ -181,8 +194,12 @@ def test_series_window_agg_ops(rolling_series, windowing, agg_op):
pytest.param(lambda x: x.var(), id="var"),
],
)
def test_dataframe_window_agg_ops(rolling_dfs, windowing, agg_op):
bf_df, pd_df = rolling_dfs
def test_dataframe_window_agg_ops(scalars_dfs, windowing, agg_op):
bf_df, pd_df = scalars_dfs
target_columns = ["int64_too", "float64_col", "bool_col"]
index_column = "bool_col"
bf_df = bf_df[target_columns].set_index(index_column)
pd_df = pd_df[target_columns].set_index(index_column)

bf_result = agg_op(windowing(bf_df)).to_pandas()

Expand Down
Loading