diff --git a/include/ddc/kernels/splines/spline_builder.hpp b/include/ddc/kernels/splines/spline_builder.hpp index b295e1baf..0e0473a51 100644 --- a/include/ddc/kernels/splines/spline_builder.hpp +++ b/include/ddc/kernels/splines/spline_builder.hpp @@ -468,7 +468,7 @@ class SplineBuilder * @param[in] derivs_xmax The values of the derivatives at the upper boundary * (used only with BoundCond::HERMITE upper boundary condition). */ - template + template void operator()( ddc::ChunkSpan< double, @@ -479,13 +479,13 @@ class SplineBuilder std::optional, - Layout, + LayoutDeriv, memory_space>> derivs_xmin = std::nullopt, std::optional, - Layout, + LayoutDeriv, memory_space>> derivs_xmax = std::nullopt) const; @@ -793,7 +793,7 @@ template < ddc::BoundCond BcLower, ddc::BoundCond BcUpper, SplineSolver Solver> -template +template void SplineBuilder:: operator()( ddc::ChunkSpan< @@ -805,12 +805,12 @@ operator()( std::optional, - Layout, + LayoutDeriv, memory_space>> const derivs_xmin, std::optional, - Layout, + LayoutDeriv, memory_space>> const derivs_xmax) const { auto const batched_interpolation_domain = vals.domain(); diff --git a/include/ddc/kernels/splines/spline_builder_2d.hpp b/include/ddc/kernels/splines/spline_builder_2d.hpp index 03dc45e10..9bfe22454 100644 --- a/include/ddc/kernels/splines/spline_builder_2d.hpp +++ b/include/ddc/kernels/splines/spline_builder_2d.hpp @@ -405,7 +405,7 @@ class SplineBuilder2D * The values of the the cross-derivatives at the upper boundary in the first dimension * and the upper boundary in the second dimension. */ - template + template void operator()( ddc::ChunkSpan< double, @@ -416,49 +416,49 @@ class SplineBuilder2D std::optional, - Layout, + LayoutDeriv, memory_space>> derivs_min1 = std::nullopt, std::optional, - Layout, + LayoutDeriv, memory_space>> derivs_max1 = std::nullopt, std::optional, - Layout, + LayoutDeriv, memory_space>> derivs_min2 = std::nullopt, std::optional, - Layout, + LayoutDeriv, memory_space>> derivs_max2 = std::nullopt, std::optional, - Layout, + LayoutDeriv, memory_space>> mixed_derivs_min1_min2 = std::nullopt, std::optional, - Layout, + LayoutDeriv, memory_space>> mixed_derivs_max1_min2 = std::nullopt, std::optional, - Layout, + LayoutDeriv, memory_space>> mixed_derivs_min1_max2 = std::nullopt, std::optional, - Layout, + LayoutDeriv, memory_space>> mixed_derivs_max1_max2 = std::nullopt) const; }; @@ -476,7 +476,7 @@ template < ddc::BoundCond BcLower2, ddc::BoundCond BcUpper2, ddc::SplineSolver Solver> -template +template void SplineBuilder2D< ExecSpace, MemorySpace, @@ -499,42 +499,42 @@ operator()( std::optional, - Layout, + LayoutDeriv, memory_space>> const derivs_min1, std::optional, - Layout, + LayoutDeriv, memory_space>> const derivs_max1, std::optional, - Layout, + LayoutDeriv, memory_space>> const derivs_min2, std::optional, - Layout, + LayoutDeriv, memory_space>> const derivs_max2, std::optional, - Layout, + LayoutDeriv, memory_space>> const mixed_derivs_min1_min2, std::optional, - Layout, + LayoutDeriv, memory_space>> const mixed_derivs_max1_min2, std::optional, - Layout, + LayoutDeriv, memory_space>> const mixed_derivs_min1_max2, std::optional, - Layout, + LayoutDeriv, memory_space>> const mixed_derivs_max1_max2) const { auto const batched_interpolation_domain = vals.domain(); diff --git a/include/ddc/strided_discrete_domain.hpp b/include/ddc/strided_discrete_domain.hpp index 8db6dc7fb..8cd653be2 100644 --- a/include/ddc/strided_discrete_domain.hpp +++ b/include/ddc/strided_discrete_domain.hpp @@ -14,6 +14,7 @@ #include "detail/type_seq.hpp" +#include "discrete_domain.hpp" #include "discrete_element.hpp" #include "discrete_vector.hpp" @@ -95,7 +96,7 @@ class StridedDiscreteDomain /// Construct a StridedDiscreteDomain by copies and merge of domains template < class... DDoms, - class = std::enable_if_t<(is_strided_discrete_domain_v && ...)>> + std::enable_if_t<(is_strided_discrete_domain_v && ...), bool> = true> KOKKOS_FUNCTION constexpr explicit StridedDiscreteDomain(DDoms const&... domains) : m_element_begin(domains.front()...) , m_extents(domains.extents()...) @@ -103,6 +104,14 @@ class StridedDiscreteDomain { } + /// Construct a StridedDiscreteDomain from a DiscreteDomain + KOKKOS_FUNCTION constexpr explicit StridedDiscreteDomain(DiscreteDomain const& domain) + : m_element_begin(domain.front()) + , m_extents(domain.extents()) + , m_strides((DiscreteVector{1})...) + { + } + /** Construct a StridedDiscreteDomain starting from element_begin with size points. * @param element_begin the lower bound in each direction * @param extents the number of points in each direction @@ -340,6 +349,13 @@ class StridedDiscreteDomain<> { } + // Construct a StridedDiscreteDomain from a reordered copy of `domain` + template + KOKKOS_FUNCTION constexpr explicit StridedDiscreteDomain( + [[maybe_unused]] DiscreteDomain const& domain) + { + } + /** Construct a StridedDiscreteDomain starting from element_begin with size points. * @param element_begin the lower bound in each direction * @param size the number of points in each direction @@ -387,6 +403,11 @@ class StridedDiscreteDomain<> return {}; } + static KOKKOS_FUNCTION constexpr discrete_vector_type strides() noexcept + { + return {}; + } + static KOKKOS_FUNCTION constexpr discrete_element_type front() noexcept { return {}; diff --git a/tests/splines/batched_spline_builder.cpp b/tests/splines/batched_spline_builder.cpp index d98c21d69..51cb83bde 100644 --- a/tests/splines/batched_spline_builder.cpp +++ b/tests/splines/batched_spline_builder.cpp @@ -165,6 +165,13 @@ void BatchedSplineTest() ddc::DiscreteDomain> const derivs_domain(DElem>(1), DVect>(s_degree_x / 2)); auto const dom_derivs = ddc::replace_dim_of>(dom_vals, derivs_domain); + // Create the derivs domain + ddc::StridedDiscreteDomain> const + derivs_domain_strided(DElem>(interpolation_domain.front(), derivs_domain.front()), + DVect>(DVect(1), derivs_domain.extents()), + DVect>(interpolation_domain.extents() - 1, DVect>(1))); + ddc::remove_dims_of_t, DDimI> const dom_vals_tmp_strided(dom_vals_tmp); + ddc::StridedDiscreteDomain> const dom_derivs_strided(dom_vals_tmp_strided, derivs_domain_strided); #endif // Create a SplineBuilder over BSplines and batched along other dimensions using some boundary conditions @@ -200,6 +207,8 @@ void BatchedSplineTest() #if defined(BC_HERMITE) // Allocate and fill a chunk containing derivs to be passed as input to spline_builder. int const shift = s_degree_x % 2; // shift = 0 for even order, 1 for odd order + ddc::Chunk derivs_strided_alloc(dom_derivs_strided, ddc::KokkosAllocator()); + ddc::ChunkSpan const derivs_strided = derivs_strided_alloc.span_view(); ddc::Chunk derivs_lhs_alloc(dom_derivs, ddc::KokkosAllocator()); ddc::ChunkSpan const derivs_lhs = derivs_lhs_alloc.span_view(); if (s_bcl == ddc::BoundCond::HERMITE) { @@ -220,6 +229,8 @@ void BatchedSplineTest() typename decltype(derivs_lhs.domain())::discrete_element_type const e) { derivs_lhs(e) = derivs_lhs1(DElem>(e)); }); + + Kokkos::deep_copy(derivs_strided[interpolation_domain.front()].allocation_kokkos_view(), derivs_lhs.allocation_kokkos_view()); } ddc::Chunk derivs_rhs_alloc(dom_derivs, ddc::KokkosAllocator()); @@ -243,6 +254,7 @@ void BatchedSplineTest() typename decltype(derivs_rhs.domain())::discrete_element_type const e) { derivs_rhs(e) = derivs_rhs1(DElem>(e)); }); + Kokkos::deep_copy(derivs_strided[interpolation_domain.back()].allocation_kokkos_view(), derivs_rhs.allocation_kokkos_view()); } #endif @@ -257,6 +269,25 @@ void BatchedSplineTest() vals.span_cview(), std::optional(derivs_lhs.span_cview()), std::optional(derivs_rhs.span_cview())); + + { + ddc::Chunk coef_2_alloc(dom_spline, ddc::KokkosAllocator()); + ddc::ChunkSpan const coef_2 = coef_2_alloc.span_view(); + spline_builder( + coef_2, + vals.span_cview(), + std::optional(derivs_strided[interpolation_domain.front()].span_cview()), + std::optional(derivs_strided[interpolation_domain.back()].span_cview())); + double const max_norm_error = ddc::parallel_transform_reduce( + exec_space, + coef.domain(), + 0., + ddc::reducer::max(), + KOKKOS_LAMBDA(DElem const e) { + return Kokkos::abs(coef(e) - coef_2(e)); + }); + EXPECT_LE(max_norm_error, 1e-14); + } #else spline_builder(coef, vals.span_cview()); #endif