Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 25 additions & 31 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2686,47 +2686,41 @@ impl DefaultPhysicalPlanner {
schema: &Schema,
) -> Result<PlanAsyncExpr> {
let mut async_map = AsyncMapper::new(num_input_columns);
match &physical_expr {
let new_physical_expr = match physical_expr {
PlannedExprResult::ExprWithName(exprs) => {
exprs
let new_exprs = exprs
.iter()
.try_for_each(|(expr, _)| async_map.find_references(expr, schema))?;
.map(|(expr, name)| {
// find_and_map will:
// 1. Identify nested async UDFs bottom-up
// 2. Rewrite them to Columns referencing the async_map
// 3. Return the fully rewritten expression
let new_expr = async_map.find_and_map(expr, schema)?;
Ok((new_expr, name.clone()))
})
.collect::<Result<Vec<_>>>()?;
PlannedExprResult::ExprWithName(new_exprs)
}
PlannedExprResult::Expr(exprs) => {
exprs
let new_exprs = exprs
.iter()
.try_for_each(|expr| async_map.find_references(expr, schema))?;
.map(|expr| {
let new_expr = async_map.find_and_map(expr, schema)?;
Ok(new_expr)
})
.collect::<Result<Vec<_>>>()?;
PlannedExprResult::Expr(new_exprs)
}
}
};

if async_map.is_empty() {
return Ok(PlanAsyncExpr::Sync(physical_expr));
// If no async exprs found, result is the same as input
// (though find_and_map returns clones, structural equality holds)
return Ok(PlanAsyncExpr::Sync(new_physical_expr));
}

let new_exprs = match physical_expr {
PlannedExprResult::ExprWithName(exprs) => PlannedExprResult::ExprWithName(
exprs
.iter()
.map(|(expr, column_name)| {
let new_expr = Arc::clone(expr)
.transform_up(|e| Ok(async_map.map_expr(e)))?;
Ok((new_expr.data, column_name.to_string()))
})
.collect::<Result<_>>()?,
),
PlannedExprResult::Expr(exprs) => PlannedExprResult::Expr(
exprs
.iter()
.map(|expr| {
let new_expr = Arc::clone(expr)
.transform_up(|e| Ok(async_map.map_expr(e)))?;
Ok(new_expr.data)
})
.collect::<Result<_>>()?,
),
};
// rewrite the projection's expressions in terms of the columns with the result of async evaluation
Ok(PlanAsyncExpr::Async(async_map, new_exprs))
// Pass the rewritten expressions
Ok(PlanAsyncExpr::Async(async_map, new_physical_expr))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ use std::sync::Arc;
use arrow::array::{Int32Array, RecordBatch, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
use async_trait::async_trait;
use datafusion::prelude::*;
use datafusion::dataframe::DataFrame;
use datafusion::execution::context::SessionContext;
use datafusion_common::Result;
use datafusion_common::test_util::format_batches;
use datafusion_common::{Result, assert_batches_eq};
use datafusion_expr::async_udf::{AsyncScalarUDF, AsyncScalarUDFImpl};
use datafusion_expr::{
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
Expand Down Expand Up @@ -66,24 +67,24 @@ fn register_table_and_udf() -> Result<SessionContext> {
async fn test_async_udf_with_non_modular_batch_size() -> Result<()> {
let ctx = register_table_and_udf()?;

let df = ctx
let df: DataFrame = ctx
.sql("SELECT id, test_async_udf(prompt) as result FROM test_table")
.await?;

let result = df.collect().await?;

assert_batches_eq!(
&[
"+----+---------+",
"| id | result |",
"+----+---------+",
"| 0 | prompt0 |",
"| 1 | prompt1 |",
"| 2 | prompt2 |",
"+----+---------+"
],
&result
);
let result: Vec<RecordBatch> = df.collect().await?;

let result_str = format_batches(&result)?.to_string();
let expected = [
"+----+---------+",
"| id | result |",
"+----+---------+",
"| 0 | prompt0 |",
"| 1 | prompt1 |",
"| 2 | prompt2 |",
"+----+---------+",
]
.join("\n");
assert_eq!(result_str.trim(), expected.trim());

Ok(())
}
Expand All @@ -93,13 +94,13 @@ async fn test_async_udf_with_non_modular_batch_size() -> Result<()> {
async fn test_async_udf_metrics() -> Result<()> {
let ctx = register_table_and_udf()?;

let df = ctx
let df: DataFrame = ctx
.sql(
"EXPLAIN ANALYZE SELECT id, test_async_udf(prompt) as result FROM test_table",
)
.await?;

let result = df.collect().await?;
let result: Vec<RecordBatch> = df.collect().await?;

let explain_analyze_str = format_batches(&result)?.to_string();
let async_func_exec_without_metrics =
Expand All @@ -113,6 +114,43 @@ async fn test_async_udf_metrics() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_nested_async_udf() -> Result<()> {
let ctx = register_table_and_udf()?;

let df: DataFrame = ctx
.sql(
"SELECT id, test_async_udf(test_async_udf(prompt)) as result FROM test_table",
)
.await?;

let result: Result<Vec<RecordBatch>> = df.collect().await;

// This is expected to succeed now
match &result {
Ok(batches) => {
// Check results
let result_str = format_batches(batches)?.to_string();
let expected = [
"+----+---------+",
"| id | result |",
"+----+---------+",
"| 0 | prompt0 |",
"| 1 | prompt1 |",
"| 2 | prompt2 |",
"+----+---------+",
]
.join("\n");
assert_eq!(result_str.trim(), expected.trim());
}
Err(e) => {
panic!("Nested async UDF failed: {e}");
}
}

Ok(())
}

#[derive(Debug, PartialEq, Eq, Hash, Clone)]
struct TestAsyncUDFImpl {
batch_size: usize,
Expand Down
Loading