Skip to content

Commit 7094be5

Browse files
Cache arrow object type to avoid repeated detection (#321)
This caches the arrow type to avoid detecting it every time a stream is created.
2 parents 02c9283 + 05a1ec6 commit 7094be5

4 files changed

Lines changed: 11 additions & 8 deletions

File tree

src/duckdb_py/arrow/arrow_array_stream.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ unique_ptr<ArrowArrayStreamWrapper> PythonTableArrowArrayStreamFactory::Produce(
6464
auto factory = static_cast<PythonTableArrowArrayStreamFactory *>(reinterpret_cast<void *>(factory_ptr)); // NOLINT
6565
D_ASSERT(factory->arrow_object);
6666
py::handle arrow_obj_handle(factory->arrow_object);
67-
auto arrow_object_type = DuckDBPyConnection::GetArrowType(arrow_obj_handle);
67+
auto arrow_object_type = factory->cached_arrow_type;
6868

69-
if (arrow_object_type == PyArrowObjectType::PyCapsuleInterface) {
69+
if (arrow_object_type == PyArrowObjectType::PyCapsuleInterface || arrow_object_type == PyArrowObjectType::Table) {
7070
py::object capsule_obj = arrow_obj_handle.attr("__arrow_c_stream__")();
7171
auto capsule = py::reinterpret_borrow<py::capsule>(capsule_obj);
7272
auto stream = capsule.get_pointer<struct ArrowArrayStream>();
@@ -181,8 +181,8 @@ void PythonTableArrowArrayStreamFactory::GetSchema(uintptr_t factory_ptr, ArrowS
181181
D_ASSERT(factory->arrow_object);
182182
py::handle arrow_obj_handle(factory->arrow_object);
183183

184-
auto type = DuckDBPyConnection::GetArrowType(arrow_obj_handle);
185-
if (type == PyArrowObjectType::PyCapsuleInterface) {
184+
auto type = factory->cached_arrow_type;
185+
if (type == PyArrowObjectType::PyCapsuleInterface || type == PyArrowObjectType::Table) {
186186
// Get __arrow_c_schema__ if it exists
187187
if (py::hasattr(arrow_obj_handle, "__arrow_c_schema__")) {
188188
auto schema_capsule = arrow_obj_handle.attr("__arrow_c_schema__")();

src/duckdb_py/include/duckdb_python/arrow/arrow_array_stream.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,9 @@ PyArrowObjectType GetArrowType(const py::handle &obj);
5959

6060
class PythonTableArrowArrayStreamFactory {
6161
public:
62-
explicit PythonTableArrowArrayStreamFactory(PyObject *arrow_table, const ClientProperties &client_properties_p)
63-
: arrow_object(arrow_table), client_properties(client_properties_p) {};
62+
explicit PythonTableArrowArrayStreamFactory(PyObject *arrow_table, const ClientProperties &client_properties_p,
63+
PyArrowObjectType arrow_type_p)
64+
: arrow_object(arrow_table), client_properties(client_properties_p), cached_arrow_type(arrow_type_p) {};
6465

6566
//! Produces an Arrow Scanner, should be only called once when initializing Scan States
6667
static unique_ptr<ArrowArrayStreamWrapper> Produce(uintptr_t factory, ArrowStreamParameters &parameters);
@@ -73,6 +74,7 @@ class PythonTableArrowArrayStreamFactory {
7374
PyObject *arrow_object;
7475

7576
const ClientProperties client_properties;
77+
const PyArrowObjectType cached_arrow_type;
7678

7779
private:
7880
static py::object ProduceScanner(py::object &arrow_scanner, py::handle &arrow_obj_handle,

src/duckdb_py/python_replacement_scan.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ static void CreateArrowScan(const string &name, py::object entry, TableFunctionR
5151
auto dependency_item = PythonDependencyItem::Create(stream_messages);
5252
external_dependency->AddDependency("replacement_cache", std::move(dependency_item));
5353
} else {
54-
auto stream_factory = make_uniq<PythonTableArrowArrayStreamFactory>(entry.ptr(), client_properties);
54+
auto stream_factory = make_uniq<PythonTableArrowArrayStreamFactory>(entry.ptr(), client_properties, type);
5555
auto stream_factory_produce = PythonTableArrowArrayStreamFactory::Produce;
5656
auto stream_factory_get_schema = PythonTableArrowArrayStreamFactory::GetSchema;
5757

src/duckdb_py/python_udf.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ static void ConvertArrowTableToVector(const py::object &table, Vector &out, Clie
7474
D_ASSERT(py::gil_check());
7575
py::gil_scoped_release gil;
7676

77-
auto stream_factory = make_uniq<PythonTableArrowArrayStreamFactory>(ptr, context.GetClientProperties());
77+
auto stream_factory =
78+
make_uniq<PythonTableArrowArrayStreamFactory>(ptr, context.GetClientProperties(), PyArrowObjectType::Table);
7879
auto stream_factory_produce = PythonTableArrowArrayStreamFactory::Produce;
7980
auto stream_factory_get_schema = PythonTableArrowArrayStreamFactory::GetSchema;
8081

0 commit comments

Comments
 (0)