diff --git a/examples/datafusion-ffi-example/Cargo.lock b/examples/datafusion-ffi-example/Cargo.lock index 075ebd5a..e5a1ca8d 100644 --- a/examples/datafusion-ffi-example/Cargo.lock +++ b/examples/datafusion-ffi-example/Cargo.lock @@ -1448,6 +1448,7 @@ dependencies = [ "arrow", "arrow-array", "arrow-schema", + "async-trait", "datafusion", "datafusion-ffi", "pyo3", diff --git a/examples/datafusion-ffi-example/Cargo.toml b/examples/datafusion-ffi-example/Cargo.toml index 0e17567b..31916355 100644 --- a/examples/datafusion-ffi-example/Cargo.toml +++ b/examples/datafusion-ffi-example/Cargo.toml @@ -27,6 +27,7 @@ pyo3 = { version = "0.23", features = ["extension-module", "abi3", "abi3-py39"] arrow = { version = "55.0.0" } arrow-array = { version = "55.0.0" } arrow-schema = { version = "55.0.0" } +async-trait = "0.1.88" [build-dependencies] pyo3-build-config = "0.23" diff --git a/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py b/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py new file mode 100644 index 00000000..517b983c --- /dev/null +++ b/examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py @@ -0,0 +1,117 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pyarrow as pa + +from datafusion import SessionContext, Table +from datafusion_ffi_example import MyCatalogProvider + +from datafusion.context import PyCatalogProvider, PySchemaProvider + + +def test_catalog_provider(): + ctx = SessionContext() + + my_catalog_name = "my_catalog" + expected_schema_name = "my_schema" + expected_table_name = "my_table" + expected_table_columns = ["units", "price"] + + catalog_provider = MyCatalogProvider() + ctx.register_catalog_provider(my_catalog_name, catalog_provider) + my_catalog = ctx.catalog(my_catalog_name) + + my_catalog_schemas = my_catalog.names() + assert expected_schema_name in my_catalog_schemas + my_database = my_catalog.database(expected_schema_name) + assert expected_table_name in my_database.names() + my_table = my_database.table(expected_table_name) + assert expected_table_columns == my_table.schema.names + + result = ctx.table( + f"{my_catalog_name}.{expected_schema_name}.{expected_table_name}" + ).collect() + assert len(result) == 2 + + col0_result = [r.column(0) for r in result] + col1_result = [r.column(1) for r in result] + expected_col0 = [ + pa.array([10, 20, 30], type=pa.int32()), + pa.array([5, 7], type=pa.int32()), + ] + expected_col1 = [ + pa.array([1, 2, 5], type=pa.float64()), + pa.array([1.5, 2.5], type=pa.float64()), + ] + assert col0_result == expected_col0 + assert col1_result == expected_col1 + + +class MyPyCatalogProvider(PyCatalogProvider): + my_schemas = ['my_schema'] + + def schema_names(self) -> list[str]: + return self.my_schemas + + def schema(self, name: str) -> PySchemaProvider: + return MyPySchemaProvider() + + +class MyPySchemaProvider(PySchemaProvider): + my_tables = ['table1', 'table2', 'table3'] + + def table_names(self) -> list[str]: + return self.my_tables + + def table_exist(self, table_name: str) -> bool: + return table_name in self.my_tables + + def table(self, table_name: str) -> Table: + raise RuntimeError(f"Can not get table: {table_name}") + + def register_table(self, table: Table) -> None: + raise RuntimeError(f"Can not register {table} as table") + + def deregister_table(self, table_name: str) -> None: + raise RuntimeError(f"Can not deregister table: {table_name}") + + +def test_python_catalog_provider(): + ctx = SessionContext() + + my_catalog_name = "my_py_catalog" + expected_schema_name = "my_schema" + my_py_catalog_provider = MyPyCatalogProvider() + ctx.register_catalog_provider(my_catalog_name, my_py_catalog_provider) + my_py_catalog = ctx.catalog(my_catalog_name) + assert MyPyCatalogProvider.my_schemas == my_py_catalog.names() + + my_database = my_py_catalog.database(expected_schema_name) + assert set(MyPySchemaProvider.my_tables) == my_database.names() + + # asserting a non-compliant provider fails at the python level as expected + try: + ctx.register_catalog_provider(my_catalog_name, "non_compliant_provider") + except TypeError: + # expect a TypeError because we can not register a str as a catalog provider + pass + + +if __name__ == "__main__": + test_python_catalog_provider() \ No newline at end of file diff --git a/examples/datafusion-ffi-example/src/catalog_provider.rs b/examples/datafusion-ffi-example/src/catalog_provider.rs new file mode 100644 index 00000000..ddffe8c6 --- /dev/null +++ b/examples/datafusion-ffi-example/src/catalog_provider.rs @@ -0,0 +1,174 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use pyo3::{pyclass, pymethods, Bound, PyResult, Python}; +use std::{any::Any, fmt::Debug, sync::Arc}; + +use arrow::datatypes::Schema; +use async_trait::async_trait; +use datafusion::{ + catalog::{ + CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider, TableProvider, + }, + common::exec_err, + datasource::MemTable, + error::{DataFusionError, Result}, +}; +use datafusion_ffi::catalog_provider::FFI_CatalogProvider; +use pyo3::types::PyCapsule; + +pub fn my_table() -> Arc { + use arrow::datatypes::{DataType, Field}; + use datafusion::common::record_batch; + + let schema = Arc::new(Schema::new(vec![ + Field::new("units", DataType::Int32, true), + Field::new("price", DataType::Float64, true), + ])); + + let partitions = vec![ + record_batch!( + ("units", Int32, vec![10, 20, 30]), + ("price", Float64, vec![1.0, 2.0, 5.0]) + ) + .unwrap(), + record_batch!( + ("units", Int32, vec![5, 7]), + ("price", Float64, vec![1.5, 2.5]) + ) + .unwrap(), + ]; + + Arc::new(MemTable::try_new(schema, vec![partitions]).unwrap()) +} + +#[derive(Debug)] +pub struct FixedSchemaProvider { + inner: MemorySchemaProvider, +} + +impl Default for FixedSchemaProvider { + fn default() -> Self { + let inner = MemorySchemaProvider::new(); + + let table = my_table(); + + let _ = inner.register_table("my_table".to_string(), table).unwrap(); + + Self { inner } + } +} + +#[async_trait] +impl SchemaProvider for FixedSchemaProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn table_names(&self) -> Vec { + self.inner.table_names() + } + + async fn table(&self, name: &str) -> Result>, DataFusionError> { + self.inner.table(name).await + } + + fn register_table( + &self, + name: String, + table: Arc, + ) -> Result>> { + self.inner.register_table(name, table) + } + + fn deregister_table(&self, name: &str) -> Result>> { + self.inner.deregister_table(name) + } + + fn table_exist(&self, name: &str) -> bool { + self.inner.table_exist(name) + } +} + +/// This catalog provider is intended only for unit tests. It prepopulates with one +/// schema and only allows for schemas named after four types of fruit. +#[pyclass(name = "MyCatalogProvider", module = "datafusion_ffi_example", subclass)] +#[derive(Debug)] +pub(crate) struct MyCatalogProvider { + inner: MemoryCatalogProvider, +} + +impl Default for MyCatalogProvider { + fn default() -> Self { + let inner = MemoryCatalogProvider::new(); + + let schema_name: &str = "my_schema"; + let _ = inner.register_schema(schema_name, Arc::new(FixedSchemaProvider::default())); + + Self { inner } + } +} + +impl CatalogProvider for MyCatalogProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema_names(&self) -> Vec { + self.inner.schema_names() + } + + fn schema(&self, name: &str) -> Option> { + self.inner.schema(name) + } + + fn register_schema( + &self, + name: &str, + schema: Arc, + ) -> Result>> { + self.inner.register_schema(name, schema) + } + + fn deregister_schema( + &self, + name: &str, + cascade: bool, + ) -> Result>> { + self.inner.deregister_schema(name, cascade) + } +} + +#[pymethods] +impl MyCatalogProvider { + #[new] + pub fn new() -> Self { + Self { + inner: Default::default(), + } + } + + pub fn __datafusion_catalog_provider__<'py>( + &self, + py: Python<'py>, + ) -> PyResult> { + let name = cr"datafusion_catalog_provider".into(); + let catalog_provider = FFI_CatalogProvider::new(Arc::new(MyCatalogProvider::default()), None); + + PyCapsule::new(py, catalog_provider, Some(name)) + } +} diff --git a/examples/datafusion-ffi-example/src/lib.rs b/examples/datafusion-ffi-example/src/lib.rs index ae08c3b6..3a4cf224 100644 --- a/examples/datafusion-ffi-example/src/lib.rs +++ b/examples/datafusion-ffi-example/src/lib.rs @@ -15,10 +15,12 @@ // specific language governing permissions and limitations // under the License. +use crate::catalog_provider::MyCatalogProvider; use crate::table_function::MyTableFunction; use crate::table_provider::MyTableProvider; use pyo3::prelude::*; +pub(crate) mod catalog_provider; pub(crate) mod table_function; pub(crate) mod table_provider; @@ -26,5 +28,6 @@ pub(crate) mod table_provider; fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 5b99b0d2..15ae5488 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -20,7 +20,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Any, Protocol +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable import pyarrow as pa @@ -79,6 +79,41 @@ class TableProviderExportable(Protocol): def __datafusion_table_provider__(self) -> object: ... # noqa: D105 +@runtime_checkable +class CatalogProviderExportable(Protocol): + """Type hint for object that has __datafusion_catalog_provider__ PyCapsule. + + https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProvider.html + """ + + def __datafusion_catalog_provider__(self) -> object: ... # noqa: D105 + +@runtime_checkable +class PySchemaProvider(Protocol): + def table_names(self) -> list[str]: + ... + + def register_table(self, table: Table) -> None: + ... + + def deregister_table(self, table_name: str) -> None: + ... + + def table_exist(self, table_name: str) -> bool: + ... + + def table(self, table_name: str) -> Table: + ... + + +@runtime_checkable +class PyCatalogProvider(Protocol): + def schema_names(self) -> list[str]: + ... + + def schema(self, name: str) -> PySchemaProvider: + ... + class SessionConfig: """Session configuration options.""" @@ -749,6 +784,17 @@ def deregister_table(self, name: str) -> None: """Remove a table from the session.""" self.ctx.deregister_table(name) + def register_catalog_provider( + self, name: str, provider: PyCatalogProvider | CatalogProviderExportable + ) -> None: + """Register a catalog provider.""" + if not isinstance(provider, (PyCatalogProvider, CatalogProviderExportable)): + raise TypeError( + f"Expected provider to be CatalogProviderProtocol or rust version exposed through python, but got {type(provider)} instead." + ) + + self.ctx.register_catalog_provider(name, provider) + def register_table_provider( self, name: str, provider: TableProviderExportable ) -> None: diff --git a/src/catalog.rs b/src/catalog.rs index 83f8d08c..1aa94430 100644 --- a/src/catalog.rs +++ b/src/catalog.rs @@ -15,9 +15,10 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; use std::collections::HashSet; use std::sync::Arc; - +use async_trait::async_trait; use pyo3::exceptions::PyKeyError; use pyo3::prelude::*; @@ -28,6 +29,8 @@ use datafusion::{ catalog::{CatalogProvider, SchemaProvider}, datasource::{TableProvider, TableType}, }; +use datafusion::common::DataFusionError; +use pyo3::Py; #[pyclass(name = "Catalog", module = "datafusion", subclass)] pub struct PyCatalog { @@ -44,6 +47,30 @@ pub struct PyTable { pub table: Arc, } +#[derive(Debug)] +#[pyclass(name = "CatalogProvider", module = "datafusion", subclass)] +pub struct PyCatalogProvider { + py_obj: Py, +} + +#[derive(Debug)] +#[pyclass(name = "SchemaProvider", module = "datafusion", subclass)] +pub struct PySchemaProvider { + py_obj: Py, +} + +impl PyCatalogProvider { + pub fn new(py_obj: Py) -> Self { + Self { py_obj } + } +} + +impl PySchemaProvider { + pub fn new(py_obj: Py) -> Self { + Self { py_obj } + } +} + impl PyCatalog { pub fn new(catalog: Arc) -> Self { Self { catalog } @@ -145,3 +172,120 @@ impl PyTable { // fn has_exact_statistics // fn supports_filter_pushdown } + +#[async_trait] +impl SchemaProvider for PySchemaProvider { + fn owner_name(&self) -> Option<&str> { + // TODO Find a better way to share the string coming from python because of PyO3 + None + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn table_names(&self) -> Vec { + Python::with_gil(|py| { + let obj = self.py_obj.bind_borrowed(py); + obj.call_method0("table_names") + .and_then(|res| res.extract::>()) + .unwrap_or_else(|err| { + eprintln!("Error calling table_names: {}", err); + vec![] + }) + }) + } + + async fn table( + &self, + name: &str, + ) -> Result>, DataFusionError> + { + Err(DataFusionError::NotImplemented( + "Python SchemaProvider does not support `table` yet".to_string(), + )) + } + + fn table_exist(&self, table_name: &str) -> bool { + Python::with_gil(|py| { + let obj = self.py_obj.bind_borrowed(py); + obj.call_method1("table_exist", (table_name,)) + .and_then(|res| res.extract::()) + .unwrap_or_else(|err| { + eprintln!("Error calling table_exists: {}", err); + false + }) + }) + } + + fn register_table(&self, name: String, table: Arc) -> datafusion::common::Result>> + { + Err(DataFusionError::NotImplemented( + "Python CatalogProvider does not support `register_schema`".to_string(), + )) + } + + fn deregister_table(&self, name: &str) -> datafusion::common::Result>> { + Err(DataFusionError::NotImplemented( + "Python CatalogProvider does not support `register_schema`".to_string(), + )) + } +} + +impl CatalogProvider for PyCatalogProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema_names(&self) -> Vec { + Python::with_gil(|py| { + let obj = self.py_obj.bind_borrowed(py); + obj.call_method0("schema_names") + .and_then(|res| res.extract::>()) + .unwrap_or_else(|err| { + eprintln!("Error calling schema_names: {}", err); + vec![] + }) + }) + } + + fn schema(&self, name: &str) -> Option> { + Python::with_gil(|py| { + // let obj = self.py_obj.as_ref(py); + let obj = self.py_obj.bind_borrowed(py); + match obj.call_method1("schema", (name,)) { + Ok(py_schema) => { + let schema_provider: PyResult> = py_schema.extract(); + match schema_provider { + Ok(py_obj) => { + let rust_provider = Arc::new(PySchemaProvider { py_obj }) as Arc; + Some(rust_provider) + } + Err(err) => { + eprintln!("Failed to extract schema provider: {err}"); + None + } + } + } + Err(err) => { + eprintln!("Error calling schema('{}'): {}", name, err); + None + } + } + }) + } + + fn register_schema(&self, name: &str, schema: Arc) + -> datafusion::common::Result>> { + Err(DataFusionError::NotImplemented( + "Python CatalogProvider does not support `register_schema`".to_string(), + )) + } + + fn deregister_schema(&self, _name: &str, _cascade: bool) + -> datafusion::common::Result>> { + Err(DataFusionError::NotImplemented( + "Python CatalogProvider does not support `deregister_schema`".to_string(), + )) + } +} diff --git a/src/context.rs b/src/context.rs index 6ce1f12b..e4411483 100644 --- a/src/context.rs +++ b/src/context.rs @@ -31,7 +31,7 @@ use uuid::Uuid; use pyo3::exceptions::{PyKeyError, PyValueError}; use pyo3::prelude::*; -use crate::catalog::{PyCatalog, PyTable}; +use crate::catalog::{PyCatalog, PyCatalogProvider, PyTable}; use crate::dataframe::PyDataFrame; use crate::dataset::Dataset; use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult}; @@ -49,6 +49,7 @@ use crate::utils::{get_global_ctx, get_tokio_runtime, validate_pycapsule, wait_f use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef}; use datafusion::arrow::pyarrow::PyArrowType; use datafusion::arrow::record_batch::RecordBatch; +use datafusion::catalog::CatalogProvider; use datafusion::common::TableReference; use datafusion::common::{exec_err, ScalarValue}; use datafusion::datasource::file_format::file_compression_type::FileCompressionType; @@ -70,6 +71,7 @@ use datafusion::prelude::{ AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions, }; use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; +use datafusion_ffi::catalog_provider::{FFI_CatalogProvider, ForeignCatalogProvider}; use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType}; use tokio::task::JoinHandle; @@ -614,6 +616,46 @@ impl PySessionContext { Ok(()) } + pub fn register_catalog_provider( + &mut self, + name: &str, + provider: Bound<'_, PyAny>, + ) -> PyDataFusionResult<()> { + if provider.hasattr("__datafusion_catalog_provider__")? { + let capsule = provider.getattr("__datafusion_catalog_provider__")?.call0()?; + let capsule = capsule.downcast::().map_err(py_datafusion_err)?; + validate_pycapsule(capsule, "datafusion_catalog_provider")?; + + let provider = unsafe { capsule.reference::() }; + let provider: ForeignCatalogProvider = provider.into(); + + let option: Option> = self.ctx.register_catalog(name, Arc::new(provider)); + match option { + Some(existing) => { + println!("Catalog '{}' already existed, schema names: {:?}", name, existing.schema_names()); + } + None => { + println!("Catalog '{}' registered successfully", name); + } + } + + Ok(()) + } else { + let python_provider = PyCatalogProvider::new(provider.into()); + let arc_provider = Arc::new(python_provider); + let option: Option> = self.ctx.register_catalog(name, arc_provider); + match option { + Some(existing) => { + println!("Catalog '{}' already existed in python catalog, schema names: {:?}", name, existing.schema_names()); + } + None => { + println!("Catalog '{}' registered successfully from python catalog", name); + } + } + Ok(()) + } + } + /// Construct datafusion dataframe from Arrow Table pub fn register_table_provider( &mut self, diff --git a/src/lib.rs b/src/lib.rs index 1293eee3..cc09f9ef 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -97,6 +97,7 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; // Register `common` as a submodule. Matching `datafusion-common` https://docs.rs/datafusion-common/latest/datafusion_common/ let common = PyModule::new(py, "common")?;