Skip to content

Commit 6709abb

Browse files
authored
Allow function pointer extraction from overloaded functions (#2944)
* Add a failure test for overloaded functions * Allow function pointer extraction from overloaded functions
1 parent e0c1dad commit 6709abb

File tree

3 files changed

+17
-5
lines changed

3 files changed

+17
-5
lines changed

include/pybind11/functional.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,17 @@ struct type_caster<std::function<Return(Args...)>> {
4646
auto c = reinterpret_borrow<capsule>(PyCFunction_GET_SELF(cfunc.ptr()));
4747
auto rec = (function_record *) c;
4848

49-
if (rec && rec->is_stateless &&
50-
same_type(typeid(function_type), *reinterpret_cast<const std::type_info *>(rec->data[1]))) {
51-
struct capture { function_type f; };
52-
value = ((capture *) &rec->data)->f;
53-
return true;
49+
while (rec != nullptr) {
50+
if (rec->is_stateless
51+
&& same_type(typeid(function_type),
52+
*reinterpret_cast<const std::type_info *>(rec->data[1]))) {
53+
struct capture {
54+
function_type f;
55+
};
56+
value = ((capture *) &rec->data)->f;
57+
return true;
58+
}
59+
rec = rec->next;
5460
}
5561
}
5662

tests/test_callbacks.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ TEST_SUBMODULE(callbacks, m) {
9797
// test_cpp_function_roundtrip
9898
/* Test if passing a function pointer from C++ -> Python -> C++ yields the original pointer */
9999
m.def("dummy_function", &dummy_function);
100+
m.def("dummy_function_overloaded", [](int i, int j) { return i + j; });
101+
m.def("dummy_function_overloaded", &dummy_function);
100102
m.def("dummy_function2", [](int i, int j) { return i + j; });
101103
m.def("roundtrip", [](std::function<int(int)> f, bool expect_none = false) {
102104
if (expect_none && f)

tests/test_callbacks.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ def test_cpp_function_roundtrip():
9393
m.test_dummy_function(m.roundtrip(m.dummy_function))
9494
== "matches dummy_function: eval(1) = 2"
9595
)
96+
assert (
97+
m.test_dummy_function(m.dummy_function_overloaded)
98+
== "matches dummy_function: eval(1) = 2"
99+
)
96100
assert m.roundtrip(None, expect_none=True) is None
97101
assert (
98102
m.test_dummy_function(lambda x: x + 2)

0 commit comments

Comments
 (0)