diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 758317d3d2798..2e0203bdd7988 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -17,9 +17,9 @@ mod literal_lookup_table; -use super::{Column, Literal}; +use super::{CastExpr, Column, Literal}; use crate::PhysicalExpr; -use crate::expressions::{lit, try_cast}; +use crate::expressions::{BinaryExpr, lit, try_cast}; use arrow::array::*; use arrow::compute::kernels::zip::zip; use arrow::compute::{ @@ -33,6 +33,7 @@ use datafusion_common::{ internal_datafusion_err, internal_err, }; use datafusion_expr::ColumnarValue; +use datafusion_expr_common::operator::Operator; use indexmap::{IndexMap, IndexSet}; use std::borrow::Cow; use std::hash::Hash; @@ -81,6 +82,14 @@ enum EvalMethod { /// /// See [`LiteralLookupTable`] for more details WithExprScalarLookupTable(LiteralLookupTable), + + /// This is a specialization for divide-by-zero protection pattern: + /// CASE WHEN y > 0 THEN x / y ELSE NULL END + /// CASE WHEN y != 0 THEN x / y ELSE NULL END + /// + /// Instead of evaluating the full CASE expression, it is preferred to directly perform division + /// that return NULL when the divisor is zero. + DivideByZeroProtection, } /// Implementing hash so we can use `derive` on [`EvalMethod`]. @@ -647,6 +656,20 @@ impl CaseExpr { return Ok(EvalMethod::WithExpression(body.project()?)); } + // Check for divide-by-zero protection pattern: + // CASE WHEN y > 0 THEN x / y ELSE NULL END + if body.when_then_expr.len() == 1 && body.else_expr.is_none() { + let (when_expr, then_expr) = &body.when_then_expr[0]; + + if let Some(checked_operand) = Self::extract_non_zero_operand(when_expr) + && let Some((_numerator, divisor)) = + Self::extract_division_operands(then_expr) + && divisor.eq(&checked_operand) + { + return Ok(EvalMethod::DivideByZeroProtection); + } + } + Ok( if body.when_then_expr.len() == 1 && is_cheap_and_infallible(&(body.when_then_expr[0].1)) @@ -681,6 +704,67 @@ impl CaseExpr { pub fn else_expr(&self) -> Option<&Arc> { self.body.else_expr.as_ref() } + + /// Extract the operand being checked for non-zero from a comparison expression. + /// Return Some(operand) for patterns like `y > 0`, `y != 0`, `0 < y`, `0 != y`. + fn extract_non_zero_operand( + expr: &Arc, + ) -> Option> { + let binary = expr.as_any().downcast_ref::()?; + + match binary.op() { + // y > 0 or y != 0 + Operator::Gt | Operator::NotEq if Self::is_literal_zero(binary.right()) => { + Some(Arc::clone(binary.left())) + } + // 0 < y or 0 != y + Operator::Lt | Operator::NotEq if Self::is_literal_zero(binary.left()) => { + Some(Arc::clone(binary.right())) + } + _ => None, + } + } + + /// Extract (numerator, divisor) from a division expression. + fn extract_division_operands( + expr: &Arc, + ) -> Option<(Arc, Arc)> { + let binary = expr.as_any().downcast_ref::()?; + + if binary.op() == &Operator::Divide { + let divisor = + if let Some(cast) = binary.right().as_any().downcast_ref::() { + Arc::clone(cast.expr()) + } else { + Arc::clone(binary.right()) + }; + Some((Arc::clone(binary.left()), divisor)) + } else { + None + } + } + + /// Check if an expression is a literal zero value + fn is_literal_zero(expr: &Arc) -> bool { + if let Some(lit) = expr.as_any().downcast_ref::() { + match lit.value() { + ScalarValue::Int8(Some(0)) + | ScalarValue::Int16(Some(0)) + | ScalarValue::Int32(Some(0)) + | ScalarValue::Int64(Some(0)) + | ScalarValue::UInt8(Some(0)) + | ScalarValue::UInt16(Some(0)) + | ScalarValue::UInt32(Some(0)) + | ScalarValue::UInt64(Some(0)) => true, + ScalarValue::Float16(Some(v)) if v.to_f32() == 0.0 => true, + ScalarValue::Float32(Some(v)) if *v == 0.0 => true, + ScalarValue::Float64(Some(v)) if *v == 0.0 => true, + _ => false, + } + } else { + false + } + } } impl CaseBody { @@ -1170,6 +1254,19 @@ impl CaseExpr { Ok(result) } + + fn divide_by_zero_protection(&self, batch: &RecordBatch) -> Result { + let then_expr = &self.body.when_then_expr[0].1; + let binary = then_expr + .as_any() + .downcast_ref::() + .expect("then expression should be a binary expression"); + + let numerator = binary.left().evaluate(batch)?; + let divisor = binary.right().evaluate(batch)?; + + safe_divide(&numerator, &divisor) + } } impl PhysicalExpr for CaseExpr { @@ -1268,6 +1365,7 @@ impl PhysicalExpr for CaseExpr { EvalMethod::WithExprScalarLookupTable(lookup_table) => { self.with_lookup_table(batch, lookup_table) } + EvalMethod::DivideByZeroProtection => self.divide_by_zero_protection(batch), } } @@ -1389,6 +1487,78 @@ fn replace_with_null( Ok(with_null) } +fn safe_divide( + numerator: &ColumnarValue, + divisor: &ColumnarValue, +) -> Result { + if let ColumnarValue::Scalar(div_scalar) = divisor + && is_scalar_zero(div_scalar) + { + let data_type = numerator.data_type(); + return match numerator { + ColumnarValue::Array(arr) => { + Ok(ColumnarValue::Array(new_null_array(&data_type, arr.len()))) + } + ColumnarValue::Scalar(_) => Ok(ColumnarValue::Scalar( + ScalarValue::try_new_null(&data_type)?, + )), + }; + } + + let num_rows = match (numerator, divisor) { + (ColumnarValue::Array(arr), _) => arr.len(), + (_, ColumnarValue::Array(arr)) => arr.len(), + _ => 1, + }; + + let num_array = numerator.clone().into_array(num_rows)?; + let div_array = divisor.clone().into_array(num_rows)?; + + let result = safe_divide_arrays(&num_array, &div_array)?; + + if matches!(numerator, ColumnarValue::Scalar(_)) + && matches!(divisor, ColumnarValue::Scalar(_)) + { + Ok(ColumnarValue::Scalar(ScalarValue::try_from_array( + &result, 0, + )?)) + } else { + Ok(ColumnarValue::Array(result)) + } +} + +fn safe_divide_arrays(numerator: &ArrayRef, divisor: &ArrayRef) -> Result { + use arrow::compute::kernels::cmp::eq; + use arrow::compute::kernels::numeric::div; + + let zero = ScalarValue::new_zero(divisor.data_type())?.to_scalar()?; + let zero_mask = eq(divisor, &zero)?; + + let ones = ScalarValue::new_one(divisor.data_type())?.to_scalar()?; + let safe_divisor = zip(&zero_mask, &ones, divisor)?; + + let result = div(&numerator, &safe_divisor)?; + + Ok(nullif(&result, &zero_mask)?) +} + +fn is_scalar_zero(scalar: &ScalarValue) -> bool { + match scalar { + ScalarValue::Int8(Some(0)) + | ScalarValue::Int16(Some(0)) + | ScalarValue::Int32(Some(0)) + | ScalarValue::Int64(Some(0)) + | ScalarValue::UInt8(Some(0)) + | ScalarValue::UInt16(Some(0)) + | ScalarValue::UInt32(Some(0)) + | ScalarValue::UInt64(Some(0)) => true, + ScalarValue::Float16(Some(v)) if v.to_f32() == 0.0 => true, + ScalarValue::Float32(Some(v)) if *v == 0.0 => true, + ScalarValue::Float64(Some(v)) if *v == 0.0 => true, + _ => false, + } +} + /// Create a CASE expression pub fn case( expr: Option>, @@ -2298,6 +2468,65 @@ mod tests { Ok(()) } + #[test] + fn test_divide_by_zero_protection_specialization() -> Result<()> { + let batch = case_test_batch1()?; + let schema = batch.schema(); + + // CASE WHEN a > 0 THEN 25.0 / cast(a, float64) ELSE NULL END + let when = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &schema)?; + let then = binary( + lit(25.0f64), + Operator::Divide, + cast(col("a", &schema)?, &schema, Float64)?, + &schema, + )?; + + let expr = CaseExpr::try_new(None, vec![(when, then)], None)?; + + assert!( + matches!(expr.eval_method, EvalMethod::DivideByZeroProtection), + "Expected DivideByZeroProtection, got {:?}", + expr.eval_method + ); + + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = as_float64_array(&result)?; + + let expected = &Float64Array::from(vec![Some(25.0), None, None, Some(5.0)]); + assert_eq!(expected, result); + + Ok(()) + } + + #[test] + fn test_divide_by_zero_protection_specialization_not_applied() -> Result<()> { + let batch = case_test_batch1()?; + let schema = batch.schema(); + + // CASE WHEN a > 0 THEN b / c ELSE NULL END + // Divisor (c) != checked operand (a), should NOT use specialization + let when = binary(col("a", &schema)?, Operator::Gt, lit(0i32), &schema)?; + let then = binary( + col("b", &schema)?, + Operator::Divide, + col("c", &schema)?, + &schema, + )?; + + let expr = CaseExpr::try_new(None, vec![(when, then)], None)?; + + assert!( + !matches!(expr.eval_method, EvalMethod::DivideByZeroProtection), + "Should NOT use DivideByZeroProtection when divisor doesn't match" + ); + + Ok(()) + } + fn make_col(name: &str, index: usize) -> Arc { Arc::new(Column::new(name, index)) }