Skip to content

Generic iteration function for n dimension mdspan  #202

@SimeonEhrig

Description

@SimeonEhrig

Hi, I'm working on vikunja, a platform-independent primitives library (e.g. vikunja::transform, vikunja::reduce) for different accelerators based on alpaka.
I want to use mdspan to support N dimension memory. In the presentation of Bryce Aldelstein Lelbach (youtube link) there is the idea to build a recursive function to create n nested loops, which the compiler can optimize.

I implemented a prototype with submdspan for an iota function for a mdspan:

namespace stdex = std::experimental;

template<int TDim>
struct Iterate_mdspan_impl;

template<>
struct Iterate_mdspan_impl<1>
{
    template<typename TSpan, typename TFunc>
    void operator()(TSpan span, TFunc& functor)
    {
        for(auto i = 0; i < span.extent(0); ++i)
        {
            span(i) = functor(span(i));
        }
    }
};

template<>
struct Iterate_mdspan_impl<2>
{
    template<typename TSpan, typename TFunc>
    void operator()(TSpan span, TFunc& functor)
    {
        for(auto i = 0; i < span.extent(0); ++i)
        {
            auto submdspan = stdex::submdspan(span, i, stdex::full_extent);
            Iterate_mdspan_impl<TSpan::rank() - 1>{}(submdspan, functor);
        }
    }
};

template<>
struct Iterate_mdspan_impl<3>
{
    template<typename TSpan, typename TFunc>
    void operator()(TSpan span, TFunc& functor)
    {
        for(auto i = 0; i < span.extent(0); ++i)
        {
            auto submdspan = stdex::submdspan(span, i, stdex::full_extent, stdex::full_extent);
            Iterate_mdspan_impl<TSpan::rank() - 1>{}(submdspan, functor);
        }
    }
};

template<typename TSpan, typename TData>
void iota_span(TSpan span, TData index)
{
    static_assert(TSpan::rank() > 0);
    static_assert(TSpan::rank() <= 3);
    auto functor = [&index](TData input) { return index++; };
    Iterate_mdspan_impl<TSpan::rank()>{}(span, functor);
}

My problem is, that I need to specialize each dimension by hand. At the moment I stopped at dim 3. Does anybody have an idea to write a generic function to iterate over all elements of a mdspan?

Maybe it is possible to set the n stdex::full_extent arguments in the function stdex::submdspan(span, i, stdex::full_extent, ...) depending on the template parameter TDim. I'm not sure, if this is possible with some variadic or type trait.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions