Skip to content

Commit 98e50b7

Browse files
virtualdrwgk
andauthored
fix: avoid copy constructor instantiation in shared_ptr fallback cast (#6028)
* tests: add regressions for shared_ptr reference_internal fallback * fix: avoid copy constructor instantiation in shared_ptr fallback cast * Remove stray empty line * tests: rename PyTorch shared_ptr regression test files * refactor: add cast_non_owning helper for reference-like casts Name the non-owning generic cast path so callers do not have to rediscover that reference-like policies must pass null copy/move constructor callbacks. This keeps the shared_ptr reference_internal fallback self-documenting and points future maintainers toward the safe API. Made-with: Cursor * tests: guard deprecated-copy warning probes with __has_warning Use __has_warning for the Clang-only regression test so older compiler jobs skip unsupported warning groups instead of failing with -Wunknown-warning-option. A simple __clang_major__ >= 13 guard would be shorter, but it bakes in a version cutoff; __has_warning is slightly more verbose while being more robust to vendor builds, backports, and future packaging differences. Made-with: Cursor --------- Co-authored-by: Ralf W. Grosse-Kunstleve <rgrossekunst@nvidia.com>
1 parent bfd6cbd commit 98e50b7

File tree

6 files changed

+113
-1
lines changed

6 files changed

+113
-1
lines changed

include/pybind11/cast.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1027,7 +1027,7 @@ struct copyable_holder_caster<
10271027
}
10281028

10291029
if (parent) {
1030-
return type_caster_base<type>::cast(
1030+
return type_caster_generic::cast_non_owning(
10311031
srcs, return_value_policy::reference_internal, parent);
10321032
}
10331033

include/pybind11/detail/type_caster_base.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,6 +1004,18 @@ class type_caster_generic {
10041004
return cast(srcs, policy, parent, copy_constructor, move_constructor, existing_holder);
10051005
}
10061006

1007+
static handle cast_non_owning(const cast_sources &srcs,
1008+
return_value_policy policy,
1009+
handle parent,
1010+
const void *existing_holder = nullptr) {
1011+
// Reference-like policies alias an existing C++ object instead of creating
1012+
// a new one, so copy/move constructor callbacks must remain null here.
1013+
assert(policy == return_value_policy::reference
1014+
|| policy == return_value_policy::reference_internal
1015+
|| policy == return_value_policy::automatic_reference);
1016+
return cast(srcs, policy, parent, nullptr, nullptr, existing_holder);
1017+
}
1018+
10071019
PYBIND11_NOINLINE static handle cast(const cast_sources &srcs,
10081020
return_value_policy policy,
10091021
handle parent,

tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ set(PYBIND11_TEST_FILES
167167
test_operator_overloading
168168
test_pickling
169169
test_potentially_slicing_weak_ptr
170+
test_pytorch_shared_ptr_cast_regression
170171
test_python_multiple_inheritance
171172
test_pytypes
172173
test_scoped_critical_section

tests/test_class_sh_property.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,15 @@ def test_non_smart_holder_member_type_with_smart_holder_owner_aliases_member():
204204
legacy = obj.legacy
205205
legacy.value = 13
206206
assert obj.legacy.value == 13
207+
208+
209+
def test_non_smart_holder_member_type_with_smart_holder_owner_aliases_member_multiple_reads():
210+
obj = m.ShWithSimpleStructMember()
211+
212+
a = obj.legacy
213+
b = obj.legacy
214+
215+
a.value = 13
216+
217+
assert b.value == 13
218+
assert obj.legacy.value == 13
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#include "pybind11_tests.h"
2+
3+
#include <memory>
4+
#include <string>
5+
6+
#if defined(__clang__)
7+
# if __has_warning("-Wdeprecated-copy-with-user-provided-dtor")
8+
# pragma clang diagnostic error "-Wdeprecated-copy-with-user-provided-dtor"
9+
# endif
10+
# if __has_warning("-Wdeprecated-copy-with-dtor")
11+
# pragma clang diagnostic error "-Wdeprecated-copy-with-dtor"
12+
# endif
13+
#endif
14+
15+
namespace test_pytorch_regressions {
16+
17+
// Directly extracted from PyTorch patterns that regressed in CI.
18+
struct TracingState : std::enable_shared_from_this<TracingState> {
19+
TracingState() = default;
20+
~TracingState() = default;
21+
int value = 0;
22+
};
23+
24+
const std::shared_ptr<TracingState> &get_tracing_state() {
25+
static std::shared_ptr<TracingState> state = std::make_shared<TracingState>();
26+
return state;
27+
}
28+
29+
struct InterfaceType {
30+
~InterfaceType() = default;
31+
int value = 0;
32+
};
33+
using InterfaceTypePtr = std::shared_ptr<InterfaceType>;
34+
35+
struct CompilationUnit {
36+
InterfaceTypePtr iface = std::make_shared<InterfaceType>();
37+
38+
InterfaceTypePtr get_interface(const std::string &) const { return iface; }
39+
};
40+
41+
} // namespace test_pytorch_regressions
42+
43+
TEST_SUBMODULE(pybind11_pytorch_regressions, m) {
44+
using namespace test_pytorch_regressions;
45+
46+
py::class_<TracingState, std::shared_ptr<TracingState>>(m, "TracingState")
47+
.def(py::init<>())
48+
.def_readwrite("value", &TracingState::value);
49+
50+
m.def("_get_tracing_state", []() { return get_tracing_state(); });
51+
52+
py::class_<InterfaceType, InterfaceTypePtr>(m, "InterfaceType")
53+
.def(py::init<>())
54+
.def_readwrite("value", &InterfaceType::value);
55+
56+
py::class_<CompilationUnit, std::shared_ptr<CompilationUnit>>(m, "CompilationUnit")
57+
.def(py::init<>())
58+
.def("get_interface",
59+
[](const std::shared_ptr<CompilationUnit> &self, const std::string &name) {
60+
return self->get_interface(name);
61+
});
62+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from __future__ import annotations
2+
3+
from pybind11_tests import pybind11_pytorch_regressions as m
4+
5+
6+
def test_pytorch_like_get_tracing_state_aliases_singleton_shared_ptr():
7+
a = m._get_tracing_state()
8+
b = m._get_tracing_state()
9+
10+
a.value = 17
11+
12+
assert b.value == 17
13+
assert m._get_tracing_state().value == 17
14+
15+
16+
def test_pytorch_like_compilation_unit_get_interface_aliases_member_shared_ptr():
17+
cu = m.CompilationUnit()
18+
19+
a = cu.get_interface("iface")
20+
b = cu.get_interface("iface")
21+
22+
a.value = 23
23+
24+
assert b.value == 23
25+
assert cu.get_interface("iface").value == 23

0 commit comments

Comments
 (0)