Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions bigframes/core/compile/sqlglot/expressions/unary_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import bigframes.core.compile.sqlglot.expressions.constants as constants
from bigframes.core.compile.sqlglot.expressions.op_registration import OpRegistration
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
import bigframes.dtypes as dtypes

UNARY_OP_REGISTRATION = OpRegistration()

Expand Down Expand Up @@ -420,7 +421,28 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:

@UNARY_OP_REGISTRATION.register(ops.IsInOp)
def _(op: ops.IsInOp, expr: TypedExpr) -> sge.Expression:
return sge.In(this=expr.expr, expressions=[sge.convert(v) for v in op.values])
values = []
is_numeric_expr = dtypes.is_numeric(expr.dtype)
for value in op.values:
if value is None:
continue
dtype = dtypes.bigframes_type(type(value))
if expr.dtype == dtype or is_numeric_expr and dtypes.is_numeric(dtype):
values.append(sge.convert(value))

if op.match_nulls:
contains_nulls = any(_is_null(value) for value in op.values)
if contains_nulls:
return sge.Is(this=expr.expr, expression=sge.Null()) | sge.In(
this=expr.expr, expressions=values
)

if len(values) == 0:
return sge.convert(False)

return sge.func(
"COALESCE", sge.In(this=expr.expr, expressions=values), sge.convert(False)
)


@UNARY_OP_REGISTRATION.register(ops.isalnum_op)
Expand Down Expand Up @@ -767,7 +789,7 @@ def _(op: ops.ToTimedeltaOp, expr: TypedExpr) -> sge.Expression:
factor = UNIT_TO_US_CONVERSION_FACTORS[op.unit]
if factor != 1:
value = sge.Mul(this=value, expression=sge.convert(factor))
return sge.Interval(this=value, unit=sge.Identifier(this="MICROSECOND"))
return value


@UNARY_OP_REGISTRATION.register(ops.UnixMicros)
Expand Down Expand Up @@ -866,3 +888,9 @@ def _(op: ops.ZfillOp, expr: TypedExpr) -> sge.Expression:
],
default=sge.func("LPAD", expr.expr, sge.convert(op.width), sge.convert("0")),
)


# Helpers
def _is_null(value) -> bool:
# float NaN/inf should be treated as distinct from 'true' null values
return typing.cast(bool, pd.isna(value)) and not isinstance(value, float)
2 changes: 1 addition & 1 deletion tests/system/small/engines/test_generic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def test_engines_invert_op(scalars_array_value: array_value.ArrayValue, engine):
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)


@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
def test_engines_isin_op(scalars_array_value: array_value.ArrayValue, engine):
arr, col_ids = scalars_array_value.compute_values(
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ WITH `bfcte_0` AS (
`bfcol_1` AS `bfcol_8`,
`bfcol_2` AS `bfcol_9`,
`bfcol_0` AS `bfcol_10`,
INTERVAL `bfcol_3` MICROSECOND AS `bfcol_11`
`bfcol_3` AS `bfcol_11`
FROM `bfcte_0`
), `bfcte_2` AS (
SELECT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ WITH `bfcte_0` AS (
`bfcol_1` AS `bfcol_8`,
`bfcol_2` AS `bfcol_9`,
`bfcol_0` AS `bfcol_10`,
INTERVAL `bfcol_3` MICROSECOND AS `bfcol_11`
`bfcol_3` AS `bfcol_11`
FROM `bfcte_0`
), `bfcte_2` AS (
SELECT
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,32 @@
WITH `bfcte_0` AS (
SELECT
`int64_col` AS `bfcol_0`
`int64_col` AS `bfcol_0`,
`float64_col` AS `bfcol_1`
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
), `bfcte_1` AS (
SELECT
*,
`bfcol_0` IN (1, 2, 3) AS `bfcol_1`
COALESCE(`bfcol_0` IN (1, 2, 3), FALSE) AS `bfcol_2`,
(
`bfcol_0` IS NULL
) OR `bfcol_0` IN (123456) AS `bfcol_3`,
COALESCE(`bfcol_0` IN (1.0, 2.0, 3.0), FALSE) AS `bfcol_4`,
FALSE AS `bfcol_5`,
COALESCE(`bfcol_0` IN (2.5, 3), FALSE) AS `bfcol_6`,
FALSE AS `bfcol_7`,
COALESCE(`bfcol_0` IN (123456), FALSE) AS `bfcol_8`,
(
`bfcol_1` IS NULL
) OR `bfcol_1` IN (1, 2, 3) AS `bfcol_9`
FROM `bfcte_0`
)
SELECT
`bfcol_1` AS `int64_col`
`bfcol_2` AS `ints`,
`bfcol_3` AS `ints_w_null`,
`bfcol_4` AS `floats`,
`bfcol_5` AS `strings`,
`bfcol_6` AS `mixed`,
`bfcol_7` AS `empty`,
`bfcol_8` AS `ints_wo_match_nulls`,
`bfcol_9` AS `float_in_ints`
FROM `bfcte_1`
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ WITH `bfcte_0` AS (
*,
`bfcol_1` AS `bfcol_4`,
`bfcol_0` AS `bfcol_5`,
INTERVAL `bfcol_0` MICROSECOND AS `bfcol_6`
`bfcol_0` AS `bfcol_6`
FROM `bfcte_0`
), `bfcte_2` AS (
SELECT
*,
`bfcol_4` AS `bfcol_10`,
`bfcol_5` AS `bfcol_11`,
`bfcol_6` AS `bfcol_12`,
INTERVAL (`bfcol_5` * 1000000) MICROSECOND AS `bfcol_13`
`bfcol_5` * 1000000 AS `bfcol_13`
FROM `bfcte_1`
), `bfcte_3` AS (
SELECT
Expand All @@ -25,7 +25,7 @@ WITH `bfcte_0` AS (
`bfcol_11` AS `bfcol_19`,
`bfcol_12` AS `bfcol_20`,
`bfcol_13` AS `bfcol_21`,
INTERVAL (`bfcol_11` * 604800000000) MICROSECOND AS `bfcol_22`
`bfcol_11` * 604800000000 AS `bfcol_22`
FROM `bfcte_2`
)
SELECT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,12 +370,25 @@ def test_invert(scalar_types_df: bpd.DataFrame, snapshot):


def test_is_in(scalar_types_df: bpd.DataFrame, snapshot):
col_name = "int64_col"
bf_df = scalar_types_df[[col_name]]
sql = _apply_unary_ops(
bf_df, [ops.IsInOp(values=(1, 2, 3)).as_expr(col_name)], [col_name]
)
int_col = "int64_col"
float_col = "float64_col"
bf_df = scalar_types_df[[int_col, float_col]]
ops_map = {
"ints": ops.IsInOp(values=(1, 2, 3)).as_expr(int_col),
"ints_w_null": ops.IsInOp(values=(None, 123456)).as_expr(int_col),
"floats": ops.IsInOp(values=(1.0, 2.0, 3.0), match_nulls=False).as_expr(
int_col
),
"strings": ops.IsInOp(values=("1.0", "2.0")).as_expr(int_col),
"mixed": ops.IsInOp(values=("1.0", 2.5, 3)).as_expr(int_col),
"empty": ops.IsInOp(values=()).as_expr(int_col),
"ints_wo_match_nulls": ops.IsInOp(
values=(None, 123456), match_nulls=False
).as_expr(int_col),
"float_in_ints": ops.IsInOp(values=(1, 2, 3, None)).as_expr(float_col),
}

sql = _apply_unary_ops(bf_df, list(ops_map.values()), list(ops_map.keys()))
snapshot.assert_match(sql, "out.sql")


Expand Down