From 950671dc3a179d84352098acb0a766c74799ba12 Mon Sep 17 00:00:00 2001 From: Alberto Cabrera Date: Wed, 5 Nov 2025 14:04:40 +0000 Subject: [PATCH 1/6] ggml-cpu: handle 3d tensors in repack mul_mat --- ggml/src/ggml-cpu/repack.cpp | 127 ++++++++++++++++++++++++----------- 1 file changed, 87 insertions(+), 40 deletions(-) diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 8421c84ce0942..eed5710d90d46 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -11,6 +11,7 @@ #include "arch-fallback.h" +#include #include #include #include @@ -1600,29 +1601,48 @@ template src[0]; const ggml_tensor * src1 = op->src[1]; ggml_tensor * dst = op; GGML_TENSOR_BINARY_OP_LOCALS - const void * src1_wdata = params->wdata; const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10); + GGML_ASSERT(ne03 == 1 && ne13 == 1); + GGML_ASSERT(ne12 % ne02 == 0); + const int64_t r2 = ne12 / ne02; + + const int64_t i12 = src1_start / ne1; + const int64_t i11 = src1_start - i12 * ne1; + + // Determine batch index + const int64_t i02 = i12 / r2; + + const int64_t i1 = i11; + const int64_t i2 = i12; + + const char *src0_ptr = (const char*)src0->data + i02 * nb02; + const char *src1_ptr = (const char*)params->wdata + (i11 + i12 * ne11) * src1_col_stride; + float *dst_ptr = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2)); + + const int64_t nrows = src1_end - src1_start; + const int64_t ncols = src0_end - src0_start; + // If there are more than three rows in src1, use gemm; otherwise, use gemv. - if (ne11 > 3) { + if (nrows > 3) { gemm(ne00, - (float *) ((char *) dst->data) + src0_start, ne01, - (const char *) src0->data + src0_start * nb01, - (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start); + dst_ptr + src0_start, nb1 / nb0, + src0_ptr + src0_start * nb01, + src1_ptr, nrows - (nrows % 4), ncols); } - for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) { + for (int iter = nrows - (nrows % 4); iter < nrows; iter++) { gemv(ne00, - (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01, - (const char *) src0->data + src0_start * nb01, - (const char *) src1_wdata + (src1_col_stride * iter), 1, - src0_end - src0_start); + dst_ptr + (iter * nb1) + src0_start, ne01, + src0_ptr + src0_start * nb01, + src1_ptr + (src1_col_stride * iter), 1 /* nrows */, + ncols); } } @@ -1647,54 +1667,72 @@ template type == GGML_TYPE_F32); GGML_ASSERT(ggml_n_dims(op->src[0]) == 2); // GGML_ASSERT(ggml_n_dims(op->src[1]) == 2); char * wdata = static_cast(params->wdata); - const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10); + const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10); + const size_t nbw2 = nbw1 * ne11; - assert(params->wsize >= nbw1 * ne11); + assert(params->wsize >= nbw2 * ne12); const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float; - int64_t i11_processed = 0; - for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) { - ggml_quantize_mat_t((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10); - } + for (int64_t i12 = 0; i12 < ne12; i12++) { + char * data_ptr = (char *) src1->data + i12 * nb12; + char * wdata_ptr = wdata + i12 * nbw2; + + for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) { + ggml_quantize_mat_t((float *) (data_ptr + i11 * nb11), + (void *) (wdata_ptr + i11 * nbw1), 4, ne10); + } - i11_processed = ne11 - ne11 % 4; - for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) { - from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10); + const int64_t i11_processed = ne11 - ne11 % 4; + for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) { + from_float((float *) (data_ptr + i11 * nb11), (void *) (wdata_ptr + i11 * nbw1), ne10); + } } // disable for NUMA const bool disable_chunking = ggml_is_numa(); // 4x chunks per thread - int64_t nr = ggml_nrows(op->src[0]); - int nth_scaled = nth * 4; - int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled; - int64_t nchunk = (nr + chunk_size - 1) / chunk_size; + const int64_t nr0 = ggml_nrows(op->src[0]); + const int64_t nr1 = ne1 * ne2 * ne3; + + int nth_scaled = nth * 4; + int64_t chunk_size0 = (nr0 + nth_scaled - 1) / nth_scaled; + // avoid too small chunks for narrow src1 + int64_t chunk_size1 = std::max(16, (nr1 + nth - 1) / nth); + int64_t nchunk0 = (nr0 + chunk_size0 - 1) / chunk_size0; + int64_t nchunk1 = (nr1 + chunk_size1 - 1) / chunk_size1; // Ensure minimum chunk size to avoid alignment issues with high thread counts // Minimum chunk size should be at least NB_COLS to prevent overlapping chunks after alignment const int64_t min_chunk_size = NB_COLS; - if (nchunk > 0 && (nr / nchunk) < min_chunk_size && nr >= min_chunk_size) { - nchunk = (nr + min_chunk_size - 1) / min_chunk_size; + if (nchunk0 > 0 && (nr0 / nchunk0) < min_chunk_size && nr0 >= min_chunk_size) { + nchunk0 = (nr0 + min_chunk_size - 1) / min_chunk_size; } - if (nth == 1 || nchunk < nth || disable_chunking) { - nchunk = nth; + + if (nth == 1 || nchunk0 * nchunk1 < nth || disable_chunking) { + nchunk0 = nr0 > nr1 ? nth : 1; + nchunk1 = nr0 > nr1 ? 1 : nth; } + const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; + const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; + // Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size // This prevents creating too many tiny chunks that could overlap after alignment - const int64_t max_nchunk = (nr + min_chunk_size - 1) / min_chunk_size; - if (nchunk > max_nchunk) { - nchunk = max_nchunk; - } + const int64_t max_nchunk = (nr0 + min_chunk_size - 1) / min_chunk_size; + nchunk0 = std::min(nchunk0, max_nchunk); if (ith == 0) { // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. @@ -1706,23 +1744,32 @@ template ne01) { - src0_end = ne01; - } + src0_end = std::min(src0_end, ne01); + // Make sure current plane is the last one before exiting if (src0_start >= src0_end) { - break; + if (nth >= nchunk0 * nchunk1) { + break; + } + current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); + continue; } - forward_mul_mat_one_chunk(params, dst, src0_start, src0_end); + forward_mul_mat_one_chunk(params, dst, src0_start, src0_end, src1_start, src1_end); current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); } From 0b8665116c9417c95e54b8abea9165fa0ac219bc Mon Sep 17 00:00:00 2001 From: Alberto Cabrera Date: Wed, 5 Nov 2025 18:38:18 +0000 Subject: [PATCH 2/6] Removed unnecessary branch, removed need for --- ggml/src/ggml-cpu/repack.cpp | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index eed5710d90d46..5398f109b0009 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -11,7 +11,6 @@ #include "arch-fallback.h" -#include #include #include #include @@ -1709,7 +1708,7 @@ template (16, (nr1 + nth - 1) / nth); + int64_t chunk_size1 = MAX(16, (nr1 + nth - 1) / nth); int64_t nchunk0 = (nr0 + chunk_size0 - 1) / chunk_size0; int64_t nchunk1 = (nr1 + chunk_size1 - 1) / chunk_size1; @@ -1732,7 +1731,7 @@ template = src0_end) { - if (nth >= nchunk0 * nchunk1) { - break; - } current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); continue; } From 75c7fd5d4ed2277584ce58bd0fbb5dd078c75dd0 Mon Sep 17 00:00:00 2001 From: Alberto Cabrera Date: Thu, 6 Nov 2025 12:10:30 +0000 Subject: [PATCH 3/6] Fixed dst_ptr pointer in chunk + clang_format --- ggml/src/ggml-cpu/repack.cpp | 49 +++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 5398f109b0009..60abfd531bcb1 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -1600,7 +1600,12 @@ template src[0]; const ggml_tensor * src1 = op->src[1]; ggml_tensor * dst = op; @@ -1622,26 +1627,23 @@ template data + i02 * nb02; - const char *src1_ptr = (const char*)params->wdata + (i11 + i12 * ne11) * src1_col_stride; - float *dst_ptr = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2)); + const char * src0_ptr = (const char *) src0->data + i02 * nb02; + const char * src1_ptr = (const char *) params->wdata + (i11 + i12 * ne11) * src1_col_stride; + char * dst_ptr = ((char *) dst->data + (i1 * nb1 + i2 * nb2)); const int64_t nrows = src1_end - src1_start; const int64_t ncols = src0_end - src0_start; // If there are more than three rows in src1, use gemm; otherwise, use gemv. if (nrows > 3) { - gemm(ne00, - dst_ptr + src0_start, nb1 / nb0, - src0_ptr + src0_start * nb01, - src1_ptr, nrows - (nrows % 4), ncols); + gemm(ne00, (float *) (dst_ptr) + src0_start, nb1 / nb0, + src0_ptr + src0_start * nb01, src1_ptr, + nrows - (nrows % 4), ncols); } for (int iter = nrows - (nrows % 4); iter < nrows; iter++) { - gemv(ne00, - dst_ptr + (iter * nb1) + src0_start, ne01, - src0_ptr + src0_start * nb01, - src1_ptr + (src1_col_stride * iter), 1 /* nrows */, - ncols); + gemv(ne00, (float *) (dst_ptr + (iter * nb1)) + src0_start, + ne01, src0_ptr + src0_start * nb01, + src1_ptr + (src1_col_stride * iter), 1 /* nrows */, ncols); } } @@ -1668,7 +1670,9 @@ template type == GGML_TYPE_F32); @@ -1676,16 +1680,16 @@ template src[1]) == 2); char * wdata = static_cast(params->wdata); - const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10); - const size_t nbw2 = nbw1 * ne11; + const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10); + const size_t nbw2 = nbw1 * ne11; assert(params->wsize >= nbw2 * ne12); const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float; for (int64_t i12 = 0; i12 < ne12; i12++) { - char * data_ptr = (char *) src1->data + i12 * nb12; - char * wdata_ptr = wdata + i12 * nbw2; + char * data_ptr = (char *) src1->data + i12 * nb12; + char * wdata_ptr = wdata + i12 * nbw2; for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) { ggml_quantize_mat_t((float *) (data_ptr + i11 * nb11), @@ -1719,10 +1723,9 @@ template nr1 ? nth : 1; - nchunk1 = nr0 > nr1 ? 1 : nth; + nchunk1 = nr0 > nr1 ? 1 : nth; } const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; @@ -1731,7 +1734,7 @@ template Date: Mon, 10 Nov 2025 17:54:14 +0000 Subject: [PATCH 4/6] GGML_ASSERT to check wdata within bounds --- ggml/src/ggml-cpu/repack.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 60abfd531bcb1..7b4eddba6a5b5 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -1,3 +1,4 @@ +#include "ggml.h" #define GGML_COMMON_IMPL_CPP #define GGML_COMMON_DECL_CPP #include "ggml-common.h" @@ -1630,6 +1631,7 @@ template data + i02 * nb02; const char * src1_ptr = (const char *) params->wdata + (i11 + i12 * ne11) * src1_col_stride; char * dst_ptr = ((char *) dst->data + (i1 * nb1 + i2 * nb2)); + GGML_ASSERT(src1_ptr >= params->wdata && src1_ptr < ((const char *)params->wdata + params->wsize)); const int64_t nrows = src1_end - src1_start; const int64_t ncols = src0_end - src0_start; From b56d0acef9d3ef5fab98a75f5fdf75a946e55566 Mon Sep 17 00:00:00 2001 From: Alberto Cabrera Date: Mon, 10 Nov 2025 18:08:22 +0000 Subject: [PATCH 5/6] Accidental ggml.h inclusion --- ggml/src/ggml-cpu/repack.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 7b4eddba6a5b5..e4a409237abd1 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -1,4 +1,3 @@ -#include "ggml.h" #define GGML_COMMON_IMPL_CPP #define GGML_COMMON_DECL_CPP #include "ggml-common.h" From d1938adb33d9ead115393d725e63cbbd35167126 Mon Sep 17 00:00:00 2001 From: Alberto Cabrera Date: Mon, 10 Nov 2025 20:25:53 +0000 Subject: [PATCH 6/6] Improved GGML_ASSERT on wdata boundaries --- ggml/src/ggml-cpu/repack.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index e4a409237abd1..274be146dc5cf 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -1630,11 +1630,12 @@ template data + i02 * nb02; const char * src1_ptr = (const char *) params->wdata + (i11 + i12 * ne11) * src1_col_stride; char * dst_ptr = ((char *) dst->data + (i1 * nb1 + i2 * nb2)); - GGML_ASSERT(src1_ptr >= params->wdata && src1_ptr < ((const char *)params->wdata + params->wsize)); const int64_t nrows = src1_end - src1_start; const int64_t ncols = src0_end - src0_start; + GGML_ASSERT(src1_ptr + src1_col_stride * nrows <= (const char *) params->wdata + params->wsize); + // If there are more than three rows in src1, use gemm; otherwise, use gemv. if (nrows > 3) { gemm(ne00, (float *) (dst_ptr) + src0_start, nb1 / nb0,