1
- #include < experimental /mdspan>
1
+ #include < mdspan /mdspan.hpp >
2
2
#include < array>
3
3
#include < iostream>
4
4
#include < tuple>
5
5
#include < type_traits>
6
6
7
- namespace stdex = std::experimental;
8
-
9
7
// There's no separate feature test macro for the C++20 feature
10
8
// of lambdas with named template parameters (P0428R2).
11
9
#if __cplusplus >= 202002L
@@ -80,11 +78,11 @@ auto print_pack = []<class ... InputTypes>(InputTypes&& ... input) {
80
78
// This example shows that you can do
81
79
// index arithmetic on an index sequence.
82
80
template <class IndexType , std::size_t ... Extents>
83
- auto right_extents ( stdex ::extents<IndexType, Extents...> e )
81
+ auto right_extents ( Kokkos ::extents<IndexType, Extents...> e )
84
82
{
85
83
static_assert (sizeof ...(Extents) != 0 );
86
84
return [&]<std::size_t ... Indices>( std::index_sequence<Indices...> ) {
87
- return stdex ::extents<IndexType, e.static_extent (Indices + 1 )...>{
85
+ return Kokkos ::extents<IndexType, e.static_extent (Indices + 1 )...>{
88
86
e.extent (Indices + 1 )...
89
87
};
90
88
}( std::make_index_sequence<sizeof ...(Extents) - 1 >() );
@@ -101,10 +99,10 @@ auto right_extents( stdex::extents<IndexType, Extents...> e )
101
99
// This needs to be a lambda or function object,
102
100
// not a templated function.
103
101
auto split_extents_at_leftmost =
104
- []<class IndexType , std::size_t ... Extents>(stdex ::extents<IndexType, Extents...> e)
102
+ []<class IndexType , std::size_t ... Extents>(Kokkos ::extents<IndexType, Extents...> e)
105
103
{
106
104
static_assert (sizeof ...(Extents) != 0 );
107
- stdex ::extents<IndexType, e.static_extent (0 )> left_ext (
105
+ Kokkos ::extents<IndexType, e.static_extent (0 )> left_ext (
108
106
e.extent (0 ));
109
107
return std::tuple{left_ext, right_extents (e)};
110
108
};
@@ -116,22 +114,22 @@ auto split_extents_at_leftmost =
116
114
// Returns a new extents object representing
117
115
// all but the rightmost extent of e.
118
116
template <class IndexType , std::size_t ... Extents>
119
- auto left_extents ( stdex ::extents<IndexType, Extents...> e )
117
+ auto left_extents ( Kokkos ::extents<IndexType, Extents...> e )
120
118
{
121
119
static_assert (sizeof ...(Extents) != 0 );
122
120
return [&]<std::size_t ... Indices>( std::index_sequence<Indices...> ) {
123
- return stdex ::extents<IndexType, e.static_extent (Indices)...>{
121
+ return Kokkos ::extents<IndexType, e.static_extent (Indices)...>{
124
122
e.extent (Indices)...
125
123
};
126
124
}( std::make_index_sequence<sizeof ...(Extents) - 1 >() );
127
125
}
128
126
129
127
// This needs to be a lambda or function object, not a templated function.
130
128
auto split_extents_at_rightmost =
131
- []<class IndexType , std::size_t ... Extents>(stdex ::extents<IndexType, Extents...> e)
129
+ []<class IndexType , std::size_t ... Extents>(Kokkos ::extents<IndexType, Extents...> e)
132
130
{
133
131
static_assert (sizeof ...(Extents) != 0 );
134
- stdex ::extents<IndexType, e.static_extent (e.rank () - 1 )> right_ext (
132
+ Kokkos ::extents<IndexType, e.static_extent (e.rank () - 1 )> right_ext (
135
133
e.extent (e.rank () - 1 ));
136
134
return std::tuple{left_extents (e), right_ext};
137
135
};
@@ -149,10 +147,10 @@ auto split_extents_at_rightmost =
149
147
// optimization information -- e.g., whether we want
150
148
// to apply "#pragma omp simd" to a particular extent.
151
149
template <class Callable , class IndexType , std::size_t Extent>
152
- void for_each_one_extent (Callable&& callable, stdex ::extents<IndexType, Extent> ext)
150
+ void for_each_one_extent (Callable&& callable, Kokkos ::extents<IndexType, Extent> ext)
153
151
{
154
152
// If it's a run-time extent, do a run-time loop.
155
- if constexpr (ext.static_extent (0 ) == stdex ::dynamic_extent) {
153
+ if constexpr (ext.static_extent (0 ) == Kokkos ::dynamic_extent) {
156
154
for (IndexType index = 0 ; index < ext.extent (0 ); ++index) {
157
155
std::forward<Callable>(callable)(index);
158
156
}
@@ -176,7 +174,7 @@ void for_each_one_extent(Callable&& callable, stdex::extents<IndexType, Extent>
176
174
template <class Callable , class IndexType , std::size_t ... Extents>
177
175
void for_each_in_extents_row_major (
178
176
Callable&& callable,
179
- stdex ::extents<IndexType, Extents...> ext)
177
+ Kokkos ::extents<IndexType, Extents...> ext)
180
178
{
181
179
if constexpr (ext.rank () == 0 ) {
182
180
return ;
@@ -203,12 +201,12 @@ void for_each_in_extents_row_major(
203
201
// The implementation differs in only two places from the row-major version.
204
202
// This suggests a way to generalize.
205
203
//
206
- // Overloading on stdex:: extents<IndexType, LeftExtents..., RightExtent>
204
+ // Overloading on extents<IndexType, LeftExtents..., RightExtent>
207
205
// works fine for the row major case, but not for the column major case.
208
206
template <class Callable , class IndexType , std::size_t ... Extents>
209
207
void for_each_in_extents_col_major (
210
208
Callable&& callable,
211
- stdex ::extents<IndexType, Extents...> ext)
209
+ Kokkos ::extents<IndexType, Extents...> ext)
212
210
{
213
211
if constexpr (ext.rank () == 0 ) {
214
212
return ;
@@ -242,7 +240,7 @@ void for_each_in_extents_col_major(
242
240
template <class Callable , class IndexType , std::size_t ... Extents,
243
241
class ExtentsReorderer , class ExtentsSplitter , class IndicesReorderer >
244
242
void for_each_in_extents_impl (Callable&& callable,
245
- stdex ::extents<IndexType, Extents...> ext,
243
+ Kokkos ::extents<IndexType, Extents...> ext,
246
244
ExtentsReorderer reorder_extents,
247
245
ExtentsSplitter split_extents,
248
246
IndicesReorderer reorder_indices)
@@ -280,18 +278,18 @@ void for_each_in_extents_impl(Callable&& callable,
280
278
}
281
279
282
280
auto extents_identity = []<class IndexType , std::size_t ... Extents>(
283
- stdex ::extents<IndexType, Extents...> ext)
281
+ Kokkos ::extents<IndexType, Extents...> ext)
284
282
{
285
283
return ext;
286
284
};
287
285
288
286
auto extents_reverse = []<class IndexType , std::size_t ... Extents>(
289
- stdex ::extents<IndexType, Extents...> ext)
287
+ Kokkos ::extents<IndexType, Extents...> ext)
290
288
{
291
289
constexpr std::size_t N = ext.rank ();
292
290
293
291
return [&]<std::size_t ... Indices>( std::index_sequence<Indices...> ) {
294
- return stdex ::extents<
292
+ return Kokkos ::extents<
295
293
IndexType,
296
294
ext.static_extent (N - 1 - Indices)...
297
295
>{ ext.extent (N - 1 - Indices)... };
@@ -325,8 +323,8 @@ auto indices_reverse = [](auto... args) {
325
323
// Row-major iteration
326
324
template <class Callable , class IndexType , std::size_t ... Extents>
327
325
void for_each_in_extents (Callable&& callable,
328
- stdex ::extents<IndexType, Extents...> ext,
329
- stdex ::layout_right)
326
+ Kokkos ::extents<IndexType, Extents...> ext,
327
+ Kokkos ::layout_right)
330
328
{
331
329
for_each_in_extents_impl (std::forward<Callable>(callable), ext,
332
330
extents_identity, split_extents_at_leftmost, indices_identity);
@@ -335,8 +333,8 @@ void for_each_in_extents(Callable&& callable,
335
333
// Column-major iteration
336
334
template <class Callable , class IndexType , std::size_t ... Extents>
337
335
void for_each_in_extents (Callable&& callable,
338
- stdex ::extents<IndexType, Extents...> ext,
339
- stdex ::layout_left)
336
+ Kokkos ::extents<IndexType, Extents...> ext,
337
+ Kokkos ::layout_left)
340
338
{
341
339
for_each_in_extents_impl (std::forward<Callable>(callable), ext,
342
340
extents_reverse, split_extents_at_rightmost, indices_reverse);
@@ -349,7 +347,7 @@ int main() {
349
347
#if ! defined(__clang__) && defined(MDSPAN_EXAMPLE_CAN_USE_LAMBDA_TEMPLATE_PARAM_LIST)
350
348
// The functions work for any combination
351
349
// of compile-time or run-time extents.
352
- stdex ::extents<int , 3 , stdex ::dynamic_extent, 5 > e{4 };
350
+ Kokkos ::extents<int , 3 , Kokkos ::dynamic_extent, 5 > e{4 };
353
351
354
352
std::cout << " \n Row major:\n " ;
355
353
for_each_in_extents_row_major (print_pack, e);
@@ -358,10 +356,10 @@ int main() {
358
356
for_each_in_extents_col_major (print_pack, e);
359
357
360
358
std::cout << " \n for_each_in_extents: row major:\n " ;
361
- for_each_in_extents (print_pack, e, stdex ::layout_right{});
359
+ for_each_in_extents (print_pack, e, Kokkos ::layout_right{});
362
360
363
361
std::cout << " \n for_each_in_extents: column major:\n " ;
364
- for_each_in_extents (print_pack, e, stdex ::layout_left{});
362
+ for_each_in_extents (print_pack, e, Kokkos ::layout_left{});
365
363
#endif // defined(MDSPAN_EXAMPLE_CAN_USE_LAMBDA_TEMPLATE_PARAM_LIST)
366
364
367
365
return 0 ;
0 commit comments