diff --git a/crates/core/src/context.rs b/crates/core/src/context.rs index 642afeef7..3a3adb3e7 100644 --- a/crates/core/src/context.rs +++ b/crates/core/src/context.rs @@ -59,7 +59,8 @@ use datafusion_proto::physical_plan::PhysicalExtensionCodec; use datafusion_python_util::{ create_logical_extension_capsule, create_physical_extension_capsule, ffi_logical_codec_from_pycapsule, get_global_ctx, get_tokio_runtime, - physical_codec_from_pycapsule, spawn_future, wait_for_future, + physical_codec_from_pycapsule, physical_optimizer_rule_from_pycapsule, spawn_future, + wait_for_future, }; use object_store::ObjectStore; use pyo3::IntoPyObjectExt; @@ -375,11 +376,12 @@ pub struct PySessionContext { #[pymethods] impl PySessionContext { - #[pyo3(signature = (config=None, runtime=None))] + #[pyo3(signature = (config=None, runtime=None, physical_optimizer_rules=None))] #[new] pub fn new( config: Option, runtime: Option, + physical_optimizer_rules: Option>>, ) -> PyDataFusionResult { let config = if let Some(c) = config { c.config @@ -392,11 +394,15 @@ impl PySessionContext { RuntimeEnvBuilder::default() }; let runtime = Arc::new(runtime_env_builder.build()?); - let session_state = SessionStateBuilder::new() + let mut state_builder = SessionStateBuilder::new() .with_config(config) .with_runtime_env(runtime) - .with_default_features() - .build(); + .with_default_features(); + for rule in physical_optimizer_rules.unwrap_or_default() { + let rule = physical_optimizer_rule_from_pycapsule(&rule)?; + state_builder = state_builder.with_physical_optimizer_rule(rule); + } + let session_state = state_builder.build(); let ctx = Arc::new(SessionContext::new_with_state(session_state)); Ok(PySessionContext { ctx, diff --git a/crates/util/src/lib.rs b/crates/util/src/lib.rs index 72dc9aafc..28c8834e9 100644 --- a/crates/util/src/lib.rs +++ b/crates/util/src/lib.rs @@ -24,7 +24,9 @@ use datafusion::datasource::TableProvider; use datafusion::execution::TaskContext; use datafusion::execution::context::SessionContext; use datafusion::logical_expr::Volatility; +use datafusion::physical_optimizer::PhysicalOptimizerRule; use datafusion_ffi::execution::FFI_TaskContextProvider; +use datafusion_ffi::physical_optimizer::FFI_PhysicalOptimizerRule; use datafusion_ffi::proto::logical_extension_codec::FFI_LogicalExtensionCodec; use datafusion_ffi::proto::physical_extension_codec::FFI_PhysicalExtensionCodec; use datafusion_ffi::table_provider::FFI_TableProvider; @@ -332,6 +334,13 @@ from_pycapsule!( dyn PhysicalExtensionCodec ); +from_pycapsule!( + physical_optimizer_rule_from_pycapsule, + "datafusion_physical_optimizer_rule", + FFI_PhysicalOptimizerRule, + dyn PhysicalOptimizerRule + Send + Sync +); + try_from_pycapsule!( task_context_from_pycapsule, "datafusion_task_context_provider", diff --git a/examples/datafusion-ffi-example/python/tests/_test_physical_optimizer_rule.py b/examples/datafusion-ffi-example/python/tests/_test_physical_optimizer_rule.py new file mode 100644 index 000000000..1eee07dcb --- /dev/null +++ b/examples/datafusion-ffi-example/python/tests/_test_physical_optimizer_rule.py @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pyarrow as pa +from datafusion import SessionContext +from datafusion_ffi_example import MyPhysicalOptimizerRule + + +def test_ffi_physical_optimizer_rule_runs_during_planning(): + """A rule supplied via physical_optimizer_rules is invoked while the + physical plan is built, and the query still returns correct results.""" + rule = MyPhysicalOptimizerRule() + ctx = SessionContext(physical_optimizer_rules=[rule]) + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3])], + names=["a"], + ) + ctx.register_record_batches("t", [[batch]]) + + before = rule.optimize_calls() + result = ctx.sql("SELECT a FROM t").collect() + after = rule.optimize_calls() + + assert after > before, ( + f"Expected user FFI physical optimizer rule to fire, " + f"before={before} after={after}" + ) + assert result[0].column(0).to_pylist() == [1, 2, 3] + + +def test_ffi_physical_optimizer_rule_export(): + """The rule object exposes the FFI capsule entry point.""" + rule = MyPhysicalOptimizerRule() + capsule = rule.__datafusion_physical_optimizer_rule__() + assert capsule is not None diff --git a/examples/datafusion-ffi-example/src/lib.rs b/examples/datafusion-ffi-example/src/lib.rs index 3323ac982..eccf7b81a 100644 --- a/examples/datafusion-ffi-example/src/lib.rs +++ b/examples/datafusion-ffi-example/src/lib.rs @@ -22,6 +22,7 @@ use crate::catalog_provider::{FixedSchemaProvider, MyCatalogProvider, MyCatalogP use crate::config::MyConfig; use crate::logical_extension_codec::MyLogicalExtensionCodec; use crate::physical_extension_codec::MyPhysicalExtensionCodec; +use crate::physical_optimizer::MyPhysicalOptimizerRule; use crate::scalar_udf::IsNullUDF; use crate::table_function::MyTableFunction; use crate::table_provider::MyTableProvider; @@ -33,6 +34,7 @@ pub(crate) mod catalog_provider; pub(crate) mod config; pub(crate) mod logical_extension_codec; pub(crate) mod physical_extension_codec; +pub(crate) mod physical_optimizer; pub(crate) mod scalar_udf; pub(crate) mod table_function; pub(crate) mod table_provider; @@ -55,5 +57,6 @@ fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/examples/datafusion-ffi-example/src/physical_optimizer.rs b/examples/datafusion-ffi-example/src/physical_optimizer.rs new file mode 100644 index 000000000..0acd1bb4a --- /dev/null +++ b/examples/datafusion-ffi-example/src/physical_optimizer.rs @@ -0,0 +1,98 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use datafusion::common::Result; +use datafusion::common::config::ConfigOptions; +use datafusion::physical_optimizer::PhysicalOptimizerRule; +use datafusion::physical_plan::ExecutionPlan; +use datafusion_ffi::physical_optimizer::FFI_PhysicalOptimizerRule; +use datafusion_python_util::get_tokio_runtime; +use pyo3::prelude::*; +use pyo3::types::PyCapsule; + +/// A physical optimizer rule that leaves every plan unchanged but bumps a +/// shared counter each time it runs. Tests use the counter to prove that a +/// session built with this rule actually routed physical planning through a +/// user-supplied [`PhysicalOptimizerRule`] over FFI. +#[derive(Debug)] +struct CountingPhysicalOptimizerRule { + optimize_calls: Arc, +} + +impl PhysicalOptimizerRule for CountingPhysicalOptimizerRule { + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> Result> { + self.optimize_calls.fetch_add(1, Ordering::SeqCst); + Ok(plan) + } + + fn name(&self) -> &str { + "counting_physical_optimizer_rule" + } + + fn schema_check(&self) -> bool { + // The plan is returned unchanged, so the schema is preserved. + true + } +} + +/// Python-visible handle that produces an [`FFI_PhysicalOptimizerRule`] and +/// exposes the shared call counter. +#[pyclass( + from_py_object, + name = "MyPhysicalOptimizerRule", + module = "datafusion_ffi_example", + subclass +)] +#[derive(Debug, Default, Clone)] +pub(crate) struct MyPhysicalOptimizerRule { + optimize_calls: Arc, +} + +#[pymethods] +impl MyPhysicalOptimizerRule { + #[new] + fn new() -> Self { + Self::default() + } + + fn optimize_calls(&self) -> usize { + self.optimize_calls.load(Ordering::SeqCst) + } + + fn __datafusion_physical_optimizer_rule__<'py>( + &self, + py: Python<'py>, + ) -> PyResult> { + let rule: Arc = + Arc::new(CountingPhysicalOptimizerRule { + optimize_calls: Arc::clone(&self.optimize_calls), + }); + + let runtime = get_tokio_runtime().handle().clone(); + let ffi = FFI_PhysicalOptimizerRule::new(rule, Some(runtime)); + + let name = cr"datafusion_physical_optimizer_rule".into(); + PyCapsule::new(py, ffi, Some(name)) + } +} diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 5c3501941..f8fb016d7 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -130,6 +130,16 @@ class TableProviderExportable(Protocol): def __datafusion_table_provider__(self, session: Any) -> object: ... # noqa: D105 +class PhysicalOptimizerRuleExportable(Protocol): + """Type hint for object that has __datafusion_physical_optimizer_rule__ PyCapsule. + + The method returns a PyCapsule wrapping an ``FFI_PhysicalOptimizerRule``, + typically produced by a separate compiled extension. + """ + + def __datafusion_physical_optimizer_rule__(self) -> object: ... # noqa: D105 + + class SessionConfig: """Session configuration options.""" @@ -524,6 +534,7 @@ def __init__( self, config: SessionConfig | None = None, runtime: RuntimeEnvBuilder | None = None, + physical_optimizer_rules: list[PhysicalOptimizerRuleExportable] | None = None, ) -> None: """Main interface for executing queries with DataFusion. @@ -534,6 +545,11 @@ def __init__( Args: config: Session configuration options. runtime: Runtime configuration options. + physical_optimizer_rules: User-defined physical optimizer rules to + append to the default set, each a + :class:`PhysicalOptimizerRuleExportable`. There is no upstream + API to add physical rules to a live context, so these can only + be supplied at construction time. Example usage: @@ -544,11 +560,21 @@ def __init__( ctx = SessionContext() df = ctx.read_csv("data.csv") + + To register a physical optimizer rule supplied by a compiled + extension, pass it via ``physical_optimizer_rules``:: + + from datafusion import SessionContext + from my_extension import MyPhysicalOptimizerRule + + ctx = SessionContext( + physical_optimizer_rules=[MyPhysicalOptimizerRule()] + ) """ config = config.config_internal if config is not None else None runtime = runtime.config_internal if runtime is not None else None - self.ctx = SessionContextInternal(config, runtime) + self.ctx = SessionContextInternal(config, runtime, physical_optimizer_rules) def __repr__(self) -> str: """Print a string representation of the Session Context."""