Skip to content

Commit 97f7c2c

Browse files
committed
feat: refactor TableProvider integration and add new TableProvider class
1 parent 3b95a1d commit 97f7c2c

4 files changed

Lines changed: 108 additions & 8 deletions

File tree

python/datafusion/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from . import functions, object_store, substrait, unparser
3434

3535
# The following imports are okay to remain as opaque to the user.
36-
from ._internal import Config, TableProvider
36+
from ._internal import Config
3737
from .catalog import Catalog, Database, Table
3838
from .col import col, column
3939
from .common import (
@@ -54,6 +54,7 @@
5454
from .io import read_avro, read_csv, read_json, read_parquet
5555
from .plan import ExecutionPlan, LogicalPlan
5656
from .record_batch import RecordBatch, RecordBatchStream
57+
from .table_provider import TableProvider
5758
from .user_defined import (
5859
Accumulator,
5960
AggregateUDF,

python/datafusion/dataframe.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@
5252
import polars as pl
5353
import pyarrow as pa
5454

55-
from datafusion._internal import TableProvider
5655
from datafusion._internal import expr as expr_internal
56+
from datafusion.table_provider import TableProvider
5757

5858
from enum import Enum
5959

@@ -316,7 +316,9 @@ def into_view(self) -> TableProvider:
316316
``TableProvider.from_dataframe`` calls this method under the hood,
317317
and the older ``TableProvider.from_view`` helper is deprecated.
318318
"""
319-
return self.df.into_view()
319+
from datafusion.table_provider import TableProvider as _TableProvider
320+
321+
return _TableProvider(self.df.into_view())
320322

321323
def __getitem__(self, key: str | list[str]) -> DataFrame:
322324
"""Return a new :py:class`DataFrame` with the specified column or columns.
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Wrapper helpers for :mod:`datafusion._internal.TableProvider`."""
18+
19+
from __future__ import annotations
20+
21+
from typing import Any
22+
23+
import datafusion._internal as df_internal
24+
25+
_InternalTableProvider = df_internal.TableProvider
26+
27+
28+
class TableProvider:
29+
"""High level wrapper around :mod:`datafusion._internal.TableProvider`."""
30+
31+
__slots__ = ("_table_provider",)
32+
33+
def __init__(self, table_provider: _InternalTableProvider) -> None:
34+
"""Wrap a low level :class:`~datafusion._internal.TableProvider`."""
35+
if isinstance(table_provider, TableProvider):
36+
table_provider = table_provider._table_provider
37+
38+
if not isinstance(table_provider, _InternalTableProvider):
39+
msg = "Expected a datafusion._internal.TableProvider instance."
40+
raise TypeError(msg)
41+
42+
self._table_provider = table_provider
43+
44+
# ------------------------------------------------------------------
45+
# constructors
46+
# ------------------------------------------------------------------
47+
@classmethod
48+
def _wrap(cls, provider: _InternalTableProvider | TableProvider) -> TableProvider:
49+
if isinstance(provider, cls):
50+
return provider
51+
return cls(provider)
52+
53+
@classmethod
54+
def from_capsule(cls, capsule: Any) -> TableProvider:
55+
"""Create a :class:`TableProvider` from a PyCapsule."""
56+
provider = _InternalTableProvider.from_capsule(capsule)
57+
return cls(provider)
58+
59+
@classmethod
60+
def from_dataframe(cls, df: Any) -> TableProvider:
61+
"""Create a :class:`TableProvider` from a :class:`DataFrame`."""
62+
from datafusion.dataframe import DataFrame as DataFrameWrapper
63+
64+
if isinstance(df, DataFrameWrapper):
65+
df = df.df
66+
67+
provider = _InternalTableProvider.from_dataframe(df)
68+
return cls(provider)
69+
70+
@classmethod
71+
def from_view(cls, df: Any) -> TableProvider:
72+
"""Create a :class:`TableProvider` from a DataFrame view."""
73+
from datafusion.dataframe import DataFrame as DataFrameWrapper
74+
75+
if isinstance(df, DataFrameWrapper):
76+
df = df.df
77+
78+
provider = _InternalTableProvider.from_view(df)
79+
return cls(provider)
80+
81+
# ------------------------------------------------------------------
82+
# passthrough helpers
83+
# ------------------------------------------------------------------
84+
def __getattr__(self, name: str) -> Any:
85+
"""Delegate attribute lookup to the wrapped provider."""
86+
return getattr(self._table_provider, name)
87+
88+
def __repr__(self) -> str: # pragma: no cover - simple delegation
89+
"""Return a representation of the wrapped provider."""
90+
return repr(self._table_provider)
91+
92+
def __datafusion_table_provider__(self) -> Any:
93+
"""Expose the wrapped provider for FFI integrations."""
94+
return self._table_provider.__datafusion_table_provider__()
95+
96+
97+
__all__ = ["TableProvider"]

python/tests/test_context.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def test_table_provider_from_capsule(ctx):
350350

351351

352352
def test_table_provider_from_dataframe(ctx):
353-
df = ctx.from_pydict({"a": [1, 2]}).df
353+
df = ctx.from_pydict({"a": [1, 2]})
354354
provider = TableProvider.from_dataframe(df)
355355
ctx.register_table("from_dataframe_tbl", provider)
356356
result = ctx.sql("SELECT * FROM from_dataframe_tbl").collect()
@@ -380,12 +380,12 @@ def test_table_provider_from_capsule_invalid():
380380

381381
def test_register_table_with_dataframe_errors(ctx):
382382
df = ctx.from_pydict({"a": [1]})
383-
with pytest.raises(Exception) as exc_info: # noqa: B017
383+
with pytest.raises(Exception) as exc_info:
384384
ctx.register_table("bad", df)
385385

386-
assert (
387-
str(exc_info.value)
388-
== 'Expected a Table or TableProvider. Convert DataFrames with "DataFrame.into_view()" or "TableProvider.from_dataframe()".'
386+
assert str(exc_info.value) == (
387+
"Expected a Table or TableProvider. Convert DataFrames with "
388+
'"DataFrame.into_view()" or "TableProvider.from_dataframe()".'
389389
)
390390

391391

0 commit comments

Comments
 (0)