1717
1818mod literal_lookup_table;
1919
20- use super :: { Column , Literal } ;
20+ use super :: { CastExpr , Column , Literal } ;
2121use crate :: PhysicalExpr ;
22- use crate :: expressions:: { lit, try_cast} ;
22+ use crate :: expressions:: { BinaryExpr , lit, try_cast} ;
2323use arrow:: array:: * ;
2424use arrow:: compute:: kernels:: zip:: zip;
2525use arrow:: compute:: {
@@ -33,6 +33,7 @@ use datafusion_common::{
3333 internal_datafusion_err, internal_err,
3434} ;
3535use datafusion_expr:: ColumnarValue ;
36+ use datafusion_expr_common:: operator:: Operator ;
3637use indexmap:: { IndexMap , IndexSet } ;
3738use std:: borrow:: Cow ;
3839use std:: hash:: Hash ;
@@ -81,6 +82,14 @@ enum EvalMethod {
8182 ///
8283 /// See [`LiteralLookupTable`] for more details
8384 WithExprScalarLookupTable ( LiteralLookupTable ) ,
85+
86+ /// This is a specialization for divide-by-zero protection pattern:
87+ /// CASE WHEN y > 0 THEN x / y ELSE NULL END
88+ /// CASE WHEN y != 0 THEN x / y ELSE NULL END
89+ ///
90+ /// Instead of evaluating the full CASE expression, it is preferred to directly perform division
91+ /// that return NULL when the divisor is zero.
92+ DivideByZeroProtection ,
8493}
8594
8695/// Implementing hash so we can use `derive` on [`EvalMethod`].
@@ -647,6 +656,19 @@ impl CaseExpr {
647656 return Ok ( EvalMethod :: WithExpression ( body. project ( ) ?) ) ;
648657 }
649658
659+ // Check for divide-by-zero protection pattern:
660+ // CASE WHEN y > 0 THEN x / y ELSE NULL END
661+ if body. when_then_expr . len ( ) == 1 && body. else_expr . is_none ( ) {
662+ let ( when_expr, then_expr) = & body. when_then_expr [ 0 ] ;
663+
664+ if let Some ( checked_operand) = Self :: extract_non_zero_operand ( when_expr)
665+ && let Some ( ( _numerator, divisor) ) = Self :: extract_division_operands ( then_expr)
666+ && divisor. eq ( & checked_operand)
667+ {
668+ return Ok ( EvalMethod :: DivideByZeroProtection ) ;
669+ }
670+ }
671+
650672 Ok (
651673 if body. when_then_expr . len ( ) == 1
652674 && is_cheap_and_infallible ( & ( body. when_then_expr [ 0 ] . 1 ) )
@@ -681,6 +703,67 @@ impl CaseExpr {
681703 pub fn else_expr ( & self ) -> Option < & Arc < dyn PhysicalExpr > > {
682704 self . body . else_expr . as_ref ( )
683705 }
706+
707+ /// Extract the operand being checked for non-zero from a comparison expression.
708+ /// Return Some(operand) for patterns like `y > 0`, `y != 0`, `0 < y`, `0 != y`.
709+ fn extract_non_zero_operand (
710+ expr : & Arc < dyn PhysicalExpr > ,
711+ ) -> Option < Arc < dyn PhysicalExpr > > {
712+ let binary = expr. as_any ( ) . downcast_ref :: < BinaryExpr > ( ) ?;
713+
714+ match binary. op ( ) {
715+ // y > 0 or y != 0
716+ Operator :: Gt | Operator :: NotEq if Self :: is_literal_zero ( binary. right ( ) ) => {
717+ Some ( Arc :: clone ( binary. left ( ) ) )
718+ }
719+ // 0 < y or 0 != y
720+ Operator :: Lt | Operator :: NotEq if Self :: is_literal_zero ( binary. left ( ) ) => {
721+ Some ( Arc :: clone ( binary. right ( ) ) )
722+ }
723+ _ => None ,
724+ }
725+ }
726+
727+ /// Extract (numerator, divisor) from a division expression.
728+ fn extract_division_operands (
729+ expr : & Arc < dyn PhysicalExpr > ,
730+ ) -> Option < ( Arc < dyn PhysicalExpr > , Arc < dyn PhysicalExpr > ) > {
731+ let binary = expr. as_any ( ) . downcast_ref :: < BinaryExpr > ( ) ?;
732+
733+ if binary. op ( ) == & Operator :: Divide {
734+ let divisor =
735+ if let Some ( cast) = binary. right ( ) . as_any ( ) . downcast_ref :: < CastExpr > ( ) {
736+ Arc :: clone ( cast. expr ( ) )
737+ } else {
738+ Arc :: clone ( binary. right ( ) )
739+ } ;
740+ Some ( ( Arc :: clone ( binary. left ( ) ) , divisor) )
741+ } else {
742+ None
743+ }
744+ }
745+
746+ /// Check if an expression is a literal zero value
747+ fn is_literal_zero ( expr : & Arc < dyn PhysicalExpr > ) -> bool {
748+ if let Some ( lit) = expr. as_any ( ) . downcast_ref :: < Literal > ( ) {
749+ match lit. value ( ) {
750+ ScalarValue :: Int8 ( Some ( 0 ) )
751+ | ScalarValue :: Int16 ( Some ( 0 ) )
752+ | ScalarValue :: Int32 ( Some ( 0 ) )
753+ | ScalarValue :: Int64 ( Some ( 0 ) )
754+ | ScalarValue :: UInt8 ( Some ( 0 ) )
755+ | ScalarValue :: UInt16 ( Some ( 0 ) )
756+ | ScalarValue :: UInt32 ( Some ( 0 ) )
757+ | ScalarValue :: UInt64 ( Some ( 0 ) ) => true ,
758+ ScalarValue :: Float16 ( Some ( v) ) if v. to_f32 ( ) == 0.0 => true ,
759+ ScalarValue :: Float32 ( Some ( v) ) if * v == 0.0 => true ,
760+ ScalarValue :: Float64 ( Some ( v) ) if * v == 0.0 => true ,
761+ _ => false ,
762+ }
763+ } else {
764+ false
765+ }
766+ }
684767}
685768
686769impl CaseBody {
@@ -1170,6 +1253,19 @@ impl CaseExpr {
11701253
11711254 Ok ( result)
11721255 }
1256+
1257+ fn divide_by_zero_protection ( & self , batch : & RecordBatch ) -> Result < ColumnarValue > {
1258+ let then_expr = & self . body . when_then_expr [ 0 ] . 1 ;
1259+ let binary = then_expr
1260+ . as_any ( )
1261+ . downcast_ref :: < BinaryExpr > ( )
1262+ . expect ( "then expression should be a binary expression" ) ;
1263+
1264+ let numerator = binary. left ( ) . evaluate ( batch) ?;
1265+ let divisor = binary. right ( ) . evaluate ( batch) ?;
1266+
1267+ safe_divide ( & numerator, & divisor)
1268+ }
11731269}
11741270
11751271impl PhysicalExpr for CaseExpr {
@@ -1268,6 +1364,7 @@ impl PhysicalExpr for CaseExpr {
12681364 EvalMethod :: WithExprScalarLookupTable ( lookup_table) => {
12691365 self . with_lookup_table ( batch, lookup_table)
12701366 }
1367+ EvalMethod :: DivideByZeroProtection => self . divide_by_zero_protection ( batch) ,
12711368 }
12721369 }
12731370
@@ -1389,6 +1486,78 @@ fn replace_with_null(
13891486 Ok ( with_null)
13901487}
13911488
1489+ fn safe_divide (
1490+ numerator : & ColumnarValue ,
1491+ divisor : & ColumnarValue ,
1492+ ) -> Result < ColumnarValue > {
1493+ if let ColumnarValue :: Scalar ( div_scalar) = divisor
1494+ && is_scalar_zero ( div_scalar)
1495+ {
1496+ let data_type = numerator. data_type ( ) ;
1497+ return match numerator {
1498+ ColumnarValue :: Array ( arr) => {
1499+ Ok ( ColumnarValue :: Array ( new_null_array ( & data_type, arr. len ( ) ) ) )
1500+ }
1501+ ColumnarValue :: Scalar ( _) => Ok ( ColumnarValue :: Scalar (
1502+ ScalarValue :: try_new_null ( & data_type) ?,
1503+ ) ) ,
1504+ } ;
1505+ }
1506+
1507+ let num_rows = match ( numerator, divisor) {
1508+ ( ColumnarValue :: Array ( arr) , _) => arr. len ( ) ,
1509+ ( _, ColumnarValue :: Array ( arr) ) => arr. len ( ) ,
1510+ _ => 1 ,
1511+ } ;
1512+
1513+ let num_array = numerator. clone ( ) . into_array ( num_rows) ?;
1514+ let div_array = divisor. clone ( ) . into_array ( num_rows) ?;
1515+
1516+ let result = safe_divide_arrays ( & num_array, & div_array) ?;
1517+
1518+ if matches ! ( numerator, ColumnarValue :: Scalar ( _) )
1519+ && matches ! ( divisor, ColumnarValue :: Scalar ( _) )
1520+ {
1521+ Ok ( ColumnarValue :: Scalar ( ScalarValue :: try_from_array (
1522+ & result, 0 ,
1523+ ) ?) )
1524+ } else {
1525+ Ok ( ColumnarValue :: Array ( result) )
1526+ }
1527+ }
1528+
1529+ fn safe_divide_arrays ( numerator : & ArrayRef , divisor : & ArrayRef ) -> Result < ArrayRef > {
1530+ use arrow:: compute:: kernels:: cmp:: eq;
1531+ use arrow:: compute:: kernels:: numeric:: div;
1532+
1533+ let zero = ScalarValue :: new_zero ( divisor. data_type ( ) ) ?. to_scalar ( ) ?;
1534+ let zero_mask = eq ( divisor, & zero) ?;
1535+
1536+ let ones = ScalarValue :: new_one ( divisor. data_type ( ) ) ?. to_scalar ( ) ?;
1537+ let safe_divisor = zip ( & zero_mask, & ones, divisor) ?;
1538+
1539+ let result = div ( & numerator, & safe_divisor) ?;
1540+
1541+ Ok ( nullif ( & result, & zero_mask) ?)
1542+ }
1543+
1544+ fn is_scalar_zero ( scalar : & ScalarValue ) -> bool {
1545+ match scalar {
1546+ ScalarValue :: Int8 ( Some ( 0 ) )
1547+ | ScalarValue :: Int16 ( Some ( 0 ) )
1548+ | ScalarValue :: Int32 ( Some ( 0 ) )
1549+ | ScalarValue :: Int64 ( Some ( 0 ) )
1550+ | ScalarValue :: UInt8 ( Some ( 0 ) )
1551+ | ScalarValue :: UInt16 ( Some ( 0 ) )
1552+ | ScalarValue :: UInt32 ( Some ( 0 ) )
1553+ | ScalarValue :: UInt64 ( Some ( 0 ) ) => true ,
1554+ ScalarValue :: Float16 ( Some ( v) ) if v. to_f32 ( ) == 0.0 => true ,
1555+ ScalarValue :: Float32 ( Some ( v) ) if * v == 0.0 => true ,
1556+ ScalarValue :: Float64 ( Some ( v) ) if * v == 0.0 => true ,
1557+ _ => false ,
1558+ }
1559+ }
1560+
13921561/// Create a CASE expression
13931562pub fn case (
13941563 expr : Option < Arc < dyn PhysicalExpr > > ,
@@ -2298,6 +2467,65 @@ mod tests {
22982467 Ok ( ( ) )
22992468 }
23002469
2470+ #[ test]
2471+ fn test_divide_by_zero_protection_specialization ( ) -> Result < ( ) > {
2472+ let batch = case_test_batch1 ( ) ?;
2473+ let schema = batch. schema ( ) ;
2474+
2475+ // CASE WHEN a > 0 THEN 25.0 / cast(a, float64) ELSE NULL END
2476+ let when = binary ( col ( "a" , & schema) ?, Operator :: Gt , lit ( 0i32 ) , & schema) ?;
2477+ let then = binary (
2478+ lit ( 25.0f64 ) ,
2479+ Operator :: Divide ,
2480+ cast ( col ( "a" , & schema) ?, & schema, Float64 ) ?,
2481+ & schema,
2482+ ) ?;
2483+
2484+ let expr = CaseExpr :: try_new ( None , vec ! [ ( when, then) ] , None ) ?;
2485+
2486+ assert ! (
2487+ matches!( expr. eval_method, EvalMethod :: DivideByZeroProtection ) ,
2488+ "Expected DivideByZeroProtection, got {:?}" ,
2489+ expr. eval_method
2490+ ) ;
2491+
2492+ let result = expr
2493+ . evaluate ( & batch) ?
2494+ . into_array ( batch. num_rows ( ) )
2495+ . expect ( "Failed to convert to array" ) ;
2496+ let result = as_float64_array ( & result) ?;
2497+
2498+ let expected = & Float64Array :: from ( vec ! [ Some ( 25.0 ) , None , None , Some ( 5.0 ) ] ) ;
2499+ assert_eq ! ( expected, result) ;
2500+
2501+ Ok ( ( ) )
2502+ }
2503+
2504+ #[ test]
2505+ fn test_divide_by_zero_protection_specialization_not_applied ( ) -> Result < ( ) > {
2506+ let batch = case_test_batch1 ( ) ?;
2507+ let schema = batch. schema ( ) ;
2508+
2509+ // CASE WHEN a > 0 THEN b / c ELSE NULL END
2510+ // Divisor (c) != checked operand (a), should NOT use specialization
2511+ let when = binary ( col ( "a" , & schema) ?, Operator :: Gt , lit ( 0i32 ) , & schema) ?;
2512+ let then = binary (
2513+ col ( "b" , & schema) ?,
2514+ Operator :: Divide ,
2515+ col ( "c" , & schema) ?,
2516+ & schema,
2517+ ) ?;
2518+
2519+ let expr = CaseExpr :: try_new ( None , vec ! [ ( when, then) ] , None ) ?;
2520+
2521+ assert ! (
2522+ !matches!( expr. eval_method, EvalMethod :: DivideByZeroProtection ) ,
2523+ "Should NOT use DivideByZeroProtection when divisor doesn't match"
2524+ ) ;
2525+
2526+ Ok ( ( ) )
2527+ }
2528+
23012529 fn make_col ( name : & str , index : usize ) -> Arc < dyn PhysicalExpr > {
23022530 Arc :: new ( Column :: new ( name, index) )
23032531 }
0 commit comments