Skip to content

Commit e33b715

Browse files
committed
SYCL: Add GGML_OP_MEAN operator support
1 parent 28c39da commit e33b715

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2127,6 +2127,30 @@ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *
21272127
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
21282128
}
21292129

2130+
inline void ggml_sycl_op_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2131+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
2132+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
2133+
2134+
dpct::queue_ptr main_stream = ctx.stream();
2135+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
2136+
2137+
const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
2138+
float * dst_dd = static_cast<float *>(dst->data);
2139+
2140+
const int64_t ncols = dst->src[0]->ne[0];
2141+
const int64_t nrows = ggml_nrows(dst->src[0]);
2142+
2143+
sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
2144+
2145+
main_stream->parallel_for(
2146+
sycl::range<1>(nrows),
2147+
[=](sycl::id<1> row) {
2148+
dst_dd[row] /= ncols;
2149+
}
2150+
);
2151+
}
2152+
2153+
21302154
inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
21312155
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
21322156
GGML_ASSERT(dst->type == GGML_TYPE_I32);
@@ -3510,6 +3534,12 @@ static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * ds
35103534
ggml_sycl_op_sum_rows(ctx, dst);
35113535
}
35123536

3537+
static void ggml_sycl_mean(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3538+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
3539+
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
3540+
ggml_sycl_op_mean(ctx, dst);
3541+
}
3542+
35133543
static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
35143544
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
35153545
GGML_ASSERT(ggml_is_contiguous(dst->src[0]));
@@ -3753,6 +3783,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
37533783
case GGML_OP_SUM_ROWS:
37543784
ggml_sycl_sum_rows(ctx, dst);
37553785
break;
3786+
case GGML_OP_MEAN:
3787+
ggml_sycl_mean(ctx, dst);
3788+
break;
37563789
case GGML_OP_ARGSORT:
37573790
ggml_sycl_argsort(ctx, dst);
37583791
break;
@@ -4402,6 +4435,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
44024435
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
44034436
case GGML_OP_SUM:
44044437
case GGML_OP_SUM_ROWS:
4438+
case GGML_OP_MEAN:
44054439
case GGML_OP_ARGSORT:
44064440
return ggml_is_contiguous(op->src[0]);
44074441
case GGML_OP_POOL_2D:

0 commit comments

Comments
 (0)