diff --git a/unified-runtime/source/adapters/native_cpu/enqueue.cpp b/unified-runtime/source/adapters/native_cpu/enqueue.cpp index 6818a915b9334..6c235b44b70ab 100644 --- a/unified-runtime/source/adapters/native_cpu/enqueue.cpp +++ b/unified-runtime/source/adapters/native_cpu/enqueue.cpp @@ -96,6 +96,113 @@ static inline native_cpu::state getState(const native_cpu::NDRDescT &ndr) { return resized_state; } +static inline void invoke_kernel(native_cpu::state &state, + const ur_kernel_handle_t_ &kernel, size_t g0, + size_t g1, size_t g2, + size_t numParallelThreads, size_t threadId, + const native_cpu::NDRDescT &ndr) { +#ifdef NATIVECPU_USE_OCK + state.update(g0, g1, g2); + kernel._subhandler(kernel.getArgs(numParallelThreads, threadId).data(), + &state); + (void)ndr; +#else + for (size_t local2 = 0; local2 < ndr.LocalSize[2]; ++local2) { + for (size_t local1 = 0; local1 < ndr.LocalSize[1]; ++local1) { + for (size_t local0 = 0; local0 < ndr.LocalSize[0]; ++local0) { + state.update(g0, g1, g2, local0, local1, local2); + kernel._subhandler(kernel.getArgs(numParallelThreads, threadId).data(), + &state); + } + } + } +#endif +} + +#ifdef NATIVECPU_WITH_ONETBB + +#define NATIVECPU_WITH_ONETBB_PARALLELFOR + +using IndexT = std::array; +using RangeT = native_cpu::NDRDescT::RangeT; + +static inline void execute_range(native_cpu::state &state, + const ur_kernel_handle_t_ &hKernel, + IndexT first, IndexT lastPlusOne, + size_t numParallelThreads, size_t threadId, + const native_cpu::NDRDescT &ndr) { + for (size_t g2 = first[2]; g2 < lastPlusOne[2]; g2++) { + for (size_t g1 = first[1]; g1 < lastPlusOne[1]; g1++) { + for (size_t g0 = first[0]; g0 < lastPlusOne[0]; g0 += 1) { + invoke_kernel(state, hKernel, g0, g1, g2, numParallelThreads, threadId, + ndr); + } + } + } +} + +namespace native_cpu { + +class nativecpu_tbb_executor { + const native_cpu::NDRDescT ndr; + +protected: + const ur_kernel_handle_t_ &hKernel; + const size_t numParallelThreads; + + void execute(IndexT first, IndexT last_plus_one) const { + auto state = getState(ndr); + auto threadId = native_cpu::getTBBThreadID(); + execute_range(state, hKernel, first, last_plus_one, numParallelThreads, + threadId, ndr); + } + +public: + void operator()(const tbb::blocked_range3d &r) const { + execute({r.pages().begin(), r.rows().begin(), r.cols().begin()}, + {r.pages().end(), r.rows().end(), r.cols().end()}); + } + void operator()(const tbb::blocked_range2d &r) const { + execute({r.rows().begin(), r.cols().begin(), 0}, + {r.rows().end(), r.cols().end(), 1}); + } + void operator()(const tbb::blocked_range &r) const { + execute({r.begin(), 0, 0}, {r.end(), 1, 1}); + } + nativecpu_tbb_executor(const native_cpu::NDRDescT &n, + const ur_kernel_handle_t_ &k, + const size_t numParallelThreads) + : ndr(n), hKernel(k), numParallelThreads(numParallelThreads) {} +}; + +using tbb_nd_executor = nativecpu_tbb_executor; + +template