Skip to content

Exposing FFI to python #1137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/datafusion-ffi-example/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions examples/datafusion-ffi-example/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pyo3 = { version = "0.23", features = ["extension-module", "abi3", "abi3-py39"]
arrow = { version = "55.0.0" }
arrow-array = { version = "55.0.0" }
arrow-schema = { version = "55.0.0" }
async-trait = "0.1.88"

[build-dependencies]
pyo3-build-config = "0.23"
Expand Down
117 changes: 117 additions & 0 deletions examples/datafusion-ffi-example/python/tests/_test_catalog_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import pyarrow as pa

from datafusion import SessionContext, Table
from datafusion_ffi_example import MyCatalogProvider

from datafusion.context import PyCatalogProvider, PySchemaProvider


def test_catalog_provider():
ctx = SessionContext()

my_catalog_name = "my_catalog"
expected_schema_name = "my_schema"
expected_table_name = "my_table"
expected_table_columns = ["units", "price"]

catalog_provider = MyCatalogProvider()
ctx.register_catalog_provider(my_catalog_name, catalog_provider)
my_catalog = ctx.catalog(my_catalog_name)

my_catalog_schemas = my_catalog.names()
assert expected_schema_name in my_catalog_schemas
my_database = my_catalog.database(expected_schema_name)
assert expected_table_name in my_database.names()
my_table = my_database.table(expected_table_name)
assert expected_table_columns == my_table.schema.names

result = ctx.table(
f"{my_catalog_name}.{expected_schema_name}.{expected_table_name}"
).collect()
assert len(result) == 2

col0_result = [r.column(0) for r in result]
col1_result = [r.column(1) for r in result]
expected_col0 = [
pa.array([10, 20, 30], type=pa.int32()),
pa.array([5, 7], type=pa.int32()),
]
expected_col1 = [
pa.array([1, 2, 5], type=pa.float64()),
pa.array([1.5, 2.5], type=pa.float64()),
]
assert col0_result == expected_col0
assert col1_result == expected_col1


class MyPyCatalogProvider(PyCatalogProvider):
my_schemas = ['my_schema']

def schema_names(self) -> list[str]:
return self.my_schemas

def schema(self, name: str) -> PySchemaProvider:
return MyPySchemaProvider()


class MyPySchemaProvider(PySchemaProvider):
my_tables = ['table1', 'table2', 'table3']

def table_names(self) -> list[str]:
return self.my_tables

def table_exist(self, table_name: str) -> bool:
return table_name in self.my_tables

def table(self, table_name: str) -> Table:
raise RuntimeError(f"Can not get table: {table_name}")

def register_table(self, table: Table) -> None:
raise RuntimeError(f"Can not register {table} as table")

def deregister_table(self, table_name: str) -> None:
raise RuntimeError(f"Can not deregister table: {table_name}")


def test_python_catalog_provider():
ctx = SessionContext()

my_catalog_name = "my_py_catalog"
expected_schema_name = "my_schema"
my_py_catalog_provider = MyPyCatalogProvider()
ctx.register_catalog_provider(my_catalog_name, my_py_catalog_provider)
my_py_catalog = ctx.catalog(my_catalog_name)
assert MyPyCatalogProvider.my_schemas == my_py_catalog.names()

my_database = my_py_catalog.database(expected_schema_name)
assert set(MyPySchemaProvider.my_tables) == my_database.names()

# asserting a non-compliant provider fails at the python level as expected
try:
ctx.register_catalog_provider(my_catalog_name, "non_compliant_provider")
except TypeError:
# expect a TypeError because we can not register a str as a catalog provider
pass


if __name__ == "__main__":
test_python_catalog_provider()
174 changes: 174 additions & 0 deletions examples/datafusion-ffi-example/src/catalog_provider.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use pyo3::{pyclass, pymethods, Bound, PyResult, Python};
use std::{any::Any, fmt::Debug, sync::Arc};

use arrow::datatypes::Schema;
use async_trait::async_trait;
use datafusion::{
catalog::{
CatalogProvider, MemoryCatalogProvider, MemorySchemaProvider, SchemaProvider, TableProvider,
},
common::exec_err,
datasource::MemTable,
error::{DataFusionError, Result},
};
use datafusion_ffi::catalog_provider::FFI_CatalogProvider;
use pyo3::types::PyCapsule;

pub fn my_table() -> Arc<dyn TableProvider + 'static> {
use arrow::datatypes::{DataType, Field};
use datafusion::common::record_batch;

let schema = Arc::new(Schema::new(vec![
Field::new("units", DataType::Int32, true),
Field::new("price", DataType::Float64, true),
]));

let partitions = vec![
record_batch!(
("units", Int32, vec![10, 20, 30]),
("price", Float64, vec![1.0, 2.0, 5.0])
)
.unwrap(),
record_batch!(
("units", Int32, vec![5, 7]),
("price", Float64, vec![1.5, 2.5])
)
.unwrap(),
];

Arc::new(MemTable::try_new(schema, vec![partitions]).unwrap())
}

#[derive(Debug)]
pub struct FixedSchemaProvider {
inner: MemorySchemaProvider,
}

impl Default for FixedSchemaProvider {
fn default() -> Self {
let inner = MemorySchemaProvider::new();

let table = my_table();

let _ = inner.register_table("my_table".to_string(), table).unwrap();

Self { inner }
}
}

#[async_trait]
impl SchemaProvider for FixedSchemaProvider {
fn as_any(&self) -> &dyn Any {
self
}

fn table_names(&self) -> Vec<String> {
self.inner.table_names()
}

async fn table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>, DataFusionError> {
self.inner.table(name).await
}

fn register_table(
&self,
name: String,
table: Arc<dyn TableProvider>,
) -> Result<Option<Arc<dyn TableProvider>>> {
self.inner.register_table(name, table)
}

fn deregister_table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>> {
self.inner.deregister_table(name)
}

fn table_exist(&self, name: &str) -> bool {
self.inner.table_exist(name)
}
}

/// This catalog provider is intended only for unit tests. It prepopulates with one
/// schema and only allows for schemas named after four types of fruit.
#[pyclass(name = "MyCatalogProvider", module = "datafusion_ffi_example", subclass)]
#[derive(Debug)]
pub(crate) struct MyCatalogProvider {
inner: MemoryCatalogProvider,
}

impl Default for MyCatalogProvider {
fn default() -> Self {
let inner = MemoryCatalogProvider::new();

let schema_name: &str = "my_schema";
let _ = inner.register_schema(schema_name, Arc::new(FixedSchemaProvider::default()));

Self { inner }
}
}

impl CatalogProvider for MyCatalogProvider {
fn as_any(&self) -> &dyn Any {
self
}

fn schema_names(&self) -> Vec<String> {
self.inner.schema_names()
}

fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
self.inner.schema(name)
}

fn register_schema(
&self,
name: &str,
schema: Arc<dyn SchemaProvider>,
) -> Result<Option<Arc<dyn SchemaProvider>>> {
self.inner.register_schema(name, schema)
}

fn deregister_schema(
&self,
name: &str,
cascade: bool,
) -> Result<Option<Arc<dyn SchemaProvider>>> {
self.inner.deregister_schema(name, cascade)
}
}

#[pymethods]
impl MyCatalogProvider {
#[new]
pub fn new() -> Self {
Self {
inner: Default::default(),
}
}

pub fn __datafusion_catalog_provider__<'py>(
&self,
py: Python<'py>,
) -> PyResult<Bound<'py, PyCapsule>> {
let name = cr"datafusion_catalog_provider".into();
let catalog_provider = FFI_CatalogProvider::new(Arc::new(MyCatalogProvider::default()), None);

PyCapsule::new(py, catalog_provider, Some(name))
}
}
3 changes: 3 additions & 0 deletions examples/datafusion-ffi-example/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,19 @@
// specific language governing permissions and limitations
// under the License.

use crate::catalog_provider::MyCatalogProvider;
use crate::table_function::MyTableFunction;
use crate::table_provider::MyTableProvider;
use pyo3::prelude::*;

pub(crate) mod catalog_provider;
pub(crate) mod table_function;
pub(crate) mod table_provider;

#[pymodule]
fn datafusion_ffi_example(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<MyTableProvider>()?;
m.add_class::<MyTableFunction>()?;
m.add_class::<MyCatalogProvider>()?;
Ok(())
}
48 changes: 47 additions & 1 deletion python/datafusion/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any, Protocol
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable

import pyarrow as pa

Expand Down Expand Up @@ -79,6 +79,41 @@

def __datafusion_table_provider__(self) -> object: ... # noqa: D105

@runtime_checkable
class CatalogProviderExportable(Protocol):
"""Type hint for object that has __datafusion_catalog_provider__ PyCapsule.

https://docs.rs/datafusion/latest/datafusion/catalog/trait.CatalogProvider.html
"""

def __datafusion_catalog_provider__(self) -> object: ... # noqa: D105

@runtime_checkable
class PySchemaProvider(Protocol):

Check failure on line 92 in python/datafusion/context.py

View workflow job for this annotation

GitHub Actions / build

Ruff (D101)

python/datafusion/context.py:92:7: D101 Missing docstring in public class
def table_names(self) -> list[str]:

Check failure on line 93 in python/datafusion/context.py

View workflow job for this annotation

GitHub Actions / build

Ruff (D102)

python/datafusion/context.py:93:9: D102 Missing docstring in public method
...

def register_table(self, table: Table) -> None:

Check failure on line 96 in python/datafusion/context.py

View workflow job for this annotation

GitHub Actions / build

Ruff (D102)

python/datafusion/context.py:96:9: D102 Missing docstring in public method
...

def deregister_table(self, table_name: str) -> None:

Check failure on line 99 in python/datafusion/context.py

View workflow job for this annotation

GitHub Actions / build

Ruff (D102)

python/datafusion/context.py:99:9: D102 Missing docstring in public method
...

def table_exist(self, table_name: str) -> bool:

Check failure on line 102 in python/datafusion/context.py

View workflow job for this annotation

GitHub Actions / build

Ruff (D102)

python/datafusion/context.py:102:9: D102 Missing docstring in public method
...

def table(self, table_name: str) -> Table:

Check failure on line 105 in python/datafusion/context.py

View workflow job for this annotation

GitHub Actions / build

Ruff (D102)

python/datafusion/context.py:105:9: D102 Missing docstring in public method
...


@runtime_checkable
class PyCatalogProvider(Protocol):

Check failure on line 110 in python/datafusion/context.py

View workflow job for this annotation

GitHub Actions / build

Ruff (D101)

python/datafusion/context.py:110:7: D101 Missing docstring in public class
def schema_names(self) -> list[str]:

Check failure on line 111 in python/datafusion/context.py

View workflow job for this annotation

GitHub Actions / build

Ruff (D102)

python/datafusion/context.py:111:9: D102 Missing docstring in public method
...

def schema(self, name: str) -> PySchemaProvider:

Check failure on line 114 in python/datafusion/context.py

View workflow job for this annotation

GitHub Actions / build

Ruff (D102)

python/datafusion/context.py:114:9: D102 Missing docstring in public method
...


class SessionConfig:
"""Session configuration options."""
Expand Down Expand Up @@ -749,6 +784,17 @@
"""Remove a table from the session."""
self.ctx.deregister_table(name)

def register_catalog_provider(
self, name: str, provider: PyCatalogProvider | CatalogProviderExportable
) -> None:
"""Register a catalog provider."""
if not isinstance(provider, (PyCatalogProvider, CatalogProviderExportable)):
raise TypeError(
f"Expected provider to be CatalogProviderProtocol or rust version exposed through python, but got {type(provider)} instead."

Check failure on line 793 in python/datafusion/context.py

View workflow job for this annotation

GitHub Actions / build

Ruff (EM102)

python/datafusion/context.py:793:17: EM102 Exception must not use an f-string literal, assign to variable first
)

self.ctx.register_catalog_provider(name, provider)

def register_table_provider(
self, name: str, provider: TableProviderExportable
) -> None:
Expand Down
Loading
Loading