Skip to content

Commit a4a4208

Browse files
Unify Arrow stream scanning via __arrow_c_stream__ (#307)
Unify Arrow stream scanning via __arrow_c_stream__ ## Problem Related: #70 DuckDB's Python client had separate code paths for every Arrow-flavored object type: PyArrow Table, RecordBatchReader, Scanner, Dataset, PyCapsule, and PyCapsuleInterface. Many of these did the same thing through different routes — materialize to a PyArrow Table, then scan it. This made the codebase harder to extend, and objects implementing the PyCapsule Interface (`__arrow_c_stream__`) couldn't get projection/filter pushdown unless pyarrow.dataset was installed. ## Approach The core design decision is to prefer `__arrow_c_stream__` as the universal entry point rather than maintaining isinstance checks for PyArrow Table and RecordBatchReader. Both types implement `__arrow_c_stream__`, so they don't need dedicated branches — they fall through to the same PyCapsuleInterface path that handles any third-party Arrow producer (Polars, ADBC, etc.). This collapses the type detection in `GetArrowType()` from 6 `isinstance` checks down to three (the types that don't have `__arrow_c_stream__`): * Scanner * Dataset * MessageReader ...followed by a single `hasattr(obj, "__arrow_c_stream__")` catch-all. The PyCapsuleInterface path now has "tiered" pushdown: - if `pyarrow.dataset` is available: import the stream as a RecordBatchReader, feed through Scanner.from_batches for projection/filter pushdown, export back to C stream. - otherwise: return the raw C stream directly. DuckDB handles projection/filter post-scan via arrow_scan_dumb. For schema extraction we use schema._export_to_c as a fallback between __arrow_c_schema__ and the stream-consuming fallback. This hopefully prevents single-use streams from being consumed during schema extraction. ~Polars DataFrames with __arrow_c_stream__ (v1.4+) now fall through to the unified path instead of going through .to_arrow(). We keep a fallback for Polars < 1.4.~ Edit: this resulted in a big performance degradation. Polars doesn't seem to do zero-copy conversion and will re-convert for every new scan. I've reverted for now. I'll post some benchmarks tomorrow. First results look good.
2 parents b6efa5c + 5b6edb9 commit a4a4208

8 files changed

Lines changed: 535 additions & 67 deletions

File tree

scripts/cache_data.json

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
"pyarrow.decimal32",
2121
"pyarrow.decimal64",
2222
"pyarrow.decimal128"
23-
]
23+
],
24+
"required": false
2425
},
2526
"pyarrow.dataset": {
2627
"type": "module",
@@ -29,7 +30,8 @@
2930
"children": [
3031
"pyarrow.dataset.Scanner",
3132
"pyarrow.dataset.Dataset"
32-
]
33+
],
34+
"required": false
3335
},
3436
"pyarrow.dataset.Scanner": {
3537
"type": "attribute",

src/duckdb_py/arrow/arrow_array_stream.cpp

Lines changed: 80 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,45 @@ unique_ptr<ArrowArrayStreamWrapper> PythonTableArrowArrayStreamFactory::Produce(
6666
py::handle arrow_obj_handle(factory->arrow_object);
6767
auto arrow_object_type = DuckDBPyConnection::GetArrowType(arrow_obj_handle);
6868

69+
if (arrow_object_type == PyArrowObjectType::PyCapsuleInterface) {
70+
py::object capsule_obj = arrow_obj_handle.attr("__arrow_c_stream__")();
71+
auto capsule = py::reinterpret_borrow<py::capsule>(capsule_obj);
72+
auto stream = capsule.get_pointer<struct ArrowArrayStream>();
73+
if (!stream->release) {
74+
throw InvalidInputException(
75+
"The __arrow_c_stream__() method returned a released stream. "
76+
"If this object is single-use, implement __arrow_c_schema__() or expose a .schema attribute "
77+
"with _export_to_c() so that DuckDB can extract the schema without consuming the stream.");
78+
}
79+
80+
auto &import_cache_check = *DuckDBPyConnection::ImportCache();
81+
if (import_cache_check.pyarrow.dataset()) {
82+
// Tier A: full pushdown via pyarrow.dataset
83+
// Import as RecordBatchReader, feed through Scanner.from_batches for projection/filter pushdown.
84+
auto pyarrow_lib_module = py::module::import("pyarrow").attr("lib");
85+
auto import_func = pyarrow_lib_module.attr("RecordBatchReader").attr("_import_from_c");
86+
py::object reader = import_func(reinterpret_cast<uint64_t>(stream));
87+
// _import_from_c takes ownership of the stream; null out to prevent capsule double-free
88+
stream->release = nullptr;
89+
auto &import_cache = *DuckDBPyConnection::ImportCache();
90+
py::object arrow_batch_scanner = import_cache.pyarrow.dataset.Scanner().attr("from_batches");
91+
py::handle reader_handle = reader;
92+
auto scanner = ProduceScanner(arrow_batch_scanner, reader_handle, parameters, factory->client_properties);
93+
auto record_batches = scanner.attr("to_reader")();
94+
auto res = make_uniq<ArrowArrayStreamWrapper>();
95+
auto export_to_c = record_batches.attr("_export_to_c");
96+
export_to_c(reinterpret_cast<uint64_t>(&res->arrow_array_stream));
97+
return res;
98+
} else {
99+
// Tier B: no pyarrow.dataset, return raw stream (no pushdown)
100+
// DuckDB applies projection/filter post-scan via arrow_scan_dumb
101+
auto res = make_uniq<ArrowArrayStreamWrapper>();
102+
res->arrow_array_stream = *stream;
103+
stream->release = nullptr;
104+
return res;
105+
}
106+
}
107+
69108
if (arrow_object_type == PyArrowObjectType::PyCapsule) {
70109
auto res = make_uniq<ArrowArrayStreamWrapper>();
71110
auto capsule = py::reinterpret_borrow<py::capsule>(arrow_obj_handle);
@@ -78,21 +117,12 @@ unique_ptr<ArrowArrayStreamWrapper> PythonTableArrowArrayStreamFactory::Produce(
78117
return res;
79118
}
80119

120+
// Scanner and Dataset: require pyarrow.dataset for pushdown
121+
VerifyArrowDatasetLoaded();
81122
auto &import_cache = *DuckDBPyConnection::ImportCache();
82123
py::object scanner;
83124
py::object arrow_batch_scanner = import_cache.pyarrow.dataset.Scanner().attr("from_batches");
84125
switch (arrow_object_type) {
85-
case PyArrowObjectType::Table: {
86-
auto arrow_dataset = import_cache.pyarrow.dataset().attr("dataset");
87-
auto dataset = arrow_dataset(arrow_obj_handle);
88-
py::object arrow_scanner = dataset.attr("__class__").attr("scanner");
89-
scanner = ProduceScanner(arrow_scanner, dataset, parameters, factory->client_properties);
90-
break;
91-
}
92-
case PyArrowObjectType::RecordBatchReader: {
93-
scanner = ProduceScanner(arrow_batch_scanner, arrow_obj_handle, parameters, factory->client_properties);
94-
break;
95-
}
96126
case PyArrowObjectType::Scanner: {
97127
// If it's a scanner we have to turn it to a record batch reader, and then a scanner again since we can't stack
98128
// scanners on arrow Otherwise pushed-down projections and filters will disappear like tears in the rain
@@ -119,37 +149,29 @@ unique_ptr<ArrowArrayStreamWrapper> PythonTableArrowArrayStreamFactory::Produce(
119149
}
120150

121151
void PythonTableArrowArrayStreamFactory::GetSchemaInternal(py::handle arrow_obj_handle, ArrowSchemaWrapper &schema) {
152+
// PyCapsule (from bare capsule Produce path)
122153
if (py::isinstance<py::capsule>(arrow_obj_handle)) {
123154
auto capsule = py::reinterpret_borrow<py::capsule>(arrow_obj_handle);
124155
auto stream = capsule.get_pointer<struct ArrowArrayStream>();
125156
if (!stream->release) {
126157
throw InvalidInputException("This ArrowArrayStream has already been consumed and cannot be scanned again.");
127158
}
128-
stream->get_schema(stream, &schema.arrow_schema);
129-
return;
130-
}
131-
132-
auto table_class = py::module::import("pyarrow").attr("Table");
133-
if (py::isinstance(arrow_obj_handle, table_class)) {
134-
auto obj_schema = arrow_obj_handle.attr("schema");
135-
auto export_to_c = obj_schema.attr("_export_to_c");
136-
export_to_c(reinterpret_cast<uint64_t>(&schema.arrow_schema));
159+
if (stream->get_schema(stream, &schema.arrow_schema)) {
160+
throw InvalidInputException("Failed to get Arrow schema from stream: %s",
161+
stream->get_last_error ? stream->get_last_error(stream) : "unknown error");
162+
}
137163
return;
138164
}
139165

166+
// Scanner: use projected_schema; everything else (RecordBatchReader, Dataset): use .schema
140167
VerifyArrowDatasetLoaded();
141-
142168
auto &import_cache = *DuckDBPyConnection::ImportCache();
143-
auto scanner_class = import_cache.pyarrow.dataset.Scanner();
144-
145-
if (py::isinstance(arrow_obj_handle, scanner_class)) {
169+
if (py::isinstance(arrow_obj_handle, import_cache.pyarrow.dataset.Scanner())) {
146170
auto obj_schema = arrow_obj_handle.attr("projected_schema");
147-
auto export_to_c = obj_schema.attr("_export_to_c");
148-
export_to_c(reinterpret_cast<uint64_t>(&schema));
171+
obj_schema.attr("_export_to_c")(reinterpret_cast<uint64_t>(&schema.arrow_schema));
149172
} else {
150173
auto obj_schema = arrow_obj_handle.attr("schema");
151-
auto export_to_c = obj_schema.attr("_export_to_c");
152-
export_to_c(reinterpret_cast<uint64_t>(&schema));
174+
obj_schema.attr("_export_to_c")(reinterpret_cast<uint64_t>(&schema.arrow_schema));
153175
}
154176
}
155177

@@ -158,6 +180,36 @@ void PythonTableArrowArrayStreamFactory::GetSchema(uintptr_t factory_ptr, ArrowS
158180
auto factory = static_cast<PythonTableArrowArrayStreamFactory *>(reinterpret_cast<void *>(factory_ptr)); // NOLINT
159181
D_ASSERT(factory->arrow_object);
160182
py::handle arrow_obj_handle(factory->arrow_object);
183+
184+
auto type = DuckDBPyConnection::GetArrowType(arrow_obj_handle);
185+
if (type == PyArrowObjectType::PyCapsuleInterface) {
186+
// Get __arrow_c_schema__ if it exists
187+
if (py::hasattr(arrow_obj_handle, "__arrow_c_schema__")) {
188+
auto schema_capsule = arrow_obj_handle.attr("__arrow_c_schema__")();
189+
auto capsule = py::reinterpret_borrow<py::capsule>(schema_capsule);
190+
auto arrow_schema = capsule.get_pointer<struct ArrowSchema>();
191+
schema.arrow_schema = *arrow_schema;
192+
arrow_schema->release = nullptr; // take ownership
193+
return;
194+
}
195+
// Otherwise try to use .schema with _export_to_c
196+
if (py::hasattr(arrow_obj_handle, "schema")) {
197+
auto obj_schema = arrow_obj_handle.attr("schema");
198+
if (py::hasattr(obj_schema, "_export_to_c")) {
199+
obj_schema.attr("_export_to_c")(reinterpret_cast<uint64_t>(&schema.arrow_schema));
200+
return;
201+
}
202+
}
203+
// Fallback: create a temporary stream just for the schema (consumes single-use streams!)
204+
auto stream_capsule = arrow_obj_handle.attr("__arrow_c_stream__")();
205+
auto capsule = py::reinterpret_borrow<py::capsule>(stream_capsule);
206+
auto stream = capsule.get_pointer<struct ArrowArrayStream>();
207+
if (stream->get_schema(stream, &schema.arrow_schema)) {
208+
throw InvalidInputException("Failed to get Arrow schema from stream: %s",
209+
stream->get_last_error ? stream->get_last_error(stream) : "unknown error");
210+
}
211+
return; // stream_capsule goes out of scope, stream released by capsule destructor
212+
}
161213
GetSchemaInternal(arrow_obj_handle, schema);
162214
}
163215

src/duckdb_py/include/duckdb_python/arrow/arrow_array_stream.hpp

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,7 @@ class Table : public py::object {
5151

5252
} // namespace pyarrow
5353

54-
enum class PyArrowObjectType {
55-
Invalid,
56-
Table,
57-
RecordBatchReader,
58-
Scanner,
59-
Dataset,
60-
PyCapsule,
61-
PyCapsuleInterface,
62-
MessageReader
63-
};
54+
enum class PyArrowObjectType { Invalid, Table, Scanner, Dataset, PyCapsule, PyCapsuleInterface, MessageReader };
6455

6556
void TransformDuckToArrowChunk(ArrowSchema &arrow_schema, ArrowArray &data, py::list &batches);
6657

src/duckdb_py/include/duckdb_python/import_cache/modules/pyarrow_module.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ struct PyarrowDatasetCacheItem : public PythonImportCacheItem {
4646

4747
PythonImportCacheItem Scanner;
4848
PythonImportCacheItem Dataset;
49+
50+
protected:
51+
bool IsRequired() const override final {
52+
return false;
53+
}
4954
};
5055

5156
struct PyarrowCacheItem : public PythonImportCacheItem {
@@ -80,6 +85,11 @@ struct PyarrowCacheItem : public PythonImportCacheItem {
8085
PythonImportCacheItem decimal32;
8186
PythonImportCacheItem decimal64;
8287
PythonImportCacheItem decimal128;
88+
89+
protected:
90+
bool IsRequired() const override final {
91+
return false;
92+
}
8393
};
8494

8595
} // namespace duckdb

src/duckdb_py/pyconnection.cpp

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2384,26 +2384,16 @@ PyArrowObjectType DuckDBPyConnection::GetArrowType(const py::handle &obj) {
23842384

23852385
if (ModuleIsLoaded<PyarrowCacheItem>()) {
23862386
auto &import_cache = *DuckDBPyConnection::ImportCache();
2387-
// First Verify Lib Types
2388-
auto table_class = import_cache.pyarrow.Table();
2389-
auto record_batch_reader_class = import_cache.pyarrow.RecordBatchReader();
2390-
auto message_reader_class = import_cache.pyarrow.ipc.MessageReader();
2391-
if (py::isinstance(obj, table_class)) {
2392-
return PyArrowObjectType::Table;
2393-
} else if (py::isinstance(obj, record_batch_reader_class)) {
2394-
return PyArrowObjectType::RecordBatchReader;
2395-
} else if (py::isinstance(obj, message_reader_class)) {
2387+
// MessageReader requires nanoarrow, separate scan function
2388+
if (py::isinstance(obj, import_cache.pyarrow.ipc.MessageReader())) {
23962389
return PyArrowObjectType::MessageReader;
23972390
}
23982391

23992392
if (ModuleIsLoaded<PyarrowDatasetCacheItem>()) {
2400-
// Then Verify dataset types
2401-
auto dataset_class = import_cache.pyarrow.dataset.Dataset();
2402-
auto scanner_class = import_cache.pyarrow.dataset.Scanner();
2403-
2404-
if (py::isinstance(obj, scanner_class)) {
2393+
// Scanner/Dataset don't have __arrow_c_stream__, need dedicated handling
2394+
if (py::isinstance(obj, import_cache.pyarrow.dataset.Scanner())) {
24052395
return PyArrowObjectType::Scanner;
2406-
} else if (py::isinstance(obj, dataset_class)) {
2396+
} else if (py::isinstance(obj, import_cache.pyarrow.dataset.Dataset())) {
24072397
return PyArrowObjectType::Dataset;
24082398
}
24092399
}

src/duckdb_py/python_replacement_scan.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,6 @@ static void CreateArrowScan(const string &name, py::object entry, TableFunctionR
5151
auto dependency_item = PythonDependencyItem::Create(stream_messages);
5252
external_dependency->AddDependency("replacement_cache", std::move(dependency_item));
5353
} else {
54-
if (type == PyArrowObjectType::PyCapsuleInterface) {
55-
entry = entry.attr("__arrow_c_stream__")();
56-
type = PyArrowObjectType::PyCapsule;
57-
}
58-
5954
auto stream_factory = make_uniq<PythonTableArrowArrayStreamFactory>(entry.ptr(), client_properties);
6055
auto stream_factory_produce = PythonTableArrowArrayStreamFactory::Produce;
6156
auto stream_factory_get_schema = PythonTableArrowArrayStreamFactory::GetSchema;
@@ -66,8 +61,17 @@ static void CreateArrowScan(const string &name, py::object entry, TableFunctionR
6661
make_uniq<ConstantExpression>(Value::POINTER(CastPointerToValue(stream_factory_get_schema))));
6762

6863
if (type == PyArrowObjectType::PyCapsule) {
69-
// Disable projection+filter pushdown
64+
// Disable projection+filter pushdown for bare capsules (single-use, no PyArrow wrapper)
7065
table_function.function = make_uniq<FunctionExpression>("arrow_scan_dumb", std::move(children));
66+
} else if (type == PyArrowObjectType::PyCapsuleInterface) {
67+
// Try to load pyarrow.dataset for pushdown support
68+
auto &cache = *DuckDBPyConnection::ImportCache();
69+
if (!cache.pyarrow.dataset()) {
70+
// No pyarrow.dataset: scan without pushdown, DuckDB handles projection/filter post-scan
71+
table_function.function = make_uniq<FunctionExpression>("arrow_scan_dumb", std::move(children));
72+
} else {
73+
table_function.function = make_uniq<FunctionExpression>("arrow_scan", std::move(children));
74+
}
7175
} else {
7276
table_function.function = make_uniq<FunctionExpression>("arrow_scan", std::move(children));
7377
}
@@ -141,6 +145,9 @@ unique_ptr<TableRef> PythonReplacementScan::TryReplacementObject(const py::objec
141145
subquery->external_dependency = std::move(dependency);
142146
return std::move(subquery);
143147
} else if (PolarsDataFrame::IsDataFrame(entry)) {
148+
// Polars DataFrames always go through one-time .to_arrow() materialization.
149+
// Polars's __arrow_c_stream__() serializes from its internal layout on every call,
150+
// which is expensive for repeated scans. The .to_arrow() path converts once.
144151
auto arrow_dataset = entry.attr("to_arrow")();
145152
CreateArrowScan(name, arrow_dataset, *table_function, children, client_properties, PyArrowObjectType::Table,
146153
*context.db);
@@ -149,9 +156,8 @@ unique_ptr<TableRef> PythonReplacementScan::TryReplacementObject(const py::objec
149156
auto arrow_dataset = materialized.attr("to_arrow")();
150157
CreateArrowScan(name, arrow_dataset, *table_function, children, client_properties, PyArrowObjectType::Table,
151158
*context.db);
152-
} else if (DuckDBPyConnection::GetArrowType(entry) != PyArrowObjectType::Invalid &&
153-
!(DuckDBPyConnection::GetArrowType(entry) == PyArrowObjectType::MessageReader && !relation)) {
154-
arrow_type = DuckDBPyConnection::GetArrowType(entry);
159+
} else if ((arrow_type = DuckDBPyConnection::GetArrowType(entry)) != PyArrowObjectType::Invalid &&
160+
!(arrow_type == PyArrowObjectType::MessageReader && !relation)) {
155161
CreateArrowScan(name, entry, *table_function, children, client_properties, arrow_type, *context.db);
156162
} else if (DuckDBPyConnection::IsAcceptedNumpyObject(entry) != NumpyObjectType::INVALID) {
157163
numpytype = DuckDBPyConnection::IsAcceptedNumpyObject(entry);

tests/fast/arrow/test_arrow_pycapsule.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,24 @@ def __arrow_c_stream__(self, requested_schema=None) -> object:
2929
obj = MyObject(df)
3030

3131
# Call the __arrow_c_stream__ from within DuckDB
32+
# MyObject has no __arrow_c_schema__, so GetSchema() falls back to __arrow_c_stream__ (1 call),
33+
# then Produce() calls __arrow_c_stream__ again (1 call) = 2 calls minimum per scan.
3234
res = duckdb_cursor.sql("select * from obj")
3335
assert res.fetchall() == [(1, 5), (2, 6), (3, 7), (4, 8)]
34-
assert obj.count == 1
36+
count_after_first = obj.count
37+
assert count_after_first >= 2
3538

3639
# Call the __arrow_c_stream__ method and pass in the capsule instead
3740
capsule = obj.__arrow_c_stream__()
3841
res = duckdb_cursor.sql("select * from capsule")
3942
assert res.fetchall() == [(1, 5), (2, 6), (3, 7), (4, 8)]
40-
assert obj.count == 2
43+
assert obj.count == count_after_first + 1
4144

4245
# Ensure __arrow_c_stream__ accepts a requested_schema argument as noop
4346
capsule = obj.__arrow_c_stream__(requested_schema="foo") # noqa: F841
4447
res = duckdb_cursor.sql("select * from capsule")
4548
assert res.fetchall() == [(1, 5), (2, 6), (3, 7), (4, 8)]
46-
assert obj.count == 3
49+
assert obj.count == count_after_first + 2
4750

4851
def test_capsule_roundtrip(self, duckdb_cursor):
4952
def create_capsule():

0 commit comments

Comments
 (0)