Skip to content

Commit c164bac

Browse files
committed
[Draft] OpenBC SIMD
1 parent 82eeea5 commit c164bac

File tree

4 files changed

+33
-15
lines changed

4 files changed

+33
-15
lines changed

Src/Base/AMReX_GpuLaunch.H

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,9 +298,14 @@ void ParallelForOMP (Box const& box, L const& f) noexcept
298298
#pragma omp parallel for collapse(2)
299299
for (int k = lo.z; k <= hi.z; ++k) {
300300
for (int j = lo.y; j <= hi.y; ++j) {
301-
AMREX_PRAGMA_SIMD
302-
for (int i = lo.x; i <= hi.x; ++i) {
303-
f(i,j,k);
301+
constexpr int WIDTH = amrex::simd::native_simd_size_real;
302+
int i = lo.x;
303+
for (; i + WIDTH <= hi.x; i+=WIDTH) {
304+
f(SIMDindex<WIDTH, int>{i}, j, k);
305+
}
306+
for (; i <= hi.x; ++i) {
307+
// TODO: template, etc.
308+
//f(SIMDindex<1, int>{i}, j, k);
304309
}
305310
}
306311
}

Src/Base/AMReX_GpuLaunchFunctsC.H

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#define AMREX_GPU_LAUNCH_FUNCTS_C_H_
33
#include <AMReX_Config.H>
44

5+
#include <AMReX_SIMD.H>
6+
57
namespace amrex {
68

79
/** Helper type to store/access the SIMD width in ParallelForSIMD lambdas
@@ -12,7 +14,7 @@ namespace amrex {
1214
* @tparam WIDTH SIMD width in elements
1315
* @tparam N index type (integer)
1416
*/
15-
template<int WIDTH, class N=int>
17+
template<int WIDTH=simd::native_simd_size_real, class N=int>
1618
struct SIMDindex
1719
{
1820
/** SIMD width in elements */

Src/Base/AMReX_SIMD.H

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# include <vir/simd.h> // includes SIMD TS2 header <experimental/simd>
1111
# if __cplusplus >= 202002L
1212
# include <vir/simd_cvt.h>
13+
# include <vir/simd_iota.h>
1314
# endif
1415
#endif
1516

@@ -26,6 +27,7 @@ namespace amrex::simd
2627
using namespace vir::stdx;
2728
# if __cplusplus >= 202002L
2829
using vir::cvt;
30+
using vir::iota_v;
2931
# endif
3032
#else
3133
// fallback implementations for functions that are commonly used in portable code paths

Src/FFT/AMReX_FFT_OpenBCSolver.H

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -97,17 +97,23 @@ void OpenBCSolver<T>::setGreensFunction (F const& greens_function)
9797
}
9898
AMREX_ASSERT(nimages[0] == 2);
9999
box.shift(-lo);
100-
amrex::ParallelForOMP(box, [=] AMREX_GPU_DEVICE (int i, int j, int k)
100+
amrex::ParallelForOMP(box, [=] AMREX_GPU_DEVICE (amrex::SIMDindex<> i, int j, int k)
101101
{
102-
T G;
103-
if (i == len[0] || j == len[1] || k == len[2]) {
104-
G = 0;
105-
} else {
106-
auto ii = i;
102+
using SIMD_T = simd::stdx::fixed_size_simd<T, i.width>;
103+
using SIMD_int = simd::stdx::rebind_simd_t<int, SIMD_T>;
104+
105+
SIMD_T G = 0;
106+
if (j != len[1] && k != len[2])
107+
{
108+
SIMD_int ii = simd::stdx::iota_v<SIMD_int> + i.index;
107109
auto jj = (j > len[1]) ? 2*len[1]-j : j;
108110
auto kk = (k > len[2]) ? 2*len[2]-k : k;
109-
G = greens_function(ii+lo3.x,jj+lo3.y,kk+lo3.z);
111+
112+
auto const i_bound = ii == len[0];
113+
simd::stdx::where(simd::stdx::cvt(i_bound), G) = 0.0;
114+
simd::stdx::where(simd::stdx::cvt(!i_bound), G) = greens_function(ii+lo3.x,jj+lo3.y,kk+lo3.z);
110115
}
116+
111117
for (int koff = 0; koff < nimages[2]; ++koff) {
112118
int k2 = (koff == 0) ? k : 2*len[2]-k;
113119
if ((k2 == 2*len[2]) || (koff == 1 && k == len[2])) {
@@ -119,11 +125,14 @@ void OpenBCSolver<T>::setGreensFunction (F const& greens_function)
119125
continue;
120126
}
121127
for (int ioff = 0; ioff < nimages[0]; ++ioff) {
122-
int i2 = (ioff == 0) ? i : 2*len[0]-i;
123-
if ((i2 == 2*len[0]) || (ioff == 1 && i == len[0])) {
124-
continue;
128+
for (int iw = i.index; iw < i.index+i.width; ++iw) {
129+
int i2 = (ioff == 0) ? iw : 2*len[0]-iw;
130+
if ((i2 == 2*len[0]) || (ioff == 1 && iw == len[0])) {
131+
continue;
132+
}
133+
// TODO: SIMD-assign N values
134+
a(i2+lo3.x,j2+lo3.y,k2+lo3.z) = G[iw];
125135
}
126-
a(i2+lo3.x,j2+lo3.y,k2+lo3.z) = G;
127136
}
128137
}
129138
}

0 commit comments

Comments
 (0)