diff --git a/bigframes/bigquery/__init__.py b/bigframes/bigquery/__init__.py index 303120b88a..28a818e709 100644 --- a/bigframes/bigquery/__init__.py +++ b/bigframes/bigquery/__init__.py @@ -272,6 +272,46 @@ def json_extract_array( return series._apply_unary_op(ops.JSONExtractArray(json_path=json_path)) +# Approximate aggrgate functions defined from +# https://cloud.google.com/bigquery/docs/reference/standard-sql/approximate_aggregate_functions + + +def approx_top_count( + series: series.Series, + number: int, +) -> series.Series: + """Returns the approximate top elements of `expression` as an array of STRUCTs. + The number parameter specifies the number of elements returned. + + Each `STRUCT` contains two fields. The first field (named `value`) contains an input + value. The second field (named `count`) contains an `INT64` specifying the number + of times the value was returned. + + Returns `NULL` if there are zero input rows. + + **Examples:** + + >>> import bigframes.pandas as bpd + >>> import bigframes.bigquery as bbq + >>> bpd.options.display.progress_bar = None + >>> s = bpd.Series(["apple", "apple", "pear", "pear", "pear", "banana"]) + >>> bbq.approx_top_count(s, number=2) + [{'value': 'pear', 'count': 3}, {'value': 'apple', 'count': 2}] + + Args: + series (bigframes.series.Series): + The Series with any data type that the `GROUP BY` clause supports. + number (int): + An integer specifying the number of times the value was returned. + + Returns: + bigframes.series.Series: A new Series with the result data. + """ + if number < 1: + raise ValueError("The number of approx_top_count must be at least 1") + return series._apply_aggregation(agg_ops.ApproxTopCountOp(number=number)) + + def struct(value: dataframe.DataFrame) -> series.Series: """Takes a DataFrame and converts it into a Series of structs with each struct entry corresponding to a DataFrame row and each struct field diff --git a/bigframes/core/compile/aggregate_compiler.py b/bigframes/core/compile/aggregate_compiler.py index 91a3045efb..b65953934d 100644 --- a/bigframes/core/compile/aggregate_compiler.py +++ b/bigframes/core/compile/aggregate_compiler.py @@ -11,6 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from __future__ import annotations + import functools import typing from typing import cast, List, Optional @@ -19,6 +22,7 @@ import bigframes_vendored.ibis.expr.operations as vendored_ibis_ops import ibis import ibis.expr.datatypes as ibis_dtypes +import ibis.expr.operations as ibis_ops import ibis.expr.types as ibis_types import pandas as pd @@ -196,6 +200,34 @@ def _( return cast(ibis_types.NumericValue, value) +@compile_unary_agg.register +def _( + op: agg_ops.ApproxTopCountOp, + column: ibis_types.Column, + window=None, +) -> ibis_types.ArrayColumn: + # APPROX_TOP_COUNT has very few allowed windows. + if window is not None: + raise NotImplementedError( + f"Approx top count with windowing is not supported. {constants.FEEDBACK_LINK}" + ) + + # Define a user-defined function (UDF) that approximates the top counts of an expression. + # The type of value is dynamically matching the input column. + def approx_top_count(expression, number: ibis_dtypes.int64): # type: ignore + ... + + return_type = ibis_dtypes.Array( + ibis_dtypes.Struct.from_tuples( + [("value", column.type()), ("count", ibis_dtypes.int64)] + ) + ) + approx_top_count.__annotations__["return"] = return_type + udf_op = ibis_ops.udf.agg.builtin(approx_top_count) + + return udf_op(expression=column, number=op.number) + + @compile_unary_agg.register @numeric_op def _( diff --git a/bigframes/operations/aggregations.py b/bigframes/operations/aggregations.py index d071889ac4..faba7465d9 100644 --- a/bigframes/operations/aggregations.py +++ b/bigframes/operations/aggregations.py @@ -184,6 +184,23 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT return input_types[0] +@dataclasses.dataclass(frozen=True) +class ApproxTopCountOp(UnaryAggregateOp): + name: typing.ClassVar[str] = "approx_top_count" + number: int + + def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType: + if not dtypes.is_orderable(input_types[0]): + raise TypeError(f"Type {input_types[0]} is not orderable") + + input_type = input_types[0] + fields = [ + pa.field("value", dtypes.bigframes_dtype_to_arrow_dtype(input_type)), + pa.field("count", pa.int64()), + ] + return pd.ArrowDtype(pa.list_(pa.struct(fields))) + + @dataclasses.dataclass(frozen=True) class MeanOp(UnaryAggregateOp): name: ClassVar[str] = "mean" diff --git a/tests/system/small/bigquery/test_approx_agg.py b/tests/system/small/bigquery/test_approx_agg.py new file mode 100644 index 0000000000..c88f5850f8 --- /dev/null +++ b/tests/system/small/bigquery/test_approx_agg.py @@ -0,0 +1,76 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import bigframes.bigquery as bbq +import bigframes.pandas as bpd + + +@pytest.mark.parametrize( + ("data", "expected"), + [ + pytest.param( + [1, 2, 3, 3, 2], [{"value": 3, "count": 2}, {"value": 2, "count": 2}] + ), + pytest.param( + ["apple", "apple", "pear", "pear", "pear", "banana"], + [{"value": "pear", "count": 3}, {"value": "apple", "count": 2}], + ), + pytest.param( + [True, False, True, False, True], + [{"value": True, "count": 3}, {"value": False, "count": 2}], + ), + pytest.param( + [], + [], + ), + pytest.param( + [[1, 2], [1], [1, 2]], + [], + marks=pytest.mark.xfail(raises=TypeError), + ), + ], + ids=["int64", "string", "bool", "null", "array"], +) +def test_approx_top_count_w_dtypes(data, expected): + s = bpd.Series(data) + result = bbq.approx_top_count(s, number=2) + assert result == expected + + +@pytest.mark.parametrize( + ("number", "expected"), + [ + pytest.param( + 0, + [], + marks=pytest.mark.xfail(raises=ValueError), + ), + pytest.param(1, [{"value": 3, "count": 2}]), + pytest.param( + 4, + [ + {"value": 3, "count": 2}, + {"value": 2, "count": 2}, + {"value": 1, "count": 1}, + ], + ), + ], + ids=["zero", "one", "full"], +) +def test_approx_top_count_w_numbers(number, expected): + s = bpd.Series([1, 2, 3, 3, 2]) + result = bbq.approx_top_count(s, number=number) + assert result == expected