-
Notifications
You must be signed in to change notification settings - Fork 80
Description
Hi, I'm working on vikunja, a platform-independent primitives library (e.g. vikunja::transform
, vikunja::reduce
) for different accelerators based on alpaka.
I want to use mdspan
to support N dimension memory. In the presentation of Bryce Aldelstein Lelbach (youtube link) there is the idea to build a recursive function to create n nested loops, which the compiler can optimize.
I implemented a prototype with submdspan
for an iota function for a mdspan
:
namespace stdex = std::experimental;
template<int TDim>
struct Iterate_mdspan_impl;
template<>
struct Iterate_mdspan_impl<1>
{
template<typename TSpan, typename TFunc>
void operator()(TSpan span, TFunc& functor)
{
for(auto i = 0; i < span.extent(0); ++i)
{
span(i) = functor(span(i));
}
}
};
template<>
struct Iterate_mdspan_impl<2>
{
template<typename TSpan, typename TFunc>
void operator()(TSpan span, TFunc& functor)
{
for(auto i = 0; i < span.extent(0); ++i)
{
auto submdspan = stdex::submdspan(span, i, stdex::full_extent);
Iterate_mdspan_impl<TSpan::rank() - 1>{}(submdspan, functor);
}
}
};
template<>
struct Iterate_mdspan_impl<3>
{
template<typename TSpan, typename TFunc>
void operator()(TSpan span, TFunc& functor)
{
for(auto i = 0; i < span.extent(0); ++i)
{
auto submdspan = stdex::submdspan(span, i, stdex::full_extent, stdex::full_extent);
Iterate_mdspan_impl<TSpan::rank() - 1>{}(submdspan, functor);
}
}
};
template<typename TSpan, typename TData>
void iota_span(TSpan span, TData index)
{
static_assert(TSpan::rank() > 0);
static_assert(TSpan::rank() <= 3);
auto functor = [&index](TData input) { return index++; };
Iterate_mdspan_impl<TSpan::rank()>{}(span, functor);
}
My problem is, that I need to specialize each dimension by hand. At the moment I stopped at dim 3. Does anybody have an idea to write a generic function to iterate over all elements of a mdspan
?
Maybe it is possible to set the n stdex::full_extent
arguments in the function stdex::submdspan(span, i, stdex::full_extent, ...)
depending on the template parameter TDim
. I'm not sure, if this is possible with some variadic or type trait.