diff --git a/src/duckdb_py/pyresult.cpp b/src/duckdb_py/pyresult.cpp index 34b5f6ff..91118add 100644 --- a/src/duckdb_py/pyresult.cpp +++ b/src/duckdb_py/pyresult.cpp @@ -496,16 +496,134 @@ duckdb::pyarrow::RecordBatchReader DuckDBPyResult::FetchRecordBatchReader(idx_t return py::cast(record_batch_reader); } +// Holds owned copies of the string data for a deep-copied ArrowSchema node. +struct ArrowSchemaCopyData { + string format; + string name; + string metadata; +}; + +static void ReleaseCopiedArrowSchema(ArrowSchema *schema) { + if (!schema || !schema->release) { + return; + } + for (int64_t i = 0; i < schema->n_children; i++) { + if (schema->children[i]->release) { + schema->children[i]->release(schema->children[i]); + } + delete schema->children[i]; + } + delete[] schema->children; + if (schema->dictionary) { + if (schema->dictionary->release) { + schema->dictionary->release(schema->dictionary); + } + delete schema->dictionary; + } + delete reinterpret_cast(schema->private_data); + schema->release = nullptr; +} + +static idx_t ArrowMetadataSize(const char *metadata) { + if (!metadata) { + return 0; + } + // Arrow metadata format: int32 num_entries, then for each entry: + // int32 key_len, key_bytes, int32 value_len, value_bytes + auto ptr = metadata; + int32_t num_entries; + memcpy(&num_entries, ptr, sizeof(int32_t)); + ptr += sizeof(int32_t); + for (int32_t i = 0; i < num_entries; i++) { + int32_t len; + memcpy(&len, ptr, sizeof(int32_t)); + ptr += sizeof(int32_t) + len; + memcpy(&len, ptr, sizeof(int32_t)); + ptr += sizeof(int32_t) + len; + } + return ptr - metadata; +} + +// Deep-copy an ArrowSchema. The Arrow C Data Interface specifies that get_schema +// transfers ownership to the caller, so each call must produce an independent copy. +// Each node owns its string data via an ArrowSchemaCopyData in private_data. +static int ArrowSchemaDeepCopy(const ArrowSchema &source, ArrowSchema *out, string &error) { + out->release = nullptr; + try { + auto data = new ArrowSchemaCopyData(); + data->format = source.format ? source.format : ""; + data->name = source.name ? source.name : ""; + if (source.metadata) { + auto metadata_size = ArrowMetadataSize(source.metadata); + data->metadata.assign(source.metadata, metadata_size); + } + + out->format = data->format.c_str(); + out->name = data->name.c_str(); + out->metadata = source.metadata ? data->metadata.data() : nullptr; + out->flags = source.flags; + out->n_children = source.n_children; + out->dictionary = nullptr; + out->private_data = data; + out->release = ReleaseCopiedArrowSchema; + + if (source.n_children > 0) { + out->children = new ArrowSchema *[source.n_children]; + for (int64_t i = 0; i < source.n_children; i++) { + out->children[i] = new ArrowSchema(); + auto rc = ArrowSchemaDeepCopy(*source.children[i], out->children[i], error); + if (rc != 0) { + for (int64_t j = 0; j <= i; j++) { + if (out->children[j]->release) { + out->children[j]->release(out->children[j]); + } + delete out->children[j]; + } + delete[] out->children; + out->children = nullptr; + out->n_children = 0; + // Release the partially constructed node + delete data; + out->private_data = nullptr; + out->release = nullptr; + return rc; + } + } + } else { + out->children = nullptr; + } + + if (source.dictionary) { + out->dictionary = new ArrowSchema(); + auto rc = ArrowSchemaDeepCopy(*source.dictionary, out->dictionary, error); + if (rc != 0) { + delete out->dictionary; + out->dictionary = nullptr; + return rc; + } + } + } catch (std::exception &e) { + error = e.what(); + return -1; + } + return 0; +} + // Wraps pre-built Arrow arrays from an ArrowQueryResult into an ArrowArrayStream. // This avoids the double-materialization that happens when using ResultArrowArrayStreamWrapper // with an ArrowQueryResult (which throws NotImplementedException from FetchInternal). +// +// The schema is cached eagerly in the constructor (while the ClientContext is still alive) +// so that get_schema can be called after the originating connection has been destroyed. +// ToArrowSchema needs a live ClientContext for transaction access and catalog lookups +// (e.g. CRS conversion for GEOMETRY types). struct ArrowQueryResultStreamWrapper { ArrowQueryResultStreamWrapper(unique_ptr result_p) : result(std::move(result_p)), index(0) { auto &arrow_result = result->Cast(); arrays = arrow_result.ConsumeArrays(); - types = result->types; - names = result->names; - client_properties = result->client_properties; + + cached_schema.release = nullptr; + ArrowConverter::ToArrowSchema(&cached_schema, result->types, result->names, result->client_properties); stream.private_data = this; stream.get_schema = GetSchema; @@ -514,19 +632,18 @@ struct ArrowQueryResultStreamWrapper { stream.get_last_error = GetLastError; } + ~ArrowQueryResultStreamWrapper() { + if (cached_schema.release) { + cached_schema.release(&cached_schema); + } + } + static int GetSchema(ArrowArrayStream *stream, ArrowSchema *out) { if (!stream->release) { return -1; } auto self = reinterpret_cast(stream->private_data); - out->release = nullptr; - try { - ArrowConverter::ToArrowSchema(out, self->types, self->names, self->client_properties); - } catch (std::runtime_error &e) { - self->last_error = e.what(); - return -1; - } - return 0; + return ArrowSchemaDeepCopy(self->cached_schema, out, self->last_error); } static int GetNext(ArrowArrayStream *stream, ArrowArray *out) { @@ -563,14 +680,89 @@ struct ArrowQueryResultStreamWrapper { ArrowArrayStream stream; unique_ptr result; vector> arrays; - vector types; - vector names; - ClientProperties client_properties; + ArrowSchema cached_schema; idx_t index; string last_error; }; -// Destructor for capsules that own a heap-allocated ArrowArrayStream (slow path). +// Wraps an ArrowArrayStream and caches its schema eagerly. +// Used for the slow path (MaterializedQueryResult / StreamQueryResult) where the +// inner stream is a ResultArrowArrayStreamWrapper from DuckDB core. That wrapper's +// get_schema calls ToArrowSchema which needs a live ClientContext, so we fetch it +// once at construction time and return copies from cache afterwards. +struct SchemaCachingStreamWrapper { + SchemaCachingStreamWrapper(ArrowArrayStream inner_p) : inner(inner_p) { + inner_p.release = nullptr; + + cached_schema.release = nullptr; + if (inner.get_schema(&inner, &cached_schema)) { + schema_error = inner.get_last_error(&inner); + schema_ok = false; + } else { + schema_ok = true; + } + + stream.private_data = this; + stream.get_schema = GetSchema; + stream.get_next = GetNext; + stream.release = Release; + stream.get_last_error = GetLastError; + } + + ~SchemaCachingStreamWrapper() { + if (cached_schema.release) { + cached_schema.release(&cached_schema); + } + if (inner.release) { + inner.release(&inner); + } + } + + static int GetSchema(ArrowArrayStream *stream, ArrowSchema *out) { + if (!stream->release) { + return -1; + } + auto self = reinterpret_cast(stream->private_data); + if (!self->schema_ok) { + return -1; + } + return ArrowSchemaDeepCopy(self->cached_schema, out, self->schema_error); + } + + static int GetNext(ArrowArrayStream *stream, ArrowArray *out) { + if (!stream->release) { + return -1; + } + auto self = reinterpret_cast(stream->private_data); + return self->inner.get_next(&self->inner, out); + } + + static void Release(ArrowArrayStream *stream) { + if (!stream || !stream->release) { + return; + } + stream->release = nullptr; + delete reinterpret_cast(stream->private_data); + } + + static const char *GetLastError(ArrowArrayStream *stream) { + if (!stream->release) { + return "stream was released"; + } + auto self = reinterpret_cast(stream->private_data); + if (!self->schema_error.empty()) { + return self->schema_error.c_str(); + } + return self->inner.get_last_error(&self->inner); + } + + ArrowArrayStream stream; + ArrowArrayStream inner; + ArrowSchema cached_schema; + bool schema_ok; + string schema_error; +}; + static void ArrowArrayStreamPyCapsuleDestructor(PyObject *object) { auto data = PyCapsule_GetPointer(object, "arrow_array_stream"); if (!data) { @@ -586,19 +778,19 @@ static void ArrowArrayStreamPyCapsuleDestructor(PyObject *object) { py::object DuckDBPyResult::FetchArrowCapsule(idx_t rows_per_batch) { if (result && result->type == QueryResultType::ARROW_RESULT) { // Fast path: yield pre-built Arrow arrays directly. - // The wrapper is heap-allocated; Release() deletes it via private_data. - // We heap-allocate a separate ArrowArrayStream for the capsule so that the capsule - // holds a stable pointer even after the wrapper is consumed and deleted by a scan. auto wrapper = new ArrowQueryResultStreamWrapper(std::move(result)); auto stream = new ArrowArrayStream(); *stream = wrapper->stream; wrapper->stream.release = nullptr; return py::capsule(stream, "arrow_array_stream", ArrowArrayStreamPyCapsuleDestructor); } - // Existing slow path for MaterializedQueryResult / StreamQueryResult - auto stream_p = FetchArrowArrayStream(rows_per_batch); + // Slow path: wrap in SchemaCachingStreamWrapper so the schema is fetched + // eagerly while the ClientContext is still alive. + auto inner_stream = FetchArrowArrayStream(rows_per_batch); + auto wrapper = new SchemaCachingStreamWrapper(inner_stream); auto stream = new ArrowArrayStream(); - *stream = stream_p; + *stream = wrapper->stream; + wrapper->stream.release = nullptr; return py::capsule(stream, "arrow_array_stream", ArrowArrayStreamPyCapsuleDestructor); } diff --git a/tests/fast/arrow/test_arrow_connection_lifetime.py b/tests/fast/arrow/test_arrow_connection_lifetime.py new file mode 100644 index 00000000..68a43c4e --- /dev/null +++ b/tests/fast/arrow/test_arrow_connection_lifetime.py @@ -0,0 +1,45 @@ +"""Tests that Arrow streams remain valid after their originating connection is destroyed. + +The Arrow PyCapsule paths produce lazy streams — schema and data are consumed +later. If the stream wrapper holds only a non-owning pointer to the +ClientContext and the connection is GC'd in between, the pointer dangles and we +crash (mutex-lock-on-destroyed-object). + +Each test creates a capsule from a short-lived connection, destroys that +connection, then consumes the capsule from a *different* connection. +""" + +import gc + +import pytest + +import duckdb + +pa = pytest.importorskip("pyarrow") + +EXPECTED = [(i, i + 1, -i) for i in range(100)] +SQL = "SELECT i, i + 1 AS j, -i AS k FROM range(100) t(i)" + + +class TestArrowConnectionLifetime: + """Capsules must stay valid after the originating connection is destroyed.""" + + def test_capsule_fast_path_survives_connection_gc(self): + """__arrow_c_stream__ fast path (ArrowQueryResult): connection destroyed before capsule is consumed.""" + conn = duckdb.connect() + capsule = conn.sql(SQL).__arrow_c_stream__() # noqa: F841 + del conn + gc.collect() + result = duckdb.connect().sql("SELECT * FROM capsule").fetchall() + assert result == EXPECTED + + def test_capsule_slow_path_survives_connection_gc(self): + """__arrow_c_stream__ slow path (MaterializedQueryResult): connection destroyed before capsule is consumed.""" + conn = duckdb.connect() + rel = conn.sql(SQL) + rel.execute() # forces MaterializedQueryResult, not ArrowQueryResult + capsule = rel.__arrow_c_stream__() # noqa: F841 + del rel, conn + gc.collect() + result = duckdb.connect().sql("SELECT * FROM capsule").fetchall() + assert result == EXPECTED