@@ -2127,6 +2127,30 @@ inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor *
2127
2127
sum_rows_f32_sycl (src0_dd, dst_dd, ncols, nrows, main_stream);
2128
2128
}
2129
2129
2130
+ inline void ggml_sycl_op_mean (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2131
+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
2132
+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
2133
+
2134
+ dpct::queue_ptr main_stream = ctx.stream ();
2135
+ SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
2136
+
2137
+ const float * src0_dd = static_cast <const float *>(dst->src [0 ]->data );
2138
+ float * dst_dd = static_cast <float *>(dst->data );
2139
+
2140
+ const int64_t ncols = dst->src [0 ]->ne [0 ];
2141
+ const int64_t nrows = ggml_nrows (dst->src [0 ]);
2142
+
2143
+ sum_rows_f32_sycl (src0_dd, dst_dd, ncols, nrows, main_stream);
2144
+
2145
+ main_stream->parallel_for (
2146
+ sycl::range<1 >(nrows),
2147
+ [=](sycl::id<1 > row) {
2148
+ dst_dd[row] /= ncols;
2149
+ }
2150
+ );
2151
+ }
2152
+
2153
+
2130
2154
inline void ggml_sycl_op_argsort (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
2131
2155
GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
2132
2156
GGML_ASSERT (dst->type == GGML_TYPE_I32);
@@ -3510,6 +3534,12 @@ static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, ggml_tensor * ds
3510
3534
ggml_sycl_op_sum_rows (ctx, dst);
3511
3535
}
3512
3536
3537
+ static void ggml_sycl_mean (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3538
+ scope_op_debug_print scope_dbg_print (__func__, dst, /* num_src=*/ 1 );
3539
+ GGML_ASSERT (ggml_is_contiguous (dst->src [0 ]));
3540
+ ggml_sycl_op_mean (ctx, dst);
3541
+ }
3542
+
3513
3543
static void ggml_sycl_argsort (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
3514
3544
scope_op_debug_print scope_dbg_print (__func__, dst, /* num_src=*/ 1 );
3515
3545
GGML_ASSERT (ggml_is_contiguous (dst->src [0 ]));
@@ -3753,6 +3783,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
3753
3783
case GGML_OP_SUM_ROWS:
3754
3784
ggml_sycl_sum_rows (ctx, dst);
3755
3785
break ;
3786
+ case GGML_OP_MEAN:
3787
+ ggml_sycl_mean (ctx, dst);
3788
+ break ;
3756
3789
case GGML_OP_ARGSORT:
3757
3790
ggml_sycl_argsort (ctx, dst);
3758
3791
break ;
@@ -4402,6 +4435,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
4402
4435
return op->src [0 ]->type == GGML_TYPE_F32 && op->op_params [0 ] == GGML_SCALE_MODE_NEAREST;
4403
4436
case GGML_OP_SUM:
4404
4437
case GGML_OP_SUM_ROWS:
4438
+ case GGML_OP_MEAN:
4405
4439
case GGML_OP_ARGSORT:
4406
4440
return ggml_is_contiguous (op->src [0 ]);
4407
4441
case GGML_OP_POOL_2D:
0 commit comments