From 85936d3b2d23b25215a5fc5f53ccd43d9cb7c370 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 20 Aug 2025 00:07:30 +0000 Subject: [PATCH 1/2] chore: implement eq, eq_null_match, ne compilers --- .../sqlglot/expressions/binary_compiler.py | 67 +++++++++++++------ .../test_eq_null_match/out.sql | 14 ++++ .../test_eq_numeric/out.sql | 54 +++++++++++++++ .../test_ne_numeric/out.sql | 54 +++++++++++++++ .../expressions/test_binary_compiler.py | 32 ++++++++- 5 files changed, 198 insertions(+), 23 deletions(-) create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_eq_null_match/out.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_eq_numeric/out.sql create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_ne_numeric/out.sql diff --git a/bigframes/core/compile/sqlglot/expressions/binary_compiler.py b/bigframes/core/compile/sqlglot/expressions/binary_compiler.py index 61f1eba607..a3a0fffcbc 100644 --- a/bigframes/core/compile/sqlglot/expressions/binary_compiler.py +++ b/bigframes/core/compile/sqlglot/expressions/binary_compiler.py @@ -38,12 +38,7 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: return sge.Concat(expressions=[left.expr, right.expr]) if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype): - left_expr = left.expr - if left.dtype == dtypes.BOOL_DTYPE: - left_expr = sge.Cast(this=left_expr, to="INT64") - right_expr = right.expr - if right.dtype == dtypes.BOOL_DTYPE: - right_expr = sge.Cast(this=right_expr, to="INT64") + left_expr, right_expr = _coerce_bools(left, right) return sge.Add(this=left_expr, expression=right_expr) if ( @@ -74,15 +69,36 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: ) -@BINARY_OP_REGISTRATION.register(ops.div_op) +@BINARY_OP_REGISTRATION.register(ops.eq_op) +def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: + left_expr, right_expr = _coerce_bools(left, right) + return sge.EQ(this=left_expr, expression=right_expr) + + +@BINARY_OP_REGISTRATION.register(ops.eq_null_match_op) def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: left_expr = left.expr - if left.dtype == dtypes.BOOL_DTYPE: + if left.dtype == dtypes.BOOL_DTYPE and right.dtype != dtypes.BOOL_DTYPE: left_expr = sge.Cast(this=left_expr, to="INT64") + right_expr = right.expr - if right.dtype == dtypes.BOOL_DTYPE: + if right.dtype == dtypes.BOOL_DTYPE and left.dtype != dtypes.BOOL_DTYPE: right_expr = sge.Cast(this=right_expr, to="INT64") + sentinel = sge.convert("$NULL_SENTINEL$") + left_coalesce = sge.Coalesce( + this=sge.Cast(this=left_expr, to="STRING"), expressions=[sentinel] + ) + right_coalesce = sge.Coalesce( + this=sge.Cast(this=right_expr, to="STRING"), expressions=[sentinel] + ) + return sge.EQ(this=left_coalesce, expression=right_coalesce) + + +@BINARY_OP_REGISTRATION.register(ops.div_op) +def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: + left_expr, right_expr = _coerce_bools(left, right) + result = sge.func("IEEE_DIVIDE", left_expr, right_expr) if left.dtype == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype): return sge.Cast(this=sge.Floor(this=result), to="INT64") @@ -139,12 +155,7 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: @BINARY_OP_REGISTRATION.register(ops.mul_op) def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: - left_expr = left.expr - if left.dtype == dtypes.BOOL_DTYPE: - left_expr = sge.Cast(this=left_expr, to="INT64") - right_expr = right.expr - if right.dtype == dtypes.BOOL_DTYPE: - right_expr = sge.Cast(this=right_expr, to="INT64") + left_expr, right_expr = _coerce_bools(left, right) result = sge.Mul(this=left_expr, expression=right_expr) @@ -156,15 +167,16 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: return result +@BINARY_OP_REGISTRATION.register(ops.ne_op) +def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: + left_expr, right_expr = _coerce_bools(left, right) + return sge.NEQ(this=left_expr, expression=right_expr) + + @BINARY_OP_REGISTRATION.register(ops.sub_op) def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype): - left_expr = left.expr - if left.dtype == dtypes.BOOL_DTYPE: - left_expr = sge.Cast(this=left_expr, to="INT64") - right_expr = right.expr - if right.dtype == dtypes.BOOL_DTYPE: - right_expr = sge.Cast(this=right_expr, to="INT64") + left_expr, right_expr = _coerce_bools(left, right) return sge.Sub(this=left_expr, expression=right_expr) if ( @@ -201,3 +213,16 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: @BINARY_OP_REGISTRATION.register(ops.obj_make_ref_op) def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: return sge.func("OBJ.MAKE_REF", left.expr, right.expr) + + +def _coerce_bools( + left: TypedExpr, right: TypedExpr +) -> tuple[sge.Expression, sge.Expression]: + """Coerce boolean expressions to INT64 for binary operations.""" + left_expr = left.expr + if left.dtype == dtypes.BOOL_DTYPE: + left_expr = sge.Cast(this=left_expr, to="INT64") + right_expr = right.expr + if right.dtype == dtypes.BOOL_DTYPE: + right_expr = sge.Cast(this=right_expr, to="INT64") + return left_expr, right_expr diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_eq_null_match/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_eq_null_match/out.sql new file mode 100644 index 0000000000..90cbcfe5c7 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_eq_null_match/out.sql @@ -0,0 +1,14 @@ +WITH `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_0`, + `int64_col` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + COALESCE(CAST(`bfcol_1` AS STRING), '$NULL_SENTINEL$') = COALESCE(CAST(CAST(`bfcol_0` AS INT64) AS STRING), '$NULL_SENTINEL$') AS `bfcol_4` + FROM `bfcte_0` +) +SELECT + `bfcol_4` AS `int64_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_eq_numeric/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_eq_numeric/out.sql new file mode 100644 index 0000000000..8e3c52310d --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_eq_numeric/out.sql @@ -0,0 +1,54 @@ +WITH `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_0`, + `int64_col` AS `bfcol_1`, + `rowindex` AS `bfcol_2` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + `bfcol_2` AS `bfcol_6`, + `bfcol_1` AS `bfcol_7`, + `bfcol_0` AS `bfcol_8`, + `bfcol_1` = `bfcol_1` AS `bfcol_9` + FROM `bfcte_0` +), `bfcte_2` AS ( + SELECT + *, + `bfcol_6` AS `bfcol_14`, + `bfcol_7` AS `bfcol_15`, + `bfcol_8` AS `bfcol_16`, + `bfcol_9` AS `bfcol_17`, + `bfcol_7` = 1 AS `bfcol_18` + FROM `bfcte_1` +), `bfcte_3` AS ( + SELECT + *, + `bfcol_14` AS `bfcol_24`, + `bfcol_15` AS `bfcol_25`, + `bfcol_16` AS `bfcol_26`, + `bfcol_17` AS `bfcol_27`, + `bfcol_18` AS `bfcol_28`, + `bfcol_15` = CAST(`bfcol_16` AS INT64) AS `bfcol_29` + FROM `bfcte_2` +), `bfcte_4` AS ( + SELECT + *, + `bfcol_24` AS `bfcol_36`, + `bfcol_25` AS `bfcol_37`, + `bfcol_26` AS `bfcol_38`, + `bfcol_27` AS `bfcol_39`, + `bfcol_28` AS `bfcol_40`, + `bfcol_29` AS `bfcol_41`, + CAST(`bfcol_26` AS INT64) = `bfcol_25` AS `bfcol_42` + FROM `bfcte_3` +) +SELECT + `bfcol_36` AS `rowindex`, + `bfcol_37` AS `int64_col`, + `bfcol_38` AS `bool_col`, + `bfcol_39` AS `int_ne_int`, + `bfcol_40` AS `int_ne_1`, + `bfcol_41` AS `int_ne_bool`, + `bfcol_42` AS `bool_ne_int` +FROM `bfcte_4` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_ne_numeric/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_ne_numeric/out.sql new file mode 100644 index 0000000000..6fba4b960f --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_ne_numeric/out.sql @@ -0,0 +1,54 @@ +WITH `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_0`, + `int64_col` AS `bfcol_1`, + `rowindex` AS `bfcol_2` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + `bfcol_2` AS `bfcol_6`, + `bfcol_1` AS `bfcol_7`, + `bfcol_0` AS `bfcol_8`, + `bfcol_1` <> `bfcol_1` AS `bfcol_9` + FROM `bfcte_0` +), `bfcte_2` AS ( + SELECT + *, + `bfcol_6` AS `bfcol_14`, + `bfcol_7` AS `bfcol_15`, + `bfcol_8` AS `bfcol_16`, + `bfcol_9` AS `bfcol_17`, + `bfcol_7` <> 1 AS `bfcol_18` + FROM `bfcte_1` +), `bfcte_3` AS ( + SELECT + *, + `bfcol_14` AS `bfcol_24`, + `bfcol_15` AS `bfcol_25`, + `bfcol_16` AS `bfcol_26`, + `bfcol_17` AS `bfcol_27`, + `bfcol_18` AS `bfcol_28`, + `bfcol_15` <> CAST(`bfcol_16` AS INT64) AS `bfcol_29` + FROM `bfcte_2` +), `bfcte_4` AS ( + SELECT + *, + `bfcol_24` AS `bfcol_36`, + `bfcol_25` AS `bfcol_37`, + `bfcol_26` AS `bfcol_38`, + `bfcol_27` AS `bfcol_39`, + `bfcol_28` AS `bfcol_40`, + `bfcol_29` AS `bfcol_41`, + CAST(`bfcol_26` AS INT64) <> `bfcol_25` AS `bfcol_42` + FROM `bfcte_3` +) +SELECT + `bfcol_36` AS `rowindex`, + `bfcol_37` AS `int64_col`, + `bfcol_38` AS `bool_col`, + `bfcol_39` AS `int_ne_int`, + `bfcol_40` AS `int_ne_1`, + `bfcol_41` AS `int_ne_bool`, + `bfcol_42` AS `bool_ne_int` +FROM `bfcte_4` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_binary_compiler.py b/tests/unit/core/compile/sqlglot/expressions/test_binary_compiler.py index 49426fe6c3..11586cad02 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_binary_compiler.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_binary_compiler.py @@ -107,6 +107,24 @@ def test_div_timedelta(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(bf_df.sql, "out.sql") +def test_eq_null_match(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col", "bool_col"]] + sql = _apply_binary_op(bf_df, ops.eq_null_match_op, "int64_col", "bool_col") + snapshot.assert_match(sql, "out.sql") + + +def test_eq_numeric(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col", "bool_col"]] + + bf_df["int_ne_int"] = bf_df["int64_col"] == bf_df["int64_col"] + bf_df["int_ne_1"] = bf_df["int64_col"] == 1 + + bf_df["int_ne_bool"] = bf_df["int64_col"] == bf_df["bool_col"] + bf_df["bool_ne_int"] = bf_df["bool_col"] == bf_df["int64_col"] + + snapshot.assert_match(bf_df.sql, "out.sql") + + def test_floordiv_numeric(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["int64_col", "bool_col", "float64_col"]] @@ -121,8 +139,6 @@ def test_floordiv_numeric(scalar_types_df: bpd.DataFrame, snapshot): bf_df["int_div_bool"] = bf_df["int64_col"] // bf_df["bool_col"] bf_df["bool_div_int"] = bf_df["bool_col"] // bf_df["int64_col"] - snapshot.assert_match(bf_df.sql, "out.sql") - def test_floordiv_timedelta(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["timestamp_col", "date_col"]] @@ -200,3 +216,15 @@ def test_mul_timedelta(scalar_types_df: bpd.DataFrame, snapshot): def test_obj_make_ref(scalar_types_df: bpd.DataFrame, snapshot): blob_df = scalar_types_df["string_col"].str.to_blob() snapshot.assert_match(blob_df.to_frame().sql, "out.sql") + + +def test_ne_numeric(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col", "bool_col"]] + + bf_df["int_ne_int"] = bf_df["int64_col"] != bf_df["int64_col"] + bf_df["int_ne_1"] = bf_df["int64_col"] != 1 + + bf_df["int_ne_bool"] = bf_df["int64_col"] != bf_df["bool_col"] + bf_df["bool_ne_int"] = bf_df["bool_col"] != bf_df["int64_col"] + + snapshot.assert_match(bf_df.sql, "out.sql") From 0950b015cef519fac2f0d5cdcae0067b4b352afb Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Thu, 21 Aug 2025 00:03:22 +0000 Subject: [PATCH 2/2] address comments --- .../sqlglot/expressions/binary_compiler.py | 77 +++++++++---------- 1 file changed, 35 insertions(+), 42 deletions(-) diff --git a/bigframes/core/compile/sqlglot/expressions/binary_compiler.py b/bigframes/core/compile/sqlglot/expressions/binary_compiler.py index a3a0fffcbc..84e783bb66 100644 --- a/bigframes/core/compile/sqlglot/expressions/binary_compiler.py +++ b/bigframes/core/compile/sqlglot/expressions/binary_compiler.py @@ -38,16 +38,15 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: return sge.Concat(expressions=[left.expr, right.expr]) if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype): - left_expr, right_expr = _coerce_bools(left, right) + left_expr = _coerce_bool_to_int(left) + right_expr = _coerce_bool_to_int(right) return sge.Add(this=left_expr, expression=right_expr) if ( dtypes.is_time_or_date_like(left.dtype) and right.dtype == dtypes.TIMEDELTA_DTYPE ): - left_expr = left.expr - if left.dtype == dtypes.DATE_DTYPE: - left_expr = sge.Cast(this=left_expr, to="DATETIME") + left_expr = _coerce_date_to_datetime(left) return sge.TimestampAdd( this=left_expr, expression=right.expr, unit=sge.Var(this="MICROSECOND") ) @@ -55,9 +54,7 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: dtypes.is_time_or_date_like(right.dtype) and left.dtype == dtypes.TIMEDELTA_DTYPE ): - right_expr = right.expr - if right.dtype == dtypes.DATE_DTYPE: - right_expr = sge.Cast(this=right_expr, to="DATETIME") + right_expr = _coerce_date_to_datetime(right) return sge.TimestampAdd( this=right_expr, expression=left.expr, unit=sge.Var(this="MICROSECOND") ) @@ -71,19 +68,20 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: @BINARY_OP_REGISTRATION.register(ops.eq_op) def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: - left_expr, right_expr = _coerce_bools(left, right) + left_expr = _coerce_bool_to_int(left) + right_expr = _coerce_bool_to_int(right) return sge.EQ(this=left_expr, expression=right_expr) @BINARY_OP_REGISTRATION.register(ops.eq_null_match_op) def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: left_expr = left.expr - if left.dtype == dtypes.BOOL_DTYPE and right.dtype != dtypes.BOOL_DTYPE: - left_expr = sge.Cast(this=left_expr, to="INT64") + if right.dtype != dtypes.BOOL_DTYPE: + left_expr = _coerce_bool_to_int(left) right_expr = right.expr - if right.dtype == dtypes.BOOL_DTYPE and left.dtype != dtypes.BOOL_DTYPE: - right_expr = sge.Cast(this=right_expr, to="INT64") + if left.dtype != dtypes.BOOL_DTYPE: + right_expr = _coerce_bool_to_int(right) sentinel = sge.convert("$NULL_SENTINEL$") left_coalesce = sge.Coalesce( @@ -97,7 +95,8 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: @BINARY_OP_REGISTRATION.register(ops.div_op) def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: - left_expr, right_expr = _coerce_bools(left, right) + left_expr = _coerce_bool_to_int(left) + right_expr = _coerce_bool_to_int(right) result = sge.func("IEEE_DIVIDE", left_expr, right_expr) if left.dtype == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype): @@ -108,12 +107,8 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: @BINARY_OP_REGISTRATION.register(ops.floordiv_op) def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: - left_expr = left.expr - if left.dtype == dtypes.BOOL_DTYPE: - left_expr = sge.Cast(this=left_expr, to="INT64") - right_expr = right.expr - if right.dtype == dtypes.BOOL_DTYPE: - right_expr = sge.Cast(this=right_expr, to="INT64") + left_expr = _coerce_bool_to_int(left) + right_expr = _coerce_bool_to_int(right) result: sge.Expression = sge.Cast( this=sge.Floor(this=sge.func("IEEE_DIVIDE", left_expr, right_expr)), to="INT64" @@ -155,7 +150,8 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: @BINARY_OP_REGISTRATION.register(ops.mul_op) def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: - left_expr, right_expr = _coerce_bools(left, right) + left_expr = _coerce_bool_to_int(left) + right_expr = _coerce_bool_to_int(right) result = sge.Mul(this=left_expr, expression=right_expr) @@ -169,35 +165,31 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: @BINARY_OP_REGISTRATION.register(ops.ne_op) def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: - left_expr, right_expr = _coerce_bools(left, right) + left_expr = _coerce_bool_to_int(left) + right_expr = _coerce_bool_to_int(right) return sge.NEQ(this=left_expr, expression=right_expr) @BINARY_OP_REGISTRATION.register(ops.sub_op) def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype): - left_expr, right_expr = _coerce_bools(left, right) + left_expr = _coerce_bool_to_int(left) + right_expr = _coerce_bool_to_int(right) return sge.Sub(this=left_expr, expression=right_expr) if ( dtypes.is_time_or_date_like(left.dtype) and right.dtype == dtypes.TIMEDELTA_DTYPE ): - left_expr = left.expr - if left.dtype == dtypes.DATE_DTYPE: - left_expr = sge.Cast(this=left_expr, to="DATETIME") + left_expr = _coerce_date_to_datetime(left) return sge.TimestampSub( this=left_expr, expression=right.expr, unit=sge.Var(this="MICROSECOND") ) if dtypes.is_time_or_date_like(left.dtype) and dtypes.is_time_or_date_like( right.dtype ): - left_expr = left.expr - if left.dtype == dtypes.DATE_DTYPE: - left_expr = sge.Cast(this=left_expr, to="DATETIME") - right_expr = right.expr - if right.dtype == dtypes.DATE_DTYPE: - right_expr = sge.Cast(this=right_expr, to="DATETIME") + left_expr = _coerce_date_to_datetime(left) + right_expr = _coerce_date_to_datetime(right) return sge.TimestampDiff( this=left_expr, expression=right_expr, unit=sge.Var(this="MICROSECOND") ) @@ -215,14 +207,15 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression: return sge.func("OBJ.MAKE_REF", left.expr, right.expr) -def _coerce_bools( - left: TypedExpr, right: TypedExpr -) -> tuple[sge.Expression, sge.Expression]: - """Coerce boolean expressions to INT64 for binary operations.""" - left_expr = left.expr - if left.dtype == dtypes.BOOL_DTYPE: - left_expr = sge.Cast(this=left_expr, to="INT64") - right_expr = right.expr - if right.dtype == dtypes.BOOL_DTYPE: - right_expr = sge.Cast(this=right_expr, to="INT64") - return left_expr, right_expr +def _coerce_bool_to_int(typed_expr: TypedExpr) -> sge.Expression: + """Coerce boolean expression to integer.""" + if typed_expr.dtype == dtypes.BOOL_DTYPE: + return sge.Cast(this=typed_expr.expr, to="INT64") + return typed_expr.expr + + +def _coerce_date_to_datetime(typed_expr: TypedExpr) -> sge.Expression: + """Coerce date expression to datetime.""" + if typed_expr.dtype == dtypes.DATE_DTYPE: + return sge.Cast(this=typed_expr.expr, to="DATETIME") + return typed_expr.expr