Skip to content

Commit 267d8df

Browse files
feat: optimize CASE WHEN for divide-by-zero protection pattern
- Adds a specialization for the common pattern: CASE WHEN y > 0 THEN x / y ELSE NULL END - Add EvalMethod::DivideByZeroProtection variant - Add pattern detection in find_best_eval_method() - Implement safe_divide using Arrow kernels - Handle CastExpr wrapping on divisor
1 parent e5e7636 commit 267d8df

File tree

1 file changed

+230
-2
lines changed
  • datafusion/physical-expr/src/expressions

1 file changed

+230
-2
lines changed

datafusion/physical-expr/src/expressions/case.rs

Lines changed: 230 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717

1818
mod literal_lookup_table;
1919

20-
use super::{Column, Literal};
20+
use super::{CastExpr, Column, Literal};
2121
use crate::PhysicalExpr;
22-
use crate::expressions::{lit, try_cast};
22+
use crate::expressions::{BinaryExpr, lit, try_cast};
2323
use arrow::array::*;
2424
use arrow::compute::kernels::zip::zip;
2525
use arrow::compute::{
@@ -33,6 +33,7 @@ use datafusion_common::{
3333
internal_datafusion_err, internal_err,
3434
};
3535
use datafusion_expr::ColumnarValue;
36+
use datafusion_expr_common::operator::Operator;
3637
use indexmap::{IndexMap, IndexSet};
3738
use std::borrow::Cow;
3839
use std::hash::Hash;
@@ -81,6 +82,14 @@ enum EvalMethod {
8182
///
8283
/// See [`LiteralLookupTable`] for more details
8384
WithExprScalarLookupTable(LiteralLookupTable),
85+
86+
/// This is a specialization for divide-by-zero protection pattern:
87+
/// CASE WHEN y > 0 THEN x / y ELSE NULL END
88+
/// CASE WHEN y != 0 THEN x / y ELSE NULL END
89+
///
90+
/// Instead of evaluating the full CASE expression, it is preferred to directly perform division
91+
/// that return NULL when the divisor is zero.
92+
DivideByZeroProtection,
8493
}
8594

8695
/// Implementing hash so we can use `derive` on [`EvalMethod`].
@@ -647,6 +656,19 @@ impl CaseExpr {
647656
return Ok(EvalMethod::WithExpression(body.project()?));
648657
}
649658

659+
// Check for divide-by-zero protection pattern:
660+
// CASE WHEN y > 0 THEN x / y ELSE NULL END
661+
if body.when_then_expr.len() == 1 && body.else_expr.is_none() {
662+
let (when_expr, then_expr) = &body.when_then_expr[0];
663+
664+
if let Some(checked_operand) = Self::extract_non_zero_operand(when_expr)
665+
&& let Some((_numerator, divisor)) = Self::extract_division_operands(then_expr)
666+
&& divisor.eq(&checked_operand)
667+
{
668+
return Ok(EvalMethod::DivideByZeroProtection);
669+
}
670+
}
671+
650672
Ok(
651673
if body.when_then_expr.len() == 1
652674
&& is_cheap_and_infallible(&(body.when_then_expr[0].1))
@@ -681,6 +703,67 @@ impl CaseExpr {
681703
pub fn else_expr(&self) -> Option<&Arc<dyn PhysicalExpr>> {
682704
self.body.else_expr.as_ref()
683705
}
706+
707+
/// Extract the operand being checked for non-zero from a comparison expression.
708+
/// Return Some(operand) for patterns like `y > 0`, `y != 0`, `0 < y`, `0 != y`.
709+
fn extract_non_zero_operand(
710+
expr: &Arc<dyn PhysicalExpr>,
711+
) -> Option<Arc<dyn PhysicalExpr>> {
712+
let binary = expr.as_any().downcast_ref::<BinaryExpr>()?;
713+
714+
match binary.op() {
715+
// y > 0 or y != 0
716+
Operator::Gt | Operator::NotEq if Self::is_literal_zero(binary.right()) => {
717+
Some(Arc::clone(binary.left()))
718+
}
719+
// 0 < y or 0 != y
720+
Operator::Lt | Operator::NotEq if Self::is_literal_zero(binary.left()) => {
721+
Some(Arc::clone(binary.right()))
722+
}
723+
_ => None,
724+
}
725+
}
726+
727+
/// Extract (numerator, divisor) from a division expression.
728+
fn extract_division_operands(
729+
expr: &Arc<dyn PhysicalExpr>,
730+
) -> Option<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> {
731+
let binary = expr.as_any().downcast_ref::<BinaryExpr>()?;
732+
733+
if binary.op() == &Operator::Divide {
734+
let divisor =
735+
if let Some(cast) = binary.right().as_any().downcast_ref::<CastExpr>() {
736+
Arc::clone(cast.expr())
737+
} else {
738+
Arc::clone(binary.right())
739+
};
740+
Some((Arc::clone(binary.left()), divisor))
741+
} else {
742+
None
743+
}
744+
}
745+
746+
/// Check if an expression is a literal zero value
747+
fn is_literal_zero(expr: &Arc<dyn PhysicalExpr>) -> bool {
748+
if let Some(lit) = expr.as_any().downcast_ref::<Literal>() {
749+
match lit.value() {
750+
ScalarValue::Int8(Some(0))
751+
| ScalarValue::Int16(Some(0))
752+
| ScalarValue::Int32(Some(0))
753+
| ScalarValue::Int64(Some(0))
754+
| ScalarValue::UInt8(Some(0))
755+
| ScalarValue::UInt16(Some(0))
756+
| ScalarValue::UInt32(Some(0))
757+
| ScalarValue::UInt64(Some(0)) => true,
758+
ScalarValue::Float16(Some(v)) if v.to_f32() == 0.0 => true,
759+
ScalarValue::Float32(Some(v)) if *v == 0.0 => true,
760+
ScalarValue::Float64(Some(v)) if *v == 0.0 => true,
761+
_ => false,
762+
}
763+
} else {
764+
false
765+
}
766+
}
684767
}
685768

686769
impl CaseBody {
@@ -1170,6 +1253,19 @@ impl CaseExpr {
11701253

11711254
Ok(result)
11721255
}
1256+
1257+
fn divide_by_zero_protection(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
1258+
let then_expr = &self.body.when_then_expr[0].1;
1259+
let binary = then_expr
1260+
.as_any()
1261+
.downcast_ref::<BinaryExpr>()
1262+
.expect("then expression should be a binary expression");
1263+
1264+
let numerator = binary.left().evaluate(batch)?;
1265+
let divisor = binary.right().evaluate(batch)?;
1266+
1267+
safe_divide(&numerator, &divisor)
1268+
}
11731269
}
11741270

11751271
impl PhysicalExpr for CaseExpr {
@@ -1268,6 +1364,7 @@ impl PhysicalExpr for CaseExpr {
12681364
EvalMethod::WithExprScalarLookupTable(lookup_table) => {
12691365
self.with_lookup_table(batch, lookup_table)
12701366
}
1367+
EvalMethod::DivideByZeroProtection => self.divide_by_zero_protection(batch),
12711368
}
12721369
}
12731370

@@ -1389,6 +1486,78 @@ fn replace_with_null(
13891486
Ok(with_null)
13901487
}
13911488

1489+
fn safe_divide(
1490+
numerator: &ColumnarValue,
1491+
divisor: &ColumnarValue,
1492+
) -> Result<ColumnarValue> {
1493+
if let ColumnarValue::Scalar(div_scalar) = divisor
1494+
&& is_scalar_zero(div_scalar)
1495+
{
1496+
let data_type = numerator.data_type();
1497+
return match numerator {
1498+
ColumnarValue::Array(arr) => {
1499+
Ok(ColumnarValue::Array(new_null_array(&data_type, arr.len())))
1500+
}
1501+
ColumnarValue::Scalar(_) => Ok(ColumnarValue::Scalar(
1502+
ScalarValue::try_new_null(&data_type)?,
1503+
)),
1504+
};
1505+
}
1506+
1507+
let num_rows = match (numerator, divisor) {
1508+
(ColumnarValue::Array(arr), _) => arr.len(),
1509+
(_, ColumnarValue::Array(arr)) => arr.len(),
1510+
_ => 1,
1511+
};
1512+
1513+
let num_array = numerator.clone().into_array(num_rows)?;
1514+
let div_array = divisor.clone().into_array(num_rows)?;
1515+
1516+
let result = safe_divide_arrays(&num_array, &div_array)?;
1517+
1518+
if matches!(numerator, ColumnarValue::Scalar(_))
1519+
&& matches!(divisor, ColumnarValue::Scalar(_))
1520+
{
1521+
Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(
1522+
&result, 0,
1523+
)?))
1524+
} else {
1525+
Ok(ColumnarValue::Array(result))
1526+
}
1527+
}
1528+
1529+
fn safe_divide_arrays(numerator: &ArrayRef, divisor: &ArrayRef) -> Result<ArrayRef> {
1530+
use arrow::compute::kernels::cmp::eq;
1531+
use arrow::compute::kernels::numeric::div;
1532+
1533+
let zero = ScalarValue::new_zero(divisor.data_type())?.to_scalar()?;
1534+
let zero_mask = eq(divisor, &zero)?;
1535+
1536+
let ones = ScalarValue::new_one(divisor.data_type())?.to_scalar()?;
1537+
let safe_divisor = zip(&zero_mask, &ones, divisor)?;
1538+
1539+
let result = div(&numerator, &safe_divisor)?;
1540+
1541+
Ok(nullif(&result, &zero_mask)?)
1542+
}
1543+
1544+
fn is_scalar_zero(scalar: &ScalarValue) -> bool {
1545+
match scalar {
1546+
ScalarValue::Int8(Some(0))
1547+
| ScalarValue::Int16(Some(0))
1548+
| ScalarValue::Int32(Some(0))
1549+
| ScalarValue::Int64(Some(0))
1550+
| ScalarValue::UInt8(Some(0))
1551+
| ScalarValue::UInt16(Some(0))
1552+
| ScalarValue::UInt32(Some(0))
1553+
| ScalarValue::UInt64(Some(0)) => true,
1554+
ScalarValue::Float16(Some(v)) if v.to_f32() == 0.0 => true,
1555+
ScalarValue::Float32(Some(v)) if *v == 0.0 => true,
1556+
ScalarValue::Float64(Some(v)) if *v == 0.0 => true,
1557+
_ => false,
1558+
}
1559+
}
1560+
13921561
/// Create a CASE expression
13931562
pub fn case(
13941563
expr: Option<Arc<dyn PhysicalExpr>>,
@@ -2298,6 +2467,65 @@ mod tests {
22982467
Ok(())
22992468
}
23002469

2470+
#[test]
2471+
fn test_divide_by_zero_protection_specialization() -> Result<()> {
2472+
let batch = case_test_batch1()?;
2473+
let schema = batch.schema();
2474+
2475+
// CASE WHEN a > 0 THEN 25.0 / cast(a, float64) ELSE NULL END
2476+
let when = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &schema)?;
2477+
let then = binary(
2478+
lit(25.0f64),
2479+
Operator::Divide,
2480+
cast(col("a", &schema)?, &schema, Float64)?,
2481+
&schema,
2482+
)?;
2483+
2484+
let expr = CaseExpr::try_new(None, vec![(when, then)], None)?;
2485+
2486+
assert!(
2487+
matches!(expr.eval_method, EvalMethod::DivideByZeroProtection),
2488+
"Expected DivideByZeroProtection, got {:?}",
2489+
expr.eval_method
2490+
);
2491+
2492+
let result = expr
2493+
.evaluate(&batch)?
2494+
.into_array(batch.num_rows())
2495+
.expect("Failed to convert to array");
2496+
let result = as_float64_array(&result)?;
2497+
2498+
let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]);
2499+
assert_eq!(expected, result);
2500+
2501+
Ok(())
2502+
}
2503+
2504+
#[test]
2505+
fn test_divide_by_zero_protection_specialization_not_applied() -> Result<()> {
2506+
let batch = case_test_batch1()?;
2507+
let schema = batch.schema();
2508+
2509+
// CASE WHEN a > 0 THEN b / c ELSE NULL END
2510+
// Divisor (c) != checked operand (a), should NOT use specialization
2511+
let when = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &schema)?;
2512+
let then = binary(
2513+
col("b", &schema)?,
2514+
Operator::Divide,
2515+
col("c", &schema)?,
2516+
&schema,
2517+
)?;
2518+
2519+
let expr = CaseExpr::try_new(None, vec![(when, then)], None)?;
2520+
2521+
assert!(
2522+
!matches!(expr.eval_method, EvalMethod::DivideByZeroProtection),
2523+
"Should NOT use DivideByZeroProtection when divisor doesn't match"
2524+
);
2525+
2526+
Ok(())
2527+
}
2528+
23012529
fn make_col(name: &str, index: usize) -> Arc<dyn PhysicalExpr> {
23022530
Arc::new(Column::new(name, index))
23032531
}

0 commit comments

Comments
 (0)