Skip to content

Commit 7c09a6b

Browse files
committed
Avoid heap allocation for function calls with a small number of arguments
We don't have access to llvm::SmallVector or similar, but given the limited subset of the `std::vector` API that `function_call::args{,_convert}` need and the "reserve-then-fill" usage pattern, it is relatively straightforward to implement custom containers that get the job done. Seems to improves time to call the collatz function in pybind/pybind11_benchmark significantly; numbers are a little noisy but there's a clear improvement from "about 60 ns per call" to "about 45 ns per call" on my machine (M4 Max Mac), as measured with `timeit.repeat('collatz(4)', 'from pybind11_benchmark import collatz')`.
1 parent bf2d56e commit 7c09a6b

File tree

6 files changed

+396
-5
lines changed

6 files changed

+396
-5
lines changed

include/pybind11/cast.h

Lines changed: 225 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <tuple>
2828
#include <type_traits>
2929
#include <utility>
30+
#include <variant>
3031
#include <vector>
3132

3233
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
@@ -2037,6 +2038,226 @@ using is_pos_only = std::is_same<intrinsic_t<T>, pos_only>;
20372038
// forward declaration (definition in attr.h)
20382039
struct function_record;
20392040

2041+
// small_vector-like container to avoid heap allocation for N or fewer
2042+
// arguments.
2043+
template <std::size_t N>
2044+
struct argument_vector {
2045+
struct handle_array {
2046+
std::array<handle, N> arr;
2047+
std::size_t size = 0;
2048+
};
2049+
2050+
using handle_vector = std::vector<handle>;
2051+
2052+
public:
2053+
argument_vector() = default;
2054+
2055+
argument_vector(const argument_vector &) = delete;
2056+
argument_vector &operator=(const argument_vector &) = delete;
2057+
argument_vector(argument_vector &&) noexcept = default;
2058+
argument_vector &operator=(argument_vector &&) noexcept = default;
2059+
2060+
std::size_t size() const {
2061+
if (std::holds_alternative<handle_array>(repr_)) {
2062+
return std::get<handle_array>(repr_).size;
2063+
} else if (std::holds_alternative<handle_vector>(repr_)) {
2064+
return std::get<handle_vector>(repr_).size();
2065+
} else {
2066+
// NOTE: we know that our variant is never going to be
2067+
// valueless_by_exception() because the relevant move operations
2068+
// don't throw. Marking this case unreachable allows better code
2069+
// generation.
2070+
PYBIND11_UNREACHABLE();
2071+
}
2072+
}
2073+
2074+
handle &operator[](std::size_t idx) {
2075+
assert(idx < size());
2076+
if (std::holds_alternative<handle_array>(repr_)) {
2077+
return std::get<handle_array>(repr_).arr[idx];
2078+
} else if (std::holds_alternative<handle_vector>(repr_)) {
2079+
return std::get<handle_vector>(repr_)[idx];
2080+
} else {
2081+
PYBIND11_UNREACHABLE();
2082+
}
2083+
}
2084+
2085+
handle operator[](std::size_t idx) const {
2086+
assert(idx < size());
2087+
if (std::holds_alternative<handle_array>(repr_)) {
2088+
return std::get<handle_array>(repr_).arr[idx];
2089+
} else if (std::holds_alternative<handle_vector>(repr_)) {
2090+
return std::get<handle_vector>(repr_)[idx];
2091+
} else {
2092+
PYBIND11_UNREACHABLE();
2093+
}
2094+
}
2095+
2096+
void push_back(handle x) {
2097+
if (std::holds_alternative<handle_array>(repr_)) {
2098+
auto &ha = std::get<handle_array>(repr_);
2099+
if (ha.size == N) {
2100+
move_to_vector_with_reserved_size(N + 1);
2101+
std::get<handle_vector>(repr_).push_back(x);
2102+
} else {
2103+
ha.arr[ha.size++] = x;
2104+
}
2105+
} else if (std::holds_alternative<handle_vector>(repr_)) {
2106+
std::get<handle_vector>(repr_).push_back(x);
2107+
} else {
2108+
PYBIND11_UNREACHABLE();
2109+
}
2110+
}
2111+
2112+
template <typename Arg>
2113+
void emplace_back(Arg &&x) {
2114+
push_back(handle(x));
2115+
}
2116+
2117+
void reserve(std::size_t sz) {
2118+
if (std::holds_alternative<handle_array>(repr_)) {
2119+
if (sz > N) {
2120+
move_to_vector_with_reserved_size(sz);
2121+
}
2122+
} else if (std::holds_alternative<handle_vector>(repr_)) {
2123+
std::get<handle_vector>(repr_).reserve(sz);
2124+
} else {
2125+
PYBIND11_UNREACHABLE();
2126+
}
2127+
}
2128+
2129+
private:
2130+
void move_to_vector_with_reserved_size(std::size_t reserved_size) {
2131+
auto &ha = std::get<handle_array>(repr_);
2132+
handle_vector hv;
2133+
hv.reserve(reserved_size);
2134+
std::copy(ha.arr.begin(), ha.arr.begin() + ha.size, std::back_inserter(hv));
2135+
repr_ = std::move(hv);
2136+
}
2137+
std::variant<handle_array, handle_vector> repr_;
2138+
};
2139+
2140+
// small_vector-like container to avoid heap allocation for N or fewer
2141+
// arguments.
2142+
template <std::size_t kRequestedInlineSize>
2143+
struct args_convert_vector {
2144+
private:
2145+
static constexpr auto kBitsPerWord = 8 * sizeof(std::size_t);
2146+
static constexpr auto kWords = (kRequestedInlineSize + kBitsPerWord - 1) / kBitsPerWord;
2147+
static constexpr auto kInlineSize = kWords * kBitsPerWord;
2148+
struct inline_array {
2149+
std::array<std::size_t, kWords> arr;
2150+
std::size_t size = 0;
2151+
};
2152+
using heap_array = std::vector<bool>;
2153+
2154+
public:
2155+
args_convert_vector() = default;
2156+
2157+
args_convert_vector(const args_convert_vector &) = delete;
2158+
args_convert_vector &operator=(const args_convert_vector &) = delete;
2159+
args_convert_vector(args_convert_vector &&) noexcept = default;
2160+
args_convert_vector &operator=(args_convert_vector &&) noexcept = default;
2161+
2162+
args_convert_vector(std::size_t count, bool value) {
2163+
if (count > kInlineSize) {
2164+
repr_ = heap_array(count, value);
2165+
} else {
2166+
auto &inline_arr = std::get<inline_array>(repr_);
2167+
inline_arr.arr.fill(value ? std::size_t(-1) : 0);
2168+
inline_arr.size = count;
2169+
}
2170+
}
2171+
2172+
std::size_t size() const {
2173+
if (std::holds_alternative<inline_array>(repr_)) {
2174+
return std::get<inline_array>(repr_).size;
2175+
} else if (std::holds_alternative<heap_array>(repr_)) {
2176+
return std::get<heap_array>(repr_).size();
2177+
} else {
2178+
PYBIND11_UNREACHABLE();
2179+
}
2180+
}
2181+
2182+
void reserve(std::size_t sz) {
2183+
if (std::holds_alternative<inline_array>(repr_)) {
2184+
if (sz > kInlineSize) {
2185+
move_to_vector_with_reserved_size(sz);
2186+
}
2187+
} else if (std::holds_alternative<heap_array>(repr_)) {
2188+
std::get<heap_array>(repr_).reserve(sz);
2189+
} else {
2190+
PYBIND11_UNREACHABLE();
2191+
}
2192+
}
2193+
2194+
bool operator[](std::size_t idx) const {
2195+
if (std::holds_alternative<inline_array>(repr_)) {
2196+
return inline_index(idx);
2197+
} else if (std::holds_alternative<heap_array>(repr_)) {
2198+
assert(idx < std::get<heap_array>(repr_).size());
2199+
return std::get<heap_array>(repr_)[idx];
2200+
} else {
2201+
PYBIND11_UNREACHABLE();
2202+
}
2203+
}
2204+
2205+
void push_back(bool b) {
2206+
if (std::holds_alternative<inline_array>(repr_)) {
2207+
auto &ha = std::get<inline_array>(repr_);
2208+
if (ha.size == kInlineSize) {
2209+
move_to_vector_with_reserved_size(kInlineSize + 1);
2210+
std::get<heap_array>(repr_).push_back(b);
2211+
} else {
2212+
assert(ha.size < kInlineSize);
2213+
const auto wbi = word_and_bit_index(ha.size++);
2214+
assert(wbi.word < kWords);
2215+
assert(wbi.bit < kBitsPerWord);
2216+
if (b) {
2217+
ha.arr[wbi.word] |= (std::size_t(1) << wbi.bit);
2218+
} else {
2219+
ha.arr[wbi.word] &= ~(std::size_t(1) << wbi.bit);
2220+
}
2221+
assert(operator[](ha.size - 1) == b);
2222+
}
2223+
} else if (std::holds_alternative<heap_array>(repr_)) {
2224+
std::get<heap_array>(repr_).push_back(b);
2225+
} else {
2226+
PYBIND11_UNREACHABLE();
2227+
}
2228+
}
2229+
2230+
void swap(args_convert_vector &rhs) { std::swap(repr_, rhs.repr_); }
2231+
2232+
private:
2233+
struct WordAndBitIndex {
2234+
std::size_t word;
2235+
std::size_t bit;
2236+
};
2237+
2238+
static auto word_and_bit_index(std::size_t idx) {
2239+
return WordAndBitIndex{idx / kBitsPerWord, idx % kBitsPerWord};
2240+
}
2241+
2242+
bool inline_index(std::size_t idx) const {
2243+
const auto wbi = word_and_bit_index(idx);
2244+
assert(wbi.word < kWords);
2245+
assert(wbi.bit < kBitsPerWord);
2246+
return std::get<inline_array>(repr_).arr[wbi.word] & (std::size_t(1) << wbi.bit);
2247+
}
2248+
2249+
void move_to_vector_with_reserved_size(std::size_t reserved_size) {
2250+
auto &inline_arr = std::get<inline_array>(repr_);
2251+
heap_array hp;
2252+
hp.reserve(reserved_size);
2253+
for (std::size_t ii = 0; ii < inline_arr.size; ++ii) {
2254+
hp.push_back(inline_index(ii));
2255+
}
2256+
repr_ = std::move(hp);
2257+
}
2258+
std::variant<inline_array, heap_array> repr_;
2259+
};
2260+
20402261
/// Internal data associated with a single function call
20412262
struct function_call {
20422263
function_call(const function_record &f, handle p); // Implementation in attr.h
@@ -2045,10 +2266,12 @@ struct function_call {
20452266
const function_record &func;
20462267

20472268
/// Arguments passed to the function:
2048-
std::vector<handle> args;
2269+
/// (Inline size chosen mostly arbitrarily; 5 should pad function_call out to two cache lines
2270+
/// (16 pointers) in size.)
2271+
argument_vector<5> args;
20492272

20502273
/// The `convert` value the arguments should be loaded with
2051-
std::vector<bool> args_convert;
2274+
args_convert_vector<5> args_convert;
20522275

20532276
/// Extra references for the optional `py::args` and/or `py::kwargs` arguments (which, if
20542277
/// present, are also in `args` but without a reference).

include/pybind11/detail/common.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,18 @@
181181
# define PYBIND11_MAYBE_UNUSED __attribute__((__unused__))
182182
#endif
183183

184+
#if defined(__GNUC__)
185+
# define PYBIND11_UNREACHABLE() \
186+
assert(false); \
187+
__builtin_unreachable()
188+
#elif defined(_MSC_VER)
189+
# define PYBIND11_UNREACHABLE() \
190+
assert(false); \
191+
__assume(0)
192+
#else
193+
# define PYBIND11_UNREACHABLE() assert(false);
194+
#endif
195+
184196
// https://en.cppreference.com/w/c/chrono/localtime
185197
#if defined(__STDC_LIB_EXT1__) && !defined(__STDC_WANT_LIB_EXT1__)
186198
# define __STDC_WANT_LIB_EXT1__

include/pybind11/pybind11.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,12 +1048,12 @@ class cpp_function : public function {
10481048
}
10491049
#endif
10501050

1051-
std::vector<bool> second_pass_convert;
1051+
args_convert_vector<5> second_pass_convert;
10521052
if (overloaded) {
10531053
// We're in the first no-convert pass, so swap out the conversion flags for a
10541054
// set of all-false flags. If the call fails, we'll swap the flags back in for
10551055
// the conversion-allowed call below.
1056-
second_pass_convert.resize(func.nargs, false);
1056+
second_pass_convert = args_convert_vector<5>(func.nargs, false);
10571057
call.args_convert.swap(second_pass_convert);
10581058
}
10591059

tests/test_embed/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ if(PYBIND11_TEST_SMART_HOLDER)
3333
-DPYBIND11_RUN_TESTING_WITH_SMART_HOLDER_AS_DEFAULT_BUT_NEVER_USE_IN_PRODUCTION_PLEASE)
3434
endif()
3535

36-
add_executable(test_embed catch.cpp test_interpreter.cpp test_subinterpreter.cpp)
36+
add_executable(test_embed catch.cpp test_args_convert_vector.cpp test_argument_vector.cpp
37+
test_interpreter.cpp test_subinterpreter.cpp)
3738
pybind11_enable_warnings(test_embed)
3839

3940
target_link_libraries(test_embed PRIVATE pybind11::embed Catch2::Catch2 Threads::Threads)
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
#include "pybind11/pybind11.h"
2+
#include "catch.hpp"
3+
4+
namespace py = pybind11;
5+
6+
// 5 is chosen to match the production value. It doesn't really matter
7+
// what we pick because the actual inline size is going to be 8 *
8+
// sizeof(size_t) given the implementation at the time of writing.
9+
using args_convert_vector = py::detail::args_convert_vector<5>;
10+
11+
namespace {
12+
template <typename Container>
13+
std::vector<Container> get_sample_vectors() {
14+
std::vector<Container> result;
15+
result.emplace_back();
16+
for (const auto sz : {0, 4, 5, 6, 31, 32, 33, 63, 64, 65}) {
17+
for (const bool b : {false, true}) {
18+
result.emplace_back(static_cast<std::size_t>(sz), b);
19+
}
20+
}
21+
return result;
22+
}
23+
24+
void require_vector_matches_sample(const args_convert_vector &actual,
25+
const std::vector<bool> &expected) {
26+
REQUIRE(actual.size() == expected.size());
27+
for (size_t ii = 0; ii < actual.size(); ++ii) {
28+
REQUIRE(actual[ii] == expected[ii]);
29+
}
30+
}
31+
32+
template <typename MutationFunc>
33+
void mutation_test_with_samples(MutationFunc mutation_func) {
34+
auto sample_contents = get_sample_vectors<std::vector<bool>>();
35+
auto samples = get_sample_vectors<args_convert_vector>();
36+
for (size_t ii = 0; ii < samples.size(); ++ii) {
37+
auto &actual = samples[ii];
38+
auto &expected = sample_contents[ii];
39+
40+
mutation_func(actual);
41+
mutation_func(expected);
42+
require_vector_matches_sample(actual, expected);
43+
}
44+
}
45+
} // namespace
46+
47+
TEST_CASE("check sample args_convert_vector contents") {
48+
mutation_test_with_samples([](auto &) {});
49+
}
50+
51+
TEST_CASE("args_convert_vector push_back") {
52+
for (const bool b : {false, true}) {
53+
mutation_test_with_samples([b](auto &vec) { vec.push_back(b); });
54+
}
55+
}
56+
57+
TEST_CASE("args_convert_vector reserve") {
58+
for (std::size_t ii = 0; ii < 4; ++ii) {
59+
mutation_test_with_samples([ii](auto &vec) { vec.reserve(ii); });
60+
}
61+
}
62+
63+
TEST_CASE("args_convert_vector reserve then push_back") {
64+
for (std::size_t ii = 0; ii < 4; ++ii) {
65+
for (const bool b : {false, true}) {
66+
mutation_test_with_samples([ii, b](auto &vec) {
67+
vec.reserve(ii);
68+
vec.push_back(b);
69+
});
70+
}
71+
}
72+
}

0 commit comments

Comments
 (0)