Skip to content

Commit 83fee78

Browse files
committed
Add pitched allocation example
This is the example I wrote for the user question here: #117 Fixes #248.
1 parent b31a635 commit 83fee78

File tree

3 files changed

+249
-0
lines changed

3 files changed

+249
-0
lines changed

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ add_subdirectory(dot_product)
1616
add_subdirectory(tiled_layout)
1717
add_subdirectory(restrict_accessor)
1818
add_subdirectory(aligned_accessor)
19+
add_subdirectory(pitched_allocation)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
mdspan_add_example(pitched_allocation)
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
#include <experimental/mdspan>
2+
#include <cassert>
3+
#include <cstring>
4+
#include <cstdint>
5+
#include <memory>
6+
7+
// This example shows how to deal with "pitched" allocations. These
8+
// are multidimensional array allocations where the size of each
9+
// element might not necessarily evenly divide the number of bytes per
10+
// "row" of the contiguous dimension. The commented-out example below
11+
// uses cudaMallocPitch to allocate a 4 x 5 two-dimensional array of
12+
// T, where sizeof(T) is 12. Each row (the contiguous dimension) has
13+
// 64 bytes. The last 4 bytes of each row are padding that do not
14+
// participate in an element.
15+
16+
// void* ptr = nullptr;
17+
// size_t pitch = 0;
18+
//
19+
// constexpr size_t num_cols = 5;
20+
// constexpr size_t num_rows = 4;
21+
//
22+
// cudaMallocPitch(&ptr, &pitch, sizeof(T) * num_cols, num_rows);
23+
// extents<int, num_rows, num_cols> exts{};
24+
// layout_stride::mapping mapping{exts, std::array{pitch, sizeof(T)}};
25+
// mdspan m{reinterpret_cast<char*>(ptr), mapping, aligned_byte_accessor<T>{}};
26+
27+
namespace stdex = std::experimental;
28+
29+
// This is the element type. "tbs" stands for Twelve-Byte Struct.
30+
// In this example, the struct includes a mixture of float and int,
31+
// just to make aliasing more interesting.
32+
struct tbs {
33+
float f0 = 0.0f;
34+
std::int32_t i = 0;
35+
float f1 = 0.0f;
36+
};
37+
38+
// Use of the proxy reference types is only required
39+
// if access to each element is not aligned.
40+
// That should not be the case here.
41+
42+
class const_tbs_proxy;
43+
class nonconst_tbs_proxy;
44+
45+
template<class T>
46+
class const_proxy {
47+
private:
48+
friend class const_tbs_proxy;
49+
constexpr const_proxy(const char* p) noexcept
50+
: p_(p)
51+
{}
52+
53+
public:
54+
// Not constexpr because of reinterpret_cast or memcpy
55+
operator T () const noexcept {
56+
// We can't do the commented-out reinterpret_cast
57+
// in Standard C++, because p_ might not have correct
58+
// alignment to point to a T.
59+
//
60+
//return *reinterpret_cast<const T*>(p_);
61+
62+
T f;
63+
std::memcpy(&f, p_, sizeof(T));
64+
return f;
65+
}
66+
67+
private:
68+
const char* p_ = nullptr;
69+
};
70+
71+
template<class T>
72+
class nonconst_proxy {
73+
private:
74+
friend class nonconst_tbs_proxy;
75+
constexpr nonconst_proxy(char* p) noexcept
76+
: p_(p)
77+
{}
78+
79+
public:
80+
// Not constexpr because of memcpy
81+
nonconst_proxy& operator=(const T& f) noexcept {
82+
std::memcpy(p_, &f, sizeof(T));
83+
return *this;
84+
}
85+
86+
// Not constexpr because of memcpy
87+
operator T () const noexcept {
88+
T f;
89+
std::memcpy(&f, p_, sizeof(T));
90+
return f;
91+
}
92+
93+
private:
94+
char* p_ = nullptr;
95+
};
96+
97+
class nonconst_tbs_proxy {
98+
private:
99+
char* p_ = nullptr;
100+
101+
public:
102+
nonconst_tbs_proxy(char* p) noexcept
103+
: p_(p), f0(p), i(p + sizeof(float)), f1(p + sizeof(float) + sizeof(int))
104+
{}
105+
106+
nonconst_tbs_proxy& operator=(const tbs& s) noexcept {
107+
this->f0 = s.f0;
108+
this->i = s.i;
109+
this->f1 = s.f1;
110+
return *this;
111+
}
112+
113+
operator tbs() const noexcept {
114+
return {float(f0), std::int32_t(i), float(f1)};
115+
};
116+
117+
nonconst_proxy<float> f0;
118+
nonconst_proxy<std::int32_t> i;
119+
nonconst_proxy<float> f1;
120+
};
121+
122+
// tbs is a struct, so users want to access its fields
123+
// with the usual dot notation. The two proxy reference types,
124+
// const_tbs_proxy and nonconst_tbs_proxy, preserve this behavior.
125+
126+
class const_tbs_proxy {
127+
private:
128+
const char* p_ = nullptr;
129+
130+
public:
131+
constexpr const_tbs_proxy(const char* p) noexcept
132+
: p_(p), f0(p), i(p + sizeof(float)), f1(p + sizeof(float) + sizeof(int))
133+
{}
134+
135+
operator tbs() const noexcept {
136+
return {float(f0), std::int32_t(i), float(f1)};
137+
};
138+
139+
const_proxy<float> f0;
140+
const_proxy<std::int32_t> i;
141+
const_proxy<float> f1;
142+
};
143+
144+
145+
struct const_tbs_accessor {
146+
using offset_policy = const_tbs_accessor;
147+
148+
using data_handle_type = const char*;
149+
using element_type = const tbs;
150+
// In the const reference case, we can use
151+
// either const_tbs_proxy or tbs (a value).
152+
//using reference = const_tbs_proxy;
153+
using reference = tbs;
154+
155+
constexpr const_tbs_accessor() noexcept = default;
156+
157+
// Not constexpr because of memcpy
158+
reference
159+
access(data_handle_type p, size_t i) const noexcept {
160+
//return {p + i * sizeof(tbs)}; // for const_tbs_proxy
161+
tbs t;
162+
std::memcpy(&t, p + i * sizeof(tbs), sizeof(tbs));
163+
return t;
164+
}
165+
166+
constexpr typename offset_policy::data_handle_type
167+
offset(data_handle_type p, size_t i) const noexcept {
168+
return p + i * sizeof(tbs);
169+
}
170+
};
171+
172+
struct nonconst_tbs_accessor {
173+
using offset_policy = nonconst_tbs_accessor;
174+
175+
using data_handle_type = char*;
176+
using element_type = tbs;
177+
using reference = nonconst_tbs_proxy;
178+
179+
constexpr nonconst_tbs_accessor() noexcept = default;
180+
181+
reference
182+
access(data_handle_type p, size_t i) const noexcept {
183+
return {p + i * sizeof(tbs)};
184+
}
185+
186+
constexpr typename offset_policy::data_handle_type
187+
offset(data_handle_type p, size_t i) const noexcept {
188+
return p + i * sizeof(tbs);
189+
}
190+
};
191+
192+
int main() {
193+
constexpr std::size_t num_elements = 5;
194+
195+
std::array<char, num_elements * sizeof(tbs)> data;
196+
auto* ptr = reinterpret_cast<tbs*>(data.data());
197+
198+
std::uninitialized_fill_n(ptr, num_elements, tbs{1.0, 2, 3.0});
199+
200+
for(std::size_t k = 0; k < num_elements; ++k) {
201+
assert(ptr[k].f0 == 1.0);
202+
assert(ptr[k].i == 2);
203+
assert(ptr[k].f1 == 3.0);
204+
}
205+
206+
const tbs* ptr_c = ptr;
207+
stdex::mdspan<const tbs, stdex::extents<int, num_elements>,
208+
stdex::layout_right, const_tbs_accessor> m{data.data()};
209+
for (std::size_t k = 0; k < num_elements; ++k) {
210+
assert(m[k].f0 == 1.0f);
211+
assert(m[k].i == 2);
212+
assert(m[k].f1 == 3.0f);
213+
}
214+
215+
stdex::mdspan<tbs, stdex::extents<int, num_elements>,
216+
stdex::layout_right, nonconst_tbs_accessor> m_nc{data.data()};
217+
for (std::size_t k = 0; k < num_elements; ++k) {
218+
m_nc[k].f0 = 4.0f;
219+
m_nc[k].i = 5;
220+
m_nc[k].f1 = 6.0f;
221+
}
222+
223+
for (std::size_t k = 0; k < num_elements; ++k) {
224+
// Be careful using auto with proxy references. It's fine here,
225+
// because we're not letting the proxy reference escape the scope.
226+
auto m_k = m[k];
227+
assert(m_k.f0 == 4.0f);
228+
assert(m_k.i == 5);
229+
assert(m_k.f1 == 6.0f);
230+
}
231+
232+
for (std::size_t k = 0; k < num_elements; ++k) {
233+
auto m_nc_k = m_nc[k];
234+
m_nc_k.f0 = 7.0f;
235+
m_nc_k.i = 8;
236+
m_nc_k.f1 = 9.0f;
237+
}
238+
239+
for (std::size_t k = 0; k < num_elements; ++k) {
240+
auto m_k = m[k];
241+
assert(m_k.f0 == 7.0f);
242+
assert(m_k.i == 8);
243+
assert(m_k.f1 == 9.0f);
244+
}
245+
246+
return 0;
247+
}

0 commit comments

Comments
 (0)