Skip to content

Commit de27c05

Browse files
committed
Use PhysicalArrowCollector for __arrow_c_stream__ on relations
1 parent f65eb74 commit de27c05

3 files changed

Lines changed: 179 additions & 1 deletion

File tree

src/duckdb_py/pyrelation.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -992,6 +992,19 @@ py::object DuckDBPyRelation::ToArrowCapsule(const py::object &requested_schema)
992992
if (!rel) {
993993
return py::none();
994994
}
995+
// The PyCapsule protocol doesn't allow custom parameters, so we use the same
996+
// default batch size as fetch_arrow_table / fetch_record_batch.
997+
idx_t batch_size = 1000000;
998+
auto &config = ClientConfig::GetConfig(*rel->context->GetContext());
999+
ScopedConfigSetting scoped_setting(
1000+
config,
1001+
[&batch_size](ClientConfig &config) {
1002+
config.get_result_collector = [&batch_size](ClientContext &context,
1003+
PreparedStatementData &data) -> PhysicalOperator & {
1004+
return PhysicalArrowCollector::Create(context, data, batch_size);
1005+
};
1006+
},
1007+
[](ClientConfig &config) { config.get_result_collector = nullptr; });
9951008
ExecuteOrThrow();
9961009
}
9971010
AssertResultOpen();
@@ -1003,7 +1016,8 @@ py::object DuckDBPyRelation::ToArrowCapsule(const py::object &requested_schema)
10031016
PolarsDataFrame DuckDBPyRelation::ToPolars(idx_t batch_size, bool lazy) {
10041017
if (!lazy) {
10051018
auto arrow = ToArrowTableInternal(batch_size, true);
1006-
return py::cast<PolarsDataFrame>(pybind11::module_::import("polars").attr("DataFrame")(arrow));
1019+
return py::cast<PolarsDataFrame>(
1020+
pybind11::module_::import("polars").attr("from_arrow")(arrow, py::arg("rechunk") = false));
10071021
}
10081022
auto &import_cache = *DuckDBPyConnection::ImportCache();
10091023
auto lazy_frame_produce = import_cache.duckdb.polars_io.duckdb_source();

src/duckdb_py/pyresult.cpp

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,81 @@ duckdb::pyarrow::RecordBatchReader DuckDBPyResult::FetchRecordBatchReader(idx_t
496496
return py::cast<duckdb::pyarrow::RecordBatchReader>(record_batch_reader);
497497
}
498498

499+
// Wraps pre-built Arrow arrays from an ArrowQueryResult into an ArrowArrayStream.
500+
// This avoids the double-materialization that happens when using ResultArrowArrayStreamWrapper
501+
// with an ArrowQueryResult (which throws NotImplementedException from FetchInternal).
502+
struct ArrowQueryResultStreamWrapper {
503+
ArrowQueryResultStreamWrapper(unique_ptr<QueryResult> result_p) : result(std::move(result_p)), index(0) {
504+
auto &arrow_result = result->Cast<ArrowQueryResult>();
505+
arrays = arrow_result.ConsumeArrays();
506+
types = result->types;
507+
names = result->names;
508+
client_properties = result->client_properties;
509+
510+
stream.private_data = this;
511+
stream.get_schema = GetSchema;
512+
stream.get_next = GetNext;
513+
stream.release = Release;
514+
stream.get_last_error = GetLastError;
515+
}
516+
517+
static int GetSchema(ArrowArrayStream *stream, ArrowSchema *out) {
518+
if (!stream->release) {
519+
return -1;
520+
}
521+
auto self = reinterpret_cast<ArrowQueryResultStreamWrapper *>(stream->private_data);
522+
out->release = nullptr;
523+
try {
524+
ArrowConverter::ToArrowSchema(out, self->types, self->names, self->client_properties);
525+
} catch (std::runtime_error &e) {
526+
self->last_error = e.what();
527+
return -1;
528+
}
529+
return 0;
530+
}
531+
532+
static int GetNext(ArrowArrayStream *stream, ArrowArray *out) {
533+
if (!stream->release) {
534+
return -1;
535+
}
536+
auto self = reinterpret_cast<ArrowQueryResultStreamWrapper *>(stream->private_data);
537+
if (self->index >= self->arrays.size()) {
538+
out->release = nullptr;
539+
return 0;
540+
}
541+
*out = self->arrays[self->index]->arrow_array;
542+
self->arrays[self->index]->arrow_array.release = nullptr;
543+
self->index++;
544+
return 0;
545+
}
546+
547+
static void Release(ArrowArrayStream *stream) {
548+
if (!stream || !stream->release) {
549+
return;
550+
}
551+
stream->release = nullptr;
552+
delete reinterpret_cast<ArrowQueryResultStreamWrapper *>(stream->private_data);
553+
}
554+
555+
static const char *GetLastError(ArrowArrayStream *stream) {
556+
if (!stream->release) {
557+
return "stream was released";
558+
}
559+
auto self = reinterpret_cast<ArrowQueryResultStreamWrapper *>(stream->private_data);
560+
return self->last_error.c_str();
561+
}
562+
563+
ArrowArrayStream stream;
564+
unique_ptr<QueryResult> result;
565+
vector<unique_ptr<ArrowArrayWrapper>> arrays;
566+
vector<LogicalType> types;
567+
vector<string> names;
568+
ClientProperties client_properties;
569+
idx_t index;
570+
string last_error;
571+
};
572+
573+
// Destructor for capsules that own a heap-allocated ArrowArrayStream (slow path).
499574
static void ArrowArrayStreamPyCapsuleDestructor(PyObject *object) {
500575
auto data = PyCapsule_GetPointer(object, "arrow_array_stream");
501576
if (!data) {
@@ -508,7 +583,28 @@ static void ArrowArrayStreamPyCapsuleDestructor(PyObject *object) {
508583
delete stream;
509584
}
510585

586+
// Destructor for capsules pointing at an embedded ArrowArrayStream (fast path).
587+
// The stream is owned by an ArrowQueryResultStreamWrapper; Release() frees both.
588+
static void ArrowArrayStreamEmbeddedPyCapsuleDestructor(PyObject *object) {
589+
auto data = PyCapsule_GetPointer(object, "arrow_array_stream");
590+
if (!data) {
591+
return;
592+
}
593+
auto stream = reinterpret_cast<ArrowArrayStream *>(data);
594+
if (stream->release) {
595+
stream->release(stream);
596+
}
597+
}
598+
511599
py::object DuckDBPyResult::FetchArrowCapsule(idx_t rows_per_batch) {
600+
if (result && result->type == QueryResultType::ARROW_RESULT) {
601+
// Fast path: yield pre-built Arrow arrays directly.
602+
// The wrapper is heap-allocated; Release() deletes it via private_data.
603+
// The capsule points at the embedded stream field — no separate heap allocation needed.
604+
auto wrapper = new ArrowQueryResultStreamWrapper(std::move(result));
605+
return py::capsule(&wrapper->stream, "arrow_array_stream", ArrowArrayStreamEmbeddedPyCapsuleDestructor);
606+
}
607+
// Existing slow path for MaterializedQueryResult / StreamQueryResult
512608
auto stream_p = FetchArrowArrayStream(rows_per_batch);
513609
auto stream = new ArrowArrayStream();
514610
*stream = stream_p;

tests/fast/arrow/test_arrow_pycapsule.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import duckdb
44

5+
pa = pytest.importorskip("pyarrow")
56
pl = pytest.importorskip("polars")
67

78

@@ -11,6 +12,73 @@ def polars_supports_capsule():
1112
return Version(pl.__version__) >= Version("1.4.1")
1213

1314

15+
class TestArrowPyCapsuleExport:
16+
"""Tests for the PyCapsule export path (rel.__arrow_c_stream__).
17+
18+
Validates that the fast path (PhysicalArrowCollector + ArrowQueryResultStreamWrapper)
19+
produces correct data, matching to_arrow_table() across types and edge cases.
20+
"""
21+
22+
def test_capsule_matches_to_arrow_table(self):
23+
"""Fast path produces identical data to to_arrow_table for various types."""
24+
conn = duckdb.connect()
25+
sql = """
26+
SELECT
27+
i AS int_col,
28+
i::DOUBLE AS double_col,
29+
'row_' || i::VARCHAR AS str_col,
30+
i % 2 = 0 AS bool_col,
31+
CASE WHEN i % 3 = 0 THEN NULL ELSE i END AS nullable_col
32+
FROM range(1000) t(i)
33+
"""
34+
expected = conn.sql(sql).to_arrow_table()
35+
actual = pa.table(conn.sql(sql))
36+
assert actual.equals(expected)
37+
38+
def test_capsule_matches_to_arrow_table_nested_types(self):
39+
"""Fast path handles nested types (struct, list, map)."""
40+
conn = duckdb.connect()
41+
sql = """
42+
SELECT
43+
{'x': i, 'y': i::VARCHAR} AS struct_col,
44+
[i, i+1, i+2] AS list_col,
45+
MAP {i::VARCHAR: i*10} AS map_col,
46+
FROM range(100) t(i)
47+
"""
48+
expected = conn.sql(sql).to_arrow_table()
49+
actual = pa.table(conn.sql(sql))
50+
assert actual.equals(expected)
51+
52+
def test_capsule_multi_batch(self):
53+
"""Data exceeding the 1M batch size produces multiple batches, all yielded correctly."""
54+
conn = duckdb.connect()
55+
sql = "SELECT i, i::DOUBLE AS d FROM range(1500000) t(i)"
56+
expected = conn.sql(sql).to_arrow_table()
57+
actual = pa.table(conn.sql(sql))
58+
assert actual.num_rows == 1500000
59+
assert actual.equals(expected)
60+
61+
def test_capsule_empty_result(self):
62+
"""Empty result set produces a valid empty table with correct schema."""
63+
conn = duckdb.connect()
64+
sql = "SELECT i AS a, i::VARCHAR AS b FROM range(10) t(i) WHERE i < 0"
65+
expected = conn.sql(sql).to_arrow_table()
66+
actual = pa.table(conn.sql(sql))
67+
assert actual.num_rows == 0
68+
assert actual.schema.equals(expected.schema)
69+
70+
def test_capsule_slow_path_after_execute(self):
71+
"""Pre-executed relation takes the slow path (MaterializedQueryResult) and still works."""
72+
conn = duckdb.connect()
73+
sql = "SELECT i, i::DOUBLE AS d FROM range(500) t(i)"
74+
expected = conn.sql(sql).to_arrow_table()
75+
76+
rel = conn.sql(sql)
77+
rel.execute() # forces MaterializedCollector, not PhysicalArrowCollector
78+
actual = pa.table(rel)
79+
assert actual.equals(expected)
80+
81+
1482
@pytest.mark.skipif(
1583
not polars_supports_capsule(), reason="Polars version does not support the Arrow PyCapsule interface"
1684
)

0 commit comments

Comments
 (0)