Skip to content
295 changes: 292 additions & 3 deletions src/ATen/native/xpu/sycl/DilatedMaxPool2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include <ATen/native/xpu/sycl/Atomics.h>
#include <ATen/native/xpu/sycl/BatchKernel.h>
#include <ATen/native/xpu/sycl/MemoryAccess.h>
#include <ATen/native/xpu/sycl/NumericLimits.h>
#include <comm/Runtime.h>
#include <comm/SYCLHelpers.h>
Expand Down Expand Up @@ -251,30 +252,33 @@ struct MaxPool2dBackwardDeterministicKernelFunctor {
int pwstart =
p_start(inputW, pad_w_, kernel_w_, dilation_w_, stride_w_);
int pwend = p_end(inputW, pad_w_, gradOutputSizeW_, stride_w_);
scalar_t grad = 0;
if constexpr (is_channels_last) {
int offset = batch * out_n_stride_ + plane;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
if (indices_[offset + (ph * gradOutputSizeW_ + pw) * numPlane_] ==
input_hw_index) {
gradInput_[inputIndex] += static_cast<scalar_t>(
grad += static_cast<scalar_t>(
gradOutput_
[offset + (ph * gradOutputSizeW_ + pw) * numPlane_]);
}
}
}
} else {
}
else {
int offset = batch * out_n_stride_ + plane * out_cf_c_stride_;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
if (indices_[offset + ph * gradOutputSizeW_ + pw] ==
input_hw_index) {
gradInput_[inputIndex] += static_cast<scalar_t>(
grad += static_cast<scalar_t>(
gradOutput_[offset + ph * gradOutputSizeW_ + pw]);
}
}
}
}
gradInput_[inputIndex] = grad;
}
} while (cfg_.next(item, desc));
}
Expand Down Expand Up @@ -349,6 +353,122 @@ struct MaxPool2dBackwardDeterministicKernelFunctor {
BatchKernelConfig cfg_;
};

template <typename scalar_t, typename vec_t, int vec_size>
struct MaxPool2dBackwardChannelLastVec {
void operator()(sycl::nd_item<1> item) const {
for (auto inputIndex = item.get_global_linear_id();
inputIndex < gradInputSize_ / vec_size;
inputIndex += item.get_local_range(0) * item.get_group_range(0)) {
int batch = inputIndex / (in_n_stride_ / vec_size);
int plane;
int64_t input_hw_index;

plane = inputIndex % (numPlane_ / vec_size);
input_hw_index =
((inputIndex % in_n_stride_) - plane) / (numPlane_ / vec_size);

int inputW = input_hw_index % gradInputSizeW_;
int inputH = input_hw_index / gradInputSizeW_;
int phstart = p_start(inputH, pad_h_, kernel_h_, dilation_h_, stride_h_);
int phend = p_end(inputH, pad_h_, gradOutputSizeH_, stride_h_);
int pwstart = p_start(inputW, pad_w_, kernel_w_, dilation_w_, stride_w_);
int pwend = p_end(inputW, pad_w_, gradOutputSizeW_, stride_w_);
scalar_t grad = 0;
int64_t load_offset, store_offset;
store_offset = inputIndex;
vec_t grad_vec;
#pragma unroll
for (int i = 0; i < vec_size; i++) {
grad_vec[i] = 0;
}

int offset = batch * (out_n_stride_ / vec_size) + plane;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
load_offset =
offset + (ph * gradOutputSizeW_ + pw) * (numPlane_ / vec_size);
vec_t gout_val_vec = gradOutput_[load_offset];
#pragma unroll
for (int i = 0; i < vec_size; i++) {
if (indices_[load_offset * vec_size + i] == input_hw_index) {
grad_vec[i] = static_cast<scalar_t>(grad_vec[i]) +
static_cast<scalar_t>(gout_val_vec[i]);
}
}
}
}

gradInput_[store_offset] = grad_vec;
}
}
MaxPool2dBackwardChannelLastVec(
vec_t* gradInput,
const vec_t* gradOutput,
const int64_t* indices,
int numPlane,
int gradInputSizeH,
int gradInputSizeW,
int gradOutputSizeH,
int gradOutputSizeW,
int64_t gradInputSize,
int out_cf_c_stride,
int in_cf_c_stride,
int out_n_stride,
int in_n_stride,
int kernel_h,
int kernel_w,
int stride_h,
int stride_w,
int pad_h,
int pad_w,
int dilation_h,
int dilation_w)
: gradInput_(gradInput),
gradOutput_(gradOutput),
indices_(indices),
numPlane_(numPlane),
gradInputSizeH_(gradInputSizeH),
gradInputSizeW_(gradInputSizeW),
gradOutputSizeH_(gradOutputSizeH),
gradOutputSizeW_(gradOutputSizeW),
gradInputSize_(gradInputSize),
out_cf_c_stride_(out_cf_c_stride),
in_cf_c_stride_(in_cf_c_stride),
out_n_stride_(out_n_stride),
in_n_stride_(in_n_stride),
kernel_h_(kernel_h),
kernel_w_(kernel_w),
stride_h_(stride_h),
stride_w_(stride_w),
pad_h_(pad_h),
pad_w_(pad_w),
dilation_h_(dilation_h),
dilation_w_(dilation_w) {}

private:
vec_t* gradInput_;
const vec_t* gradOutput_;
const int64_t* indices_;
int numPlane_;
int gradInputSizeH_;
int gradInputSizeW_;
int gradOutputSizeH_;
int gradOutputSizeW_;
int64_t gradInputSize_;
int out_cf_c_stride_;
int in_cf_c_stride_;
int out_n_stride_;
int in_n_stride_;
int kernel_h_;
int kernel_w_;
int stride_h_;
int stride_w_;
int pad_h_;
int pad_w_;
int dilation_h_;
int dilation_w_;
};

template <typename scalar_t, bool is_channels_last>
void launch_max_pool2d_kernel(
scalar_t* output,
Expand Down Expand Up @@ -397,6 +517,62 @@ void launch_max_pool2d_kernel(
sycl_kernel_submit(cfg.global_size(), cfg.group_size(), queue, kfn);
}

#define LAUNCH_MAXPOOL_BACKWARD_CHANNEL_LAST_VEC( \
scalar_t, \
vec_size, \
num_wg, \
wg_size, \
queue, \
gradInput, \
gradOutput, \
indices, \
numPlane, \
gradInputSizeH, \
gradInputSizeW, \
gradOutputSizeH, \
gradOutputSizeW, \
gradInputSize, \
out_cf_c_stride, \
in_cf_c_stride, \
out_n_stride, \
in_n_stride, \
kernel_h, \
kernel_w, \
stride_h, \
stride_w, \
pad_h, \
pad_w, \
dilation_h, \
dilation_w) \
{ \
using vec_t = memory::aligned_vector<scalar_t, vec_size>; \
const vec_t* grad_output_vec = reinterpret_cast<const vec_t*>(gradOutput); \
vec_t* grad_input_vec = reinterpret_cast<vec_t*>(gradInput); \
auto kfn = MaxPool2dBackwardChannelLastVec<scalar_t, vec_t, vec_size>( \
grad_input_vec, \
grad_output_vec, \
indices, \
numPlane, \
gradInputSizeH, \
gradInputSizeW, \
gradOutputSizeH, \
gradOutputSizeW, \
gradInputSize, \
out_cf_c_stride, \
in_cf_c_stride, \
out_n_stride, \
in_n_stride, \
kernel_h, \
kernel_w, \
stride_h, \
stride_w, \
pad_h, \
pad_w, \
dilation_h, \
dilation_w); \
sycl_kernel_submit(num_wg* wg_size, wg_size, queue, kfn); \
}

template <typename scalar_t, bool is_channels_last>
void launch_max_pool2d_backward_kernel(
scalar_t* gradInput,
Expand Down Expand Up @@ -435,6 +611,119 @@ void launch_max_pool2d_backward_kernel(
// with CUDA in alexnet To avoid future problem, we decided to always use
// deterministic path.


// int vec_size = 1;
// int thread_slots = syclGpuEuCount() * syclGpuHWThreadsPerEU();
// int num_sub_wg;
// auto wg_size = syclDeviceMaxWorkGroupSize();
// int64_t num_wg;
// if constexpr (is_channels_last) {
// for (vec_size = std::min(
// 8, memory::can_vectorize_up_to<scalar_t>((char*)gradOutput));
// vec_size >= 1;
// vec_size /= 2) {
// if (numPlane % vec_size != 0) {
// continue;
// }
// num_sub_wg = gradInputSize / vec_size / syclMaxSubGroupSize();
// if (2 * num_sub_wg > thread_slots) {
// int total_thread = gradInputSize / vec_size;
// num_wg = (total_thread + wg_size - 1) / wg_size;
// break;
// }
// }
// switch (vec_size) {
// case 8:
// LAUNCH_MAXPOOL_BACKWARD_CHANNEL_LAST_VEC(
// scalar_t,
// 8,
// num_wg,
// wg_size,
// queue,
// gradInput,
// gradOutput,
// indices,
// numPlane,
// gradInputSizeH,
// gradInputSizeW,
// gradOutputSizeH,
// gradOutputSizeW,
// gradInputSize,
// out_cf_c_stride,
// in_cf_c_stride,
// out_n_stride,
// in_n_stride,
// kernel_h,
// kernel_w,
// stride_h,
// stride_w,
// pad_h,
// pad_w,
// dilation_h,
// dilation_w);
// return;
// case 4:
// LAUNCH_MAXPOOL_BACKWARD_CHANNEL_LAST_VEC(
// scalar_t,
// 1,
// num_wg,
// wg_size,
// queue,
// gradInput,
// gradOutput,
// indices,
// numPlane,
// gradInputSizeH,
// gradInputSizeW,
// gradOutputSizeH,
// gradOutputSizeW,
// gradInputSize,
// out_cf_c_stride,
// in_cf_c_stride,
// out_n_stride,
// in_n_stride,
// kernel_h,
// kernel_w,
// stride_h,
// stride_w,
// pad_h,
// pad_w,
// dilation_h,
// dilation_w);
// return;
// case 2:
// LAUNCH_MAXPOOL_BACKWARD_CHANNEL_LAST_VEC(
// scalar_t,
// 2,
// num_wg,
// wg_size,
// queue,
// gradInput,
// gradOutput,
// indices,
// numPlane,
// gradInputSizeH,
// gradInputSizeW,
// gradOutputSizeH,
// gradOutputSizeW,
// gradInputSize,
// out_cf_c_stride,
// in_cf_c_stride,
// out_n_stride,
// in_n_stride,
// kernel_h,
// kernel_w,
// stride_h,
// stride_w,
// pad_h,
// pad_w,
// dilation_h,
// dilation_w);
// return;
// default:
// break;
// };
// }
using KernelClass =
MaxPool2dBackwardDeterministicKernelFunctor<scalar_t, is_channels_last>;
BatchKernelConfig cfg = BatchKernelConfig::make_config<KernelClass>(
Expand Down
Loading