From f201162bec0e1de862d9399c54acad8ae978976c Mon Sep 17 00:00:00 2001 From: Jonathan Clohessy Date: Tue, 14 Oct 2025 10:30:22 +0100 Subject: [PATCH 1/3] Implement FP32 kleidiai Gemv Signed-off-by: Jonathan Clohessy --- .../core/mlas/lib/kai_ukernel_interface.cpp | 74 +++++ .../core/mlas/lib/kai_ukernel_interface.h | 7 + .../core/mlas/lib/kleidiai/mlasi_kleidiai.h | 28 +- .../core/mlas/lib/kleidiai/sgemm_kleidiai.cpp | 284 ++++++++++++++---- onnxruntime/core/mlas/lib/platform.cpp | 4 +- onnxruntime/core/mlas/lib/qgemm.cpp | 6 +- .../test/mlas/unittest/test_fgemm_fixture.h | 7 + 7 files changed, 343 insertions(+), 67 deletions(-) diff --git a/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp b/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp index fdada83cc6582..1d9d905c5d0d2 100644 --- a/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp +++ b/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp @@ -6,12 +6,20 @@ #include "kai_ukernel_interface.h" #include "mlasi.h" +#include "kleidiai/mlasi_kleidiai.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h" + const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod = {kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, @@ -64,6 +72,56 @@ const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm}; +const kai_matmul_clamp_f32_f32_f32p_ukernel sgemm_gemv_sme = + {kai_get_m_step_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla, + kai_get_n_step_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla, + kai_get_nr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla, + kai_get_kr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla, + kai_get_sr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla, + kai_get_lhs_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla, + kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla, + kai_get_dst_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla, + kai_get_dst_size_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla, + kai_run_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla}; + +const kai_matmul_clamp_f32_f32_f32p_ukernel sgemm_gemv_sme2 = + {kai_get_m_step_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla, + kai_get_n_step_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla, + kai_get_nr_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla, + kai_get_kr_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla, + kai_get_sr_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla, + kai_get_lhs_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla, + kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla, + kai_get_dst_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla, + kai_get_dst_size_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla, + kai_run_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla}; + +const kai_matmul_clamp_f32_f32p_f32p_ukernel sgemm_gemm_sme = + {kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa, + kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa, + kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa, + kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa, + kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa, + kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa, + kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa, + kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa, + kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa, + kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa, + kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa}; + +const kai_matmul_clamp_f32_f32p_f32p_ukernel sgemm_gemm_sme2 = + {kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa}; + const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemmUKernel() { if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM()) { return kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm; @@ -79,3 +137,19 @@ const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemvUKernel() { return kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod; } } + +const kai_matmul_clamp_f32_f32p_f32p_ukernel& GetKleidiAISGemmUKernel() { + if (ArmKleidiAI::SMEInfo::CanUseSME2) { + return sgemm_gemm_sme2; + } else { + return sgemm_gemm_sme; + } +} + +const kai_matmul_clamp_f32_f32_f32p_ukernel& GetKleidiAISGemvUKernel() { + if (ArmKleidiAI::SMEInfo::CanUseSME2) { + return sgemm_gemv_sme2; + } else { + return sgemm_gemv_sme; + } +} diff --git a/onnxruntime/core/mlas/lib/kai_ukernel_interface.h b/onnxruntime/core/mlas/lib/kai_ukernel_interface.h index 1a6f111d1c794..e69c72329d64b 100644 --- a/onnxruntime/core/mlas/lib/kai_ukernel_interface.h +++ b/onnxruntime/core/mlas/lib/kai_ukernel_interface.h @@ -8,5 +8,12 @@ #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p_f32p_interface.h" + +#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p_interface.h" + const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemmUKernel(); const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemvUKernel(); + +const kai_matmul_clamp_f32_f32p_f32p_ukernel& GetKleidiAISGemmUKernel(); +const kai_matmul_clamp_f32_f32_f32p_ukernel& GetKleidiAISGemvUKernel(); diff --git a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h index 2e9c4574fd057..9c61c887675fe 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h +++ b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h @@ -15,8 +15,21 @@ #define RESTRICT __restrict__ #endif namespace ArmKleidiAI { + +struct SMEInfo { + static const bool CanUseSME2; + static const bool CanUseSME; + static const bool IsSMEAvailable; +}; + +// Boolean condition to determine if we can use SME2 // By default we should try for SME2 first before falling back to SME. -inline const bool UseSME2 = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2(); +inline const bool SMEInfo::CanUseSME2 = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2(); +// Boolean condition to determine if we can use SME +inline const bool SMEInfo::CanUseSME = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME(); +// Boolean condition to tell us if SME is enabled on this system +inline const bool SMEInfo::IsSMEAvailable = SMEInfo::CanUseSME2 || SMEInfo::CanUseSME; + // // Buffer packing routines. @@ -43,6 +56,19 @@ MlasGemmPackB( void* PackedB ); +bool +MLASCALL +MlasFp32Gemv( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SGEMM_DATA_PARAMS* Data, + size_t BatchSize + ); + + bool MLASCALL MlasGemmBatch( diff --git a/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp index 435ff1fb10017..fd60cbfa88ed7 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp @@ -12,7 +12,9 @@ #include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.h" +#include "mlas.h" #include "mlasi_kleidiai.h" +#include "kai_ukernel_interface.h" // Thread-local reusable buffers to reduce allocation overhead across tiles. @@ -21,9 +23,200 @@ struct KaiTlsBuffers { std::vector bias_zero; std::vector rhs_packed; std::vector lhs_packed; + std::vector gemv_lhs_row_tmp; }; static thread_local KaiTlsBuffers g_kai_tls; +kai_matmul_clamp_f32_f32p_f32p_ukernel sgemm_gemm = GetKleidiAISGemmUKernel(); +kai_matmul_clamp_f32_f32_f32p_ukernel sgemm_gemv = GetKleidiAISGemvUKernel(); + + +// Helpers for GEMV +/*++ +Routine Description: + Apply alpha/beta scaling to a 1-D vector with arbitrary destination stride. + +Arguments: + src - Pointer to the temporary A*B results (length L). + num_elements - Number of elements. + alpha - Scale for the computed product (A*B). + beta - Scale for the existing C values. + dst - Pointer to the destination in C. + dst_stride - Stride, in elements, between successive outputs in C. + allow_memcpy - If true, allows memcpy path when alpha==1, beta==0, and dst_stride==1. + +Notes: + Uses a memcpy path when alpha==1, beta==0, allow_memcpy is true, and dst_stride==1. +--*/ +static inline void ApplyAlphaBetaStrided(const float* src, size_t num_elements, float alpha, float beta, float* dst, size_t dst_stride, bool allow_memcpy) { + if (alpha == 1.0f && beta == 0.0f && allow_memcpy && dst_stride == 1) { + std::memcpy(dst, src, num_elements * sizeof(float)); + return; + } + for (size_t i = 0; i < num_elements; ++i) { + const float ab = src[i]; + float& d = dst[i * dst_stride]; + const float c_orig = d; + if (alpha == 1.0f && beta == 0.0f) { + d = ab; + } else if (alpha == 1.0f) { + d = ab + beta * c_orig; + } else if (beta == 0.0f) { + d = alpha * ab; + } else { + d = alpha * ab + beta * c_orig; + } + } +} + +/*++ +Routine Description: + Apply alpha/beta scaling to a 2-D tile (rows x cols). + +Arguments: + src - Pointer to the temporary A*B results (row-major, rows x cols). + rows - Number of rows in the tile. + cols - Number of columns in the tile. + alpha - Scale for the computed product (A*B). + beta - Scale for the existing C values. + dst - Pointer to the destination tile in C (row-major with leading dimension ldc). + ldc - Leading dimension of C (in elements). + +Notes: + Uses a memcpy path when alpha==1, beta==0, ldc==cols, and rows/cols are non-zero. + Otherwise applies per-row scaling via ApplyAlphaBetaStrided. +--*/ +static inline void ApplyAlphaBeta2D(const float* src, size_t rows, size_t cols, + float alpha, float beta, + float* dst, size_t ldc) { + if (alpha == 1.0f && beta == 0.0f && ldc == cols && rows != 0 && cols != 0) { + std::memcpy(dst, src, rows * cols * sizeof(float)); + return; + } + for (size_t i = 0; i < rows; ++i) { + const float* src_row = src + i * cols; + float* dst_row = dst + i * ldc; + ApplyAlphaBetaStrided(src_row, cols, alpha, beta, dst_row, 1, /*allow_memcpy*/ (ldc == cols)); + } +} + +/*++ +Routine Description: + Execute GEMV using the SME/SME2 1xN microkernel for degenerate GEMM shapes: + - M == 1 (row-vector times matrix) + - N == 1 (matrix times column-vector) + +N == 1 mapping (y = A(MxK) * b(Kx1)): + The 1xN microkernel computes a single LHS row against multiple RHS columns. + To reuse it for N == 1, we present A as the "RHS" by transpose-packing it + so that each of A's M rows becomes a "column" for the kernel: + - rhsBase := A, rhsShape := M, ldl := lda, tb := CblasTrans + - lhsBase := B (the vector b), length K + The kernel expects the LHS vector to be a contiguous K-length row: + - If TransB == CblasNoTrans, b is stored as a Kx1 column with stride ldb. + We gather it into a thread-local contiguous buffer when ldb != 1. + - If TransB == CblasTrans, b is a 1xK row and is already contiguous. + +Unsupported: + When N == 1 and Data->BIsPacked is true (except M == N == 1), this path is + disabled because we need to pack A (as RHS) and pass B as an unpacked vector. + +Post-processing: + The kernel produces M outputs into a temporary buffer. We apply alpha/beta + and write to C using ldc as the destination stride. + +Return Value: + true - A GEMV path was executed (M == 1 or N == 1). + false - Fall back to the general GEMM path. +--*/ + +bool +MLASCALL +ArmKleidiAI::MlasFp32Gemv( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SGEMM_DATA_PARAMS* Data, + size_t BatchSize +) { + // Edge case where M and N are both one so we can take the simpler M == 1 Path + bool m_n_both_one = (M == 1 && N == 1); + if (N == 1 && Data->BIsPacked && !m_n_both_one) + { + // Exit early because we cannot support cases where N is 1 and B is already packed + return false; + } + + for (size_t b = 0; b < BatchSize; ++b) { + // Depending on the value of M or N we might transpose when packing + CBLAS_TRANSPOSE tb = TransB; + size_t rhs_shape = N; + size_t ldl = Data[b].ldb; + const float* rhs_base = (M == 1) ? reinterpret_cast(Data[b].B) : reinterpret_cast(Data[b].A); + const float* lhs_base = (M == 1) ? static_cast(Data[b].A) : static_cast(Data[b].B); + if (N == 1 && !m_n_both_one) + { + tb = CblasTrans; + rhs_shape = M; + ldl = Data[b].lda; + } + // Prepare packed RHS if needed + const void* rhs_packed_ptr = nullptr; + if (Data[b].BIsPacked ) { + rhs_packed_ptr = Data[b].B; + } else { + const size_t rhs_size = ArmKleidiAI::MlasGemmPackBSize(TransA, tb, rhs_shape, K); + if (rhs_size == 0) { + return false; + } + g_kai_tls.rhs_packed.resize(rhs_size); + + ArmKleidiAI::MlasGemmPackB( + TransA, tb, rhs_shape, K, + rhs_base, + ldl, + g_kai_tls.rhs_packed.data()); + rhs_packed_ptr = g_kai_tls.rhs_packed.data(); + } + // We have to handle the case where we transpose the data to the correct format as we used a traspose packing kernel + if (N == 1 && TransB == CblasNoTrans) + { + g_kai_tls.gemv_lhs_row_tmp.resize(K); + + for (size_t k = 0; k < K; ++k) { + g_kai_tls.gemv_lhs_row_tmp[k] = lhs_base[k * Data[b].ldb]; + } + lhs_base = g_kai_tls.gemv_lhs_row_tmp.data(); + } + + // Temporary buffer for output row + g_kai_tls.output_tile.resize(rhs_shape); + std::fill_n(g_kai_tls.output_tile.data(), rhs_shape, 0.0f); + + // Run specialized 1xN-by-K kernel + sgemm_gemv.run_matmul( + 1, // Value of 1 for M == 1 and this value represents N when N == 1 case + rhs_shape, // Value of N for M == 1 and this value is M when N == 1 + K, // K + lhs_base, // lhs + K * sizeof(float), // lhs stride (bytes) + rhs_packed_ptr, // packed rhs + g_kai_tls.output_tile.data(), // output + rhs_shape * sizeof(float), // dst row stride (bytes) + sizeof(float), // dst col stride (bytes) + -std::numeric_limits::max(), + std::numeric_limits::max() + ); + // Apply alpha/beta to destination C row + bool allowMemCopy = (M == 1) ? (Data[b].ldc == N) : (Data[b].ldc == 1); + size_t destStride = (M == 1) ? 1 : Data[b].ldc; + ApplyAlphaBetaStrided(g_kai_tls.output_tile.data(), rhs_shape, Data[b].alpha, Data[b].beta, Data[b].C, destStride, allowMemCopy); + } + return true; +} + size_t MLASCALL ArmKleidiAI::MlasGemmPackBSize( @@ -127,12 +320,10 @@ Return Value: } if (TransA == CblasNoTrans) { - const size_t nr = UseSME2 ? kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() - : kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); - const size_t kr = UseSME2 ? kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() - : kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); - const size_t sr = UseSME2 ? kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() - : kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); + + const size_t nr = sgemm_gemm.get_nr(); + const size_t kr = sgemm_gemm.get_kr(); + const size_t sr = sgemm_gemm.get_sr(); // Ensure size and zero the used span. g_kai_tls.bias_zero.resize(N, 0.0f); @@ -170,7 +361,7 @@ ArmKleidiAI::MlasGemmBatch( Routine Description: - This routine performs a batched matrix multiplication (GEMM) operation using KleidiAI kernels. + This routine performs a batched matrix multiplication (GEMM or GemV) operation using KleidiAI kernels. It handles both packed and unpacked inputs and manages tiling and kernel selection depending on SME2 availability. If packing is needed, it prepares the required buffers and invokes the appropriate left-hand side (LHS) and right-hand side (RHS) pack functions. @@ -222,23 +413,26 @@ Return Value: return true; } - const size_t mr = UseSME2 ? kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() - : kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); - const size_t kr = UseSME2 ? kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() - : kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); - const size_t sr = UseSME2 ? kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() - : kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); + // Attempt GEMV (M==1 or N==1) + if (M == 1 || N == 1) + { + if (ArmKleidiAI::MlasFp32Gemv(TransA, TransB, M, N, K, Data, BatchSize)) { + return true; + } + } - size_t m_step = UseSME2 ? kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() - : kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); - size_t n_step = UseSME2 ? kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() - : kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); + size_t m_step = sgemm_gemm.get_m_step(); + size_t n_step = sgemm_gemm.get_n_step(); if ((M < m_step || N < n_step) && !Data->BIsPacked) { // Fallback to MLAS return false; } + const size_t mr = sgemm_gemm.get_mr(); + const size_t kr = sgemm_gemm.get_kr(); + const size_t sr = sgemm_gemm.get_sr(); + size_t LhsPackedStride = 0; std::byte* LhsPackedData = nullptr; @@ -329,9 +523,7 @@ Return Value: ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2]; // Get rhs tile, B - const size_t rhs_packed_offset = - UseSME2 ? kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(NIdx * n_step, K) - : kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(NIdx * n_step, K); + const size_t rhs_packed_offset = sgemm_gemm.get_rhs_packed_offset(NIdx * n_step, K); const std::byte* B_base = Data[0].BIsPacked ? reinterpret_cast(Data[BIdx].B) @@ -339,9 +531,7 @@ Return Value: auto BTile = reinterpret_cast(B_base + rhs_packed_offset); // Get lhs tile, A - const size_t lhs_packed_offset = - UseSME2 ? kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(MIdx * m_step, K) - : kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(MIdx * m_step, K); + const size_t lhs_packed_offset = sgemm_gemm.get_lhs_packed_offset(MIdx * m_step, K); const std::byte* A_base = LhsPackedData + LhsPackedStride * BIdx; auto ATile = reinterpret_cast(A_base + lhs_packed_offset); @@ -364,25 +554,15 @@ Return Value: float* temp_tile = g_kai_tls.output_tile.data(); std::fill_n(temp_tile, tile_elems, 0.0f); - if (UseSME2) { - kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa( - TileSizeM, - TileSizeN, - K, - ATile, BTile, temp_tile, - TileSizeN * sizeof(float), sizeof(float), - -std::numeric_limits::max(), std::numeric_limits::max() - ); - } else { - kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa( - TileSizeM, - TileSizeN, - K, - ATile, BTile, temp_tile, - TileSizeN * sizeof(float), sizeof(float), - -std::numeric_limits::max(), std::numeric_limits::max() - ); - } + + sgemm_gemm.run_matmul( + TileSizeM, + TileSizeN, + K, + ATile, BTile, temp_tile, + TileSizeN * sizeof(float), sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max() + ); // Final output tile pointer float* dst_tile = reinterpret_cast(CTile); @@ -407,25 +587,7 @@ Return Value: float beta = Data[BIdx].beta; size_t ldc = Data[BIdx].ldc; - for (size_t i = 0; i < TileSizeM; ++i) { - for (size_t j = 0; j < TileSizeN; ++j) { - const size_t temp_idx = i * TileSizeN + j; - const size_t dst_idx = i * ldc + j; - - float ab = temp_tile[temp_idx]; - float c_orig = dst_tile[dst_idx]; - - if (alpha == 1.0f && beta == 0.0f) { - dst_tile[dst_idx] = ab; - } else if (alpha == 1.0f) { - dst_tile[dst_idx] = ab + beta * c_orig; - } else if (beta == 0.0f) { - dst_tile[dst_idx] = alpha * ab; - } else { - dst_tile[dst_idx] = alpha * ab + beta * c_orig; - } - } - } + ApplyAlphaBeta2D(temp_tile, TileSizeM, TileSizeN, alpha, beta, dst_tile, ldc); return; }); return true; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 46fa150395d75..21369459c201b 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -596,7 +596,7 @@ Return Value: } #if defined(USE_KLEIDIAI) && !defined(_MSC_VER) - if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){ + if (ArmKleidiAI::SMEInfo::IsSMEAvailable) { this->MlasGemmBatchOverride = ArmKleidiAI::MlasGemmBatch; this->MlasGemmPackBSizeOverride = ArmKleidiAI::MlasGemmPackBSize; this->MlasGemmPackBOverride = ArmKleidiAI::MlasGemmPackB; @@ -799,4 +799,4 @@ thread_local size_t ThreadedBufSize = 0; thread_local std::unique_ptr ThreadedBufHolder(nullptr, &_aligned_free); #else thread_local std::unique_ptr ThreadedBufHolder(nullptr, &free); -#endif \ No newline at end of file +#endif diff --git a/onnxruntime/core/mlas/lib/qgemm.cpp b/onnxruntime/core/mlas/lib/qgemm.cpp index 4e9a0e27099dc..1aa90dc48f802 100644 --- a/onnxruntime/core/mlas/lib/qgemm.cpp +++ b/onnxruntime/core/mlas/lib/qgemm.cpp @@ -211,8 +211,8 @@ MlasDynamicQGemmBatch ( ) { #if defined(USE_KLEIDIAI) && !defined(_MSC_VER) //No fallback and putting in guards - if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){ - ArmKleidiAI::MlasDynamicQGemmBatch(Shape, DataParams, BatchN, ThreadPool); + if(ArmKleidiAI::SMEInfo::CanUseSME2){ + ArmKleidiAI::MlasDynamicQGemmBatch(Shape, DataParams, BatchN, ThreadPool); } #endif @@ -336,7 +336,7 @@ MlasDynamicQgemmPackBSize( #if defined(USE_KLEIDIAI) && !defined(_MSC_VER) //No fallback available //TODO: Insert Override - if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){//Still require this since no override + if(ArmKleidiAI::SMEInfo::CanUseSME2){//Still require this since no override bytes = ArmKleidiAI::MlasDynamicQgemmPackBSize(N, K); } #endif diff --git a/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h b/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h index c832ca69dbb31..17a511b42e92e 100644 --- a/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h +++ b/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h @@ -71,6 +71,13 @@ class FgemmShortExecuteTest : public MlasTestFixture Date: Wed, 22 Oct 2025 13:13:14 +0100 Subject: [PATCH 2/3] update kleidi version to fix missing header file Signed-off-by: Jonathan Clohessy --- cmake/deps.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/deps.txt b/cmake/deps.txt index bf76753c1b3c0..98ef5cffae9a6 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -56,5 +56,5 @@ extensions;https://github.com/microsoft/onnxruntime-extensions/archive/c24b7bab0 directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.12.0.zip;7e733cfdc410d777b76122d64232499205589a96 dawn;https://github.com/google/dawn/archive/13c1635a14574ebb7116b56a69f5519301417fda.zip;0aadd28fc385cf7d657d5fc70a352372d2d3c76a -kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.10.0.tar.gz;11b62149cb2514b3b9069cc435c3aa7a4e82b97a +kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.15.0.tar.gz;62ccd24ab60bcef68766440fb42d79071ac2a5d2 duktape;https://github.com/svaarala/duktape/releases/download/v2.7.0/duktape-2.7.0.tar.xz;8200c8e417dbab7adcc12c4dbdef7651cfc55794 From e8ab1b1074a69eb2f141339270b2a1a9cec7cb64 Mon Sep 17 00:00:00 2001 From: Jonathan Clohessy Date: Fri, 24 Oct 2025 15:28:52 +0100 Subject: [PATCH 3/3] Update const for kernel interface and sme checks Signed-off-by: Jonathan Clohessy --- .../contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc | 6 +++++- onnxruntime/core/mlas/lib/convolve.cpp | 6 +++--- onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp | 4 ++-- onnxruntime/core/mlas/lib/qgemm.cpp | 2 +- onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp | 3 ++- 5 files changed, 13 insertions(+), 8 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc index 36a6f70cc69d9..6988608234c65 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc @@ -15,6 +15,10 @@ #include #include +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) +#include "core/mlas/lib/kleidiai/mlasi_kleidiai.h" +#endif + namespace onnxruntime { namespace contrib { @@ -215,7 +219,7 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase { // Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops. // We check that here too before attempting to use them. - if (!CPUIDInfo::GetCPUIDInfo().HasArm_SME()) { + if (!ArmKleidiAI::SMEInfo::CanUseSME2) { can_use_dynamic_quant_mlas_ = false; } diff --git a/onnxruntime/core/mlas/lib/convolve.cpp b/onnxruntime/core/mlas/lib/convolve.cpp index bc1221475fd90..ed81295609f52 100644 --- a/onnxruntime/core/mlas/lib/convolve.cpp +++ b/onnxruntime/core/mlas/lib/convolve.cpp @@ -862,9 +862,9 @@ Return Value: --*/ { // Override - if(GetMlasPlatform().MlasConvOverride != nullptr && + if(ArmKleidiAI::SMEInfo::CanUseSME2 && GetMlasPlatform().MlasConvOverride != nullptr && GetMlasPlatform().MlasConvOverride(Parameters,Input,Filter,Bias,WorkingBuffer,Output,ThreadPool)){ - return; + return; } const size_t FilterCount = Parameters->FilterCount; @@ -1101,7 +1101,7 @@ Return Value: --*/ { // Override - if (GetMlasPlatform().MlasConvPrepareOverride != nullptr && + if (ArmKleidiAI::SMEInfo::CanUseSME2 && GetMlasPlatform().MlasConvPrepareOverride != nullptr && GetMlasPlatform().MlasConvPrepareOverride(Parameters, Dimensions, BatchCount, GroupCount, InputChannels, InputShape,KernelShape,DilationShape, Padding, StrideShape, OutputShape, FilterCount, Activation, WorkingBufferSize, Beta, ThreadPool)){ diff --git a/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp index fd60cbfa88ed7..363700da1bc66 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp @@ -27,8 +27,8 @@ struct KaiTlsBuffers { }; static thread_local KaiTlsBuffers g_kai_tls; -kai_matmul_clamp_f32_f32p_f32p_ukernel sgemm_gemm = GetKleidiAISGemmUKernel(); -kai_matmul_clamp_f32_f32_f32p_ukernel sgemm_gemv = GetKleidiAISGemvUKernel(); +const kai_matmul_clamp_f32_f32p_f32p_ukernel sgemm_gemm = GetKleidiAISGemmUKernel(); +const kai_matmul_clamp_f32_f32_f32p_ukernel sgemm_gemv = GetKleidiAISGemvUKernel(); // Helpers for GEMV diff --git a/onnxruntime/core/mlas/lib/qgemm.cpp b/onnxruntime/core/mlas/lib/qgemm.cpp index 1aa90dc48f802..daf24c8d000b3 100644 --- a/onnxruntime/core/mlas/lib/qgemm.cpp +++ b/onnxruntime/core/mlas/lib/qgemm.cpp @@ -423,7 +423,7 @@ MlasDynamicQgemmPackB( { #if defined(USE_KLEIDIAI) && !defined(_MSC_VER) //No fallback - if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){//Still require this since no override + if (ArmKleidiAI::SMEInfo::CanUseSME2) {//Still require this since no override ArmKleidiAI::MlasDynamicQgemmPackB(N, K, B, Scales, Bias, PackedB); } #endif diff --git a/onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp b/onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp index 6d05e93f517ae..6f3f4e3fe93d7 100644 --- a/onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp @@ -9,6 +9,7 @@ #include "test_util.h" #include "core/mlas/lib/mlasi.h" // for MLAS_CPUIDINFO +#include "core/mlas/lib/kleidiai/mlasi_kleidiai.h" class MlasDynamicQgemmTest { private: @@ -21,7 +22,7 @@ class MlasDynamicQgemmTest { public: void Test(size_t M, size_t N, size_t K, size_t BatchSize) { // Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops. - if (!MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()) { + if (!ArmKleidiAI::SMEInfo::CanUseSME2) { GTEST_SKIP() << "MlasDynamicQGemmBatch() requires ARM64 SME but it was not detected. Skipping test."; }