Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions Src/Base/AMReX_GpuLaunch.H
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,14 @@ void ParallelForOMP (Box const& box, L const& f) noexcept
#pragma omp parallel for collapse(2)
for (int k = lo.z; k <= hi.z; ++k) {
for (int j = lo.y; j <= hi.y; ++j) {
AMREX_PRAGMA_SIMD
for (int i = lo.x; i <= hi.x; ++i) {
f(i,j,k);
constexpr int WIDTH = amrex::simd::native_simd_size_real;
int i = lo.x;
for (; i + WIDTH <= hi.x; i+=WIDTH) {
f(SIMDindex<WIDTH, int>{i}, j, k);
}
for (; i <= hi.x; ++i) {
// TODO: template, etc.
f(SIMDindex<1, int>{i}, j, k);
}
}
}
Expand Down
4 changes: 3 additions & 1 deletion Src/Base/AMReX_GpuLaunchFunctsC.H
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#define AMREX_GPU_LAUNCH_FUNCTS_C_H_
#include <AMReX_Config.H>

#include <AMReX_SIMD.H>

namespace amrex {

/** Helper type to store/access the SIMD width in ParallelForSIMD lambdas
Expand All @@ -12,7 +14,7 @@ namespace amrex {
* @tparam WIDTH SIMD width in elements
* @tparam N index type (integer)
*/
template<int WIDTH, class N=int>
template<int WIDTH=simd::native_simd_size_real, class N=int>
struct SIMDindex
{
/** SIMD width in elements */
Expand Down
2 changes: 2 additions & 0 deletions Src/Base/AMReX_SIMD.H
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# include <vir/simd.h> // includes SIMD TS2 header <experimental/simd>
# if __cplusplus >= 202002L
# include <vir/simd_cvt.h>
# include <vir/simd_iota.h>
# endif
#endif

Expand All @@ -26,6 +27,7 @@ namespace amrex::simd
using namespace vir::stdx;
# if __cplusplus >= 202002L
using vir::cvt;
using vir::iota_v;
# endif
#else
// fallback implementations for functions that are commonly used in portable code paths
Expand Down
31 changes: 20 additions & 11 deletions Src/FFT/AMReX_FFT_OpenBCSolver.H
Original file line number Diff line number Diff line change
Expand Up @@ -97,17 +97,23 @@ void OpenBCSolver<T>::setGreensFunction (F const& greens_function)
}
AMREX_ASSERT(nimages[0] == 2);
box.shift(-lo);
amrex::ParallelForOMP(box, [=] AMREX_GPU_DEVICE (int i, int j, int k)
amrex::ParallelForOMP(box, [=]<int WIDTH> AMREX_GPU_DEVICE (amrex::SIMDindex<WIDTH> i, int j, int k)
{
T G;
if (i == len[0] || j == len[1] || k == len[2]) {
G = 0;
} else {
auto ii = i;
using SIMD_T = simd::stdx::fixed_size_simd<T, i.width>;
using SIMD_int = simd::stdx::fixed_size_simd<int, i.width>; // simd::stdx::rebind_simd_t<int, SIMD_T>;

SIMD_T G = 0;
if (j != len[1] && k != len[2])
{
SIMD_int ii = simd::stdx::iota_v<SIMD_int> + i.index;
auto jj = (j > len[1]) ? 2*len[1]-j : j;
auto kk = (k > len[2]) ? 2*len[2]-k : k;
G = greens_function(ii+lo3.x,jj+lo3.y,kk+lo3.z);

auto const i_bound = ii == len[0];
simd::stdx::where(simd::stdx::cvt(i_bound), G) = 0.0;
simd::stdx::where(simd::stdx::cvt(!i_bound), G) = greens_function.template operator()<WIDTH>(ii+lo3.x,jj+lo3.y,kk+lo3.z);
}

for (int koff = 0; koff < nimages[2]; ++koff) {
int k2 = (koff == 0) ? k : 2*len[2]-k;
if ((k2 == 2*len[2]) || (koff == 1 && k == len[2])) {
Expand All @@ -119,11 +125,14 @@ void OpenBCSolver<T>::setGreensFunction (F const& greens_function)
continue;
}
for (int ioff = 0; ioff < nimages[0]; ++ioff) {
int i2 = (ioff == 0) ? i : 2*len[0]-i;
if ((i2 == 2*len[0]) || (ioff == 1 && i == len[0])) {
continue;
for (int iw = i.index; iw < i.index+i.width; ++iw) {
int i2 = (ioff == 0) ? iw : 2*len[0]-iw;
if ((i2 == 2*len[0]) || (ioff == 1 && iw == len[0])) {
continue;
}
// TODO: SIMD-assign N values
a(i2+lo3.x,j2+lo3.y,k2+lo3.z) = G[iw];
}
a(i2+lo3.x,j2+lo3.y,k2+lo3.z) = G;
}
}
}
Expand Down
Loading