diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index b1aa850284aee..95445cadf9125 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -2686,47 +2686,41 @@ impl DefaultPhysicalPlanner { schema: &Schema, ) -> Result { let mut async_map = AsyncMapper::new(num_input_columns); - match &physical_expr { + let new_physical_expr = match physical_expr { PlannedExprResult::ExprWithName(exprs) => { - exprs + let new_exprs = exprs .iter() - .try_for_each(|(expr, _)| async_map.find_references(expr, schema))?; + .map(|(expr, name)| { + // find_and_map will: + // 1. Identify nested async UDFs bottom-up + // 2. Rewrite them to Columns referencing the async_map + // 3. Return the fully rewritten expression + let new_expr = async_map.find_and_map(expr, schema)?; + Ok((new_expr, name.clone())) + }) + .collect::>>()?; + PlannedExprResult::ExprWithName(new_exprs) } PlannedExprResult::Expr(exprs) => { - exprs + let new_exprs = exprs .iter() - .try_for_each(|expr| async_map.find_references(expr, schema))?; + .map(|expr| { + let new_expr = async_map.find_and_map(expr, schema)?; + Ok(new_expr) + }) + .collect::>>()?; + PlannedExprResult::Expr(new_exprs) } - } + }; if async_map.is_empty() { - return Ok(PlanAsyncExpr::Sync(physical_expr)); + // If no async exprs found, result is the same as input + // (though find_and_map returns clones, structural equality holds) + return Ok(PlanAsyncExpr::Sync(new_physical_expr)); } - let new_exprs = match physical_expr { - PlannedExprResult::ExprWithName(exprs) => PlannedExprResult::ExprWithName( - exprs - .iter() - .map(|(expr, column_name)| { - let new_expr = Arc::clone(expr) - .transform_up(|e| Ok(async_map.map_expr(e)))?; - Ok((new_expr.data, column_name.to_string())) - }) - .collect::>()?, - ), - PlannedExprResult::Expr(exprs) => PlannedExprResult::Expr( - exprs - .iter() - .map(|expr| { - let new_expr = Arc::clone(expr) - .transform_up(|e| Ok(async_map.map_expr(e)))?; - Ok(new_expr.data) - }) - .collect::>()?, - ), - }; - // rewrite the projection's expressions in terms of the columns with the result of async evaluation - Ok(PlanAsyncExpr::Async(async_map, new_exprs)) + // Pass the rewritten expressions + Ok(PlanAsyncExpr::Async(async_map, new_physical_expr)) } } diff --git a/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs index 31af4445ace08..6595b38e0ef5f 100644 --- a/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_async_scalar_functions.rs @@ -20,9 +20,10 @@ use std::sync::Arc; use arrow::array::{Int32Array, RecordBatch, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use async_trait::async_trait; -use datafusion::prelude::*; +use datafusion::dataframe::DataFrame; +use datafusion::execution::context::SessionContext; +use datafusion_common::Result; use datafusion_common::test_util::format_batches; -use datafusion_common::{Result, assert_batches_eq}; use datafusion_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl}; use datafusion_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, @@ -66,24 +67,24 @@ fn register_table_and_udf() -> Result { async fn test_async_udf_with_non_modular_batch_size() -> Result<()> { let ctx = register_table_and_udf()?; - let df = ctx + let df: DataFrame = ctx .sql("SELECT id, test_async_udf(prompt) as result FROM test_table") .await?; - let result = df.collect().await?; - - assert_batches_eq!( - &[ - "+----+---------+", - "| id | result |", - "+----+---------+", - "| 0 | prompt0 |", - "| 1 | prompt1 |", - "| 2 | prompt2 |", - "+----+---------+" - ], - &result - ); + let result: Vec = df.collect().await?; + + let result_str = format_batches(&result)?.to_string(); + let expected = [ + "+----+---------+", + "| id | result |", + "+----+---------+", + "| 0 | prompt0 |", + "| 1 | prompt1 |", + "| 2 | prompt2 |", + "+----+---------+", + ] + .join("\n"); + assert_eq!(result_str.trim(), expected.trim()); Ok(()) } @@ -93,13 +94,13 @@ async fn test_async_udf_with_non_modular_batch_size() -> Result<()> { async fn test_async_udf_metrics() -> Result<()> { let ctx = register_table_and_udf()?; - let df = ctx + let df: DataFrame = ctx .sql( "EXPLAIN ANALYZE SELECT id, test_async_udf(prompt) as result FROM test_table", ) .await?; - let result = df.collect().await?; + let result: Vec = df.collect().await?; let explain_analyze_str = format_batches(&result)?.to_string(); let async_func_exec_without_metrics = @@ -113,6 +114,43 @@ async fn test_async_udf_metrics() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_nested_async_udf() -> Result<()> { + let ctx = register_table_and_udf()?; + + let df: DataFrame = ctx + .sql( + "SELECT id, test_async_udf(test_async_udf(prompt)) as result FROM test_table", + ) + .await?; + + let result: Result> = df.collect().await; + + // This is expected to succeed now + match &result { + Ok(batches) => { + // Check results + let result_str = format_batches(batches)?.to_string(); + let expected = [ + "+----+---------+", + "| id | result |", + "+----+---------+", + "| 0 | prompt0 |", + "| 1 | prompt1 |", + "| 2 | prompt2 |", + "+----+---------+", + ] + .join("\n"); + assert_eq!(result_str.trim(), expected.trim()); + } + Err(e) => { + panic!("Nested async UDF failed: {e}"); + } + } + + Ok(()) +} + #[derive(Debug, PartialEq, Eq, Hash, Clone)] struct TestAsyncUDFImpl { batch_size: usize, diff --git a/datafusion/physical-plan/src/async_func.rs b/datafusion/physical-plan/src/async_func.rs index a61fd95949d1a..b88fd02da5339 100644 --- a/datafusion/physical-plan/src/async_func.rs +++ b/datafusion/physical-plan/src/async_func.rs @@ -22,14 +22,16 @@ use crate::{ DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, }; use arrow::array::RecordBatch; -use arrow_schema::{Fields, Schema, SchemaRef}; +use arrow_schema::{Field, Fields, Schema, SchemaRef}; use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion_physical_expr::expressions::Column; + use datafusion_common::{Result, assert_eq_or_internal_err}; use datafusion_execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext}; use datafusion_physical_expr::ScalarFunctionExpr; use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr; use datafusion_physical_expr::equivalence::ProjectionMapping; -use datafusion_physical_expr::expressions::Column; + use datafusion_physical_expr_common::metrics::{BaselineMetrics, RecordOutput}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use futures::Stream; @@ -59,10 +61,14 @@ impl AsyncFuncExec { async_exprs: Vec>, input: Arc, ) -> Result { - let async_fields = async_exprs - .iter() - .map(|async_expr| async_expr.field(input.schema().as_ref())) - .collect::>>()?; + let mut current_fields = input.schema().fields().to_vec(); + let mut async_fields = Vec::with_capacity(async_exprs.len()); + for async_expr in &async_exprs { + let current_schema = Schema::new(current_fields.clone()); + let field = async_expr.field(¤t_schema)?; + current_fields.push(Arc::new(field.clone())); + async_fields.push(field); + } // compute the output schema: input schema then async expressions let fields: Fields = input @@ -74,10 +80,25 @@ impl AsyncFuncExec { .collect(); let schema = Arc::new(Schema::new(fields)); - let tuples = async_exprs - .iter() - .map(|expr| (Arc::clone(&expr.func), expr.name().to_string())) - .collect::>(); + + // Only include expressions that map to input columns in the ProjectionMapping + // Expressions referencing newly created async columns cannot be verified against input schema + let input_len = input.schema().fields().len(); + let mut tuples = Vec::new(); + for expr in &async_exprs { + let mut refers_to_new_cols = false; + expr.func.apply(&mut |e: &Arc| { + if let Some(col) = e.as_any().downcast_ref::() { + refers_to_new_cols |= col.index() >= input_len; + } + Ok(TreeNodeRecursion::Continue) + })?; + + if !refers_to_new_cols { + tuples.push((Arc::clone(&expr.func), expr.name().to_string())); + } + } + let async_expr_mapping = ProjectionMapping::try_new(tuples, &input.schema())?; let cache = AsyncFuncExec::compute_properties(&input, schema, &async_expr_mapping)?; @@ -216,14 +237,38 @@ impl ExecutionPlan for AsyncFuncExec { async move { let batch = batch?; // append the result of evaluating the async expressions to the output - let mut output_arrays = batch.columns().to_vec(); - for async_expr in async_exprs_captured.iter() { + // We must evaluate them in order, adding the results to the batch + // so that subsequent async expressions can access the results of previous ones + let mut output_arrays = Vec::with_capacity(async_exprs_captured.len()); + let input_columns = batch.columns().len(); + + for (i, async_expr) in async_exprs_captured.iter().enumerate() { + // Create a batch with the input columns and the async columns evaluated so far + // We need to construct a schema for this intermediate batch + let current_schema_fields: Vec<_> = schema_captured + .fields() + .iter() + .take(input_columns + i) + .cloned() + .collect(); + let current_schema = Arc::new(Schema::new(current_schema_fields)); + + let mut current_columns = batch.columns().to_vec(); + current_columns.extend_from_slice(&output_arrays); + + let current_batch = + RecordBatch::try_new(current_schema, current_columns)?; + let output = async_expr - .invoke_with_args(&batch, Arc::clone(&config_options)) + .invoke_with_args(¤t_batch, Arc::clone(&config_options)) .await?; output_arrays.push(output.to_array(batch.num_rows())?); } - let batch = RecordBatch::try_new(schema_captured, output_arrays)?; + + let mut final_columns = batch.columns().to_vec(); + final_columns.extend(output_arrays); + + let batch = RecordBatch::try_new(schema_captured, final_columns)?; Ok(batch.record_output(&baseline_metrics_captured)) } @@ -296,6 +341,8 @@ pub struct AsyncMapper { num_input_columns: usize, /// the expressions to map pub async_exprs: Vec>, + /// the output fields of the async expressions + output_fields: Vec, } impl AsyncMapper { @@ -303,6 +350,7 @@ impl AsyncMapper { Self { num_input_columns, async_exprs: Vec::new(), + output_fields: Vec::new(), } } @@ -315,58 +363,49 @@ impl AsyncMapper { } /// Finds any references to async functions in the expression and adds them to the map - pub fn find_references( + /// AND rewrites the expression to use the mapped columns. + pub fn find_and_map( &mut self, physical_expr: &Arc, schema: &Schema, - ) -> Result<()> { - // recursively look for references to async functions - physical_expr.apply(|expr| { + ) -> Result> { + let transformed = Arc::clone(physical_expr).transform_up(|expr| { if let Some(scalar_func_expr) = expr.as_any().downcast_ref::() && scalar_func_expr.fun().as_async().is_some() { let next_name = self.next_column_name(); - self.async_exprs.push(Arc::new(AsyncFuncExpr::try_new( - next_name, - Arc::clone(expr), - schema, - )?)); + + // Construct extended schema including previously mapped async fields + let mut current_fields = schema.fields().to_vec(); + current_fields.extend( + self.output_fields + .iter() + .map(|f: &Field| Arc::new(f.clone())), + ); + let current_schema = Schema::new(current_fields); + + // We use the expression with its children already transformed + let async_expr = Arc::new(AsyncFuncExpr::try_new( + next_name.clone(), + Arc::clone(&expr), + ¤t_schema, + )?); + + // Store the output field for subsequent expressions + self.output_fields.push(async_expr.field(¤t_schema)?); + self.async_exprs.push(async_expr); + + // Replace with Column + let output_idx = self.num_input_columns + self.async_exprs.len() - 1; + Ok(Transformed::yes(Arc::new(Column::new( + &next_name, output_idx, + )))) + } else { + Ok(Transformed::no(expr)) } - Ok(TreeNodeRecursion::Continue) })?; - Ok(()) - } - - /// If the expression matches any of the async functions, return the new column - pub fn map_expr( - &self, - expr: Arc, - ) -> Transformed> { - // find the first matching async function if any - let Some(idx) = - self.async_exprs - .iter() - .enumerate() - .find_map(|(idx, async_expr)| { - if async_expr.func == Arc::clone(&expr) { - Some(idx) - } else { - None - } - }) - else { - return Transformed::no(expr); - }; - // rewrite in terms of the output column - Transformed::yes(self.output_column(idx)) - } - - /// return the output column for the async function at index idx - pub fn output_column(&self, idx: usize) -> Arc { - let async_expr = &self.async_exprs[idx]; - let output_idx = self.num_input_columns + idx; - Arc::new(Column::new(async_expr.name(), output_idx)) + Ok(transformed.data) } } diff --git a/datafusion/sqllogictest/test_files/async_udf.slt b/datafusion/sqllogictest/test_files/async_udf.slt index 0708b59e519a0..ca8d579bb5a65 100644 --- a/datafusion/sqllogictest/test_files/async_udf.slt +++ b/datafusion/sqllogictest/test_files/async_udf.slt @@ -99,3 +99,11 @@ physical_plan 01)ProjectionExec: expr=[__async_fn_0@1 as async_abs(data.x)] 02)--AsyncFuncExec: async_expr=[async_expr(name=__async_fn_0, expr=async_abs(x@0))] 03)----DataSourceExec: partitions=1, partition_sizes=[1] + +# Async udf with nesting +query I rowsort +select async_abs(async_abs(x)) from data; +---- +10 +2 +