Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@ extensions;https://github.com/microsoft/onnxruntime-extensions/archive/c24b7bab0
directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e
cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.12.0.zip;7e733cfdc410d777b76122d64232499205589a96
dawn;https://github.com/google/dawn/archive/13c1635a14574ebb7116b56a69f5519301417fda.zip;0aadd28fc385cf7d657d5fc70a352372d2d3c76a
kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.10.0.tar.gz;11b62149cb2514b3b9069cc435c3aa7a4e82b97a
kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.15.0.tar.gz;62ccd24ab60bcef68766440fb42d79071ac2a5d2
duktape;https://github.com/svaarala/duktape/releases/download/v2.7.0/duktape-2.7.0.tar.xz;8200c8e417dbab7adcc12c4dbdef7651cfc55794
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
#include <algorithm>
#include <vector>

#if defined(USE_KLEIDIAI) && !defined(_MSC_VER)
#include "core/mlas/lib/kleidiai/mlasi_kleidiai.h"
#endif

namespace onnxruntime {
namespace contrib {

Expand Down Expand Up @@ -215,7 +219,7 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase {

// Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops.
// We check that here too before attempting to use them.
if (!CPUIDInfo::GetCPUIDInfo().HasArm_SME()) {
if (!ArmKleidiAI::SMEInfo::CanUseSME2) {
can_use_dynamic_quant_mlas_ = false;
}

Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/mlas/lib/convolve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -938,9 +938,9 @@
--*/
{
// Override
if(GetMlasPlatform().MlasConvOverride != nullptr &&
if(ArmKleidiAI::SMEInfo::CanUseSME2 && GetMlasPlatform().MlasConvOverride != nullptr &&

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release_ep_generic_interface

'CanUseSME2': undeclared identifier

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release_ep_generic_interface

'ArmKleidiAI': is not a class or namespace name

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release_vitisai

'CanUseSME2': undeclared identifier

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release_vitisai

'ArmKleidiAI': is not a class or namespace name

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release

'CanUseSME2': undeclared identifier

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release

'ArmKleidiAI': is not a class or namespace name

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_debug

'CanUseSME2': undeclared identifier

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_debug

'ArmKleidiAI': is not a class or namespace name

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x86_release

'CanUseSME2': undeclared identifier

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x86_release

'ArmKleidiAI': is not a class or namespace name

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release_xnnpack

'CanUseSME2': undeclared identifier

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release_xnnpack

'ArmKleidiAI': is not a class or namespace name

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / Windows GPU DML CI Pipeline

'CanUseSME2': undeclared identifier

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / Windows GPU DML CI Pipeline

'ArmKleidiAI': is not a class or namespace name

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / Windows GPU TensorRT CI Pipeline

'CanUseSME2': undeclared identifier

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / Windows GPU TensorRT CI Pipeline

'ArmKleidiAI': is not a class or namespace name

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / Windows GPU CUDA CI Pipeline

'CanUseSME2': undeclared identifier

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / Windows GPU CUDA CI Pipeline

'ArmKleidiAI': is not a class or namespace name

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_minimal_build_edge_build_x64_RelWithDebInfo

'CanUseSME2': undeclared identifier

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_minimal_build_edge_build_x64_RelWithDebInfo

'ArmKleidiAI': is not a class or namespace name

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (novcpkg, dynamic)

'CanUseSME2': undeclared identifier

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (novcpkg, dynamic)

'ArmKleidiAI': is not a class or namespace name

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (novcpkg, static)

'CanUseSME2': undeclared identifier

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (novcpkg, static)

'ArmKleidiAI': is not a class or namespace name

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (vcpkg, dynamic)

'CanUseSME2': undeclared identifier

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (vcpkg, dynamic)

'ArmKleidiAI': is not a class or namespace name

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (vcpkg, static)

'CanUseSME2': undeclared identifier

Check failure on line 941 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (vcpkg, static)

'ArmKleidiAI': is not a class or namespace name
GetMlasPlatform().MlasConvOverride(Parameters,Input,Filter,Bias,WorkingBuffer,Output,ThreadPool)){
return;
return;
}

const size_t FilterCount = Parameters->FilterCount;
Expand Down Expand Up @@ -1201,7 +1201,7 @@
--*/
{
// Override
if (GetMlasPlatform().MlasConvPrepareOverride != nullptr &&
if (ArmKleidiAI::SMEInfo::CanUseSME2 && GetMlasPlatform().MlasConvPrepareOverride != nullptr &&

Check failure on line 1204 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release_ep_generic_interface

'CanUseSME2': undeclared identifier

Check failure on line 1204 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release_ep_generic_interface

'ArmKleidiAI': is not a class or namespace name

Check failure on line 1204 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release_vitisai

'CanUseSME2': undeclared identifier

Check failure on line 1204 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release_vitisai

'ArmKleidiAI': is not a class or namespace name

Check failure on line 1204 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release

'CanUseSME2': undeclared identifier

Check failure on line 1204 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release

'ArmKleidiAI': is not a class or namespace name

Check failure on line 1204 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_debug

'CanUseSME2': undeclared identifier

Check failure on line 1204 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_debug

'ArmKleidiAI': is not a class or namespace name

Check failure on line 1204 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x86_release

'CanUseSME2': undeclared identifier

Check failure on line 1204 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x86_release

'ArmKleidiAI': is not a class or namespace name

Check failure on line 1204 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release_xnnpack

'CanUseSME2': undeclared identifier

Check failure on line 1204 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / build_x64_release_xnnpack

'ArmKleidiAI': is not a class or namespace name

Check failure on line 1204 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / Windows GPU DML CI Pipeline

'CanUseSME2': undeclared identifier

Check failure on line 1204 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / Windows GPU DML CI Pipeline

'ArmKleidiAI': is not a class or namespace name

Check failure on line 1204 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / Windows GPU TensorRT CI Pipeline

'CanUseSME2': undeclared identifier

Check failure on line 1204 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / Windows GPU TensorRT CI Pipeline

'ArmKleidiAI': is not a class or namespace name

Check failure on line 1204 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / Windows GPU CUDA CI Pipeline

'CanUseSME2': undeclared identifier

Check failure on line 1204 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / Windows GPU CUDA CI Pipeline

'ArmKleidiAI': is not a class or namespace name

Check failure on line 1204 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_minimal_build_edge_build_x64_RelWithDebInfo

'CanUseSME2': undeclared identifier

Check failure on line 1204 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_minimal_build_edge_build_x64_RelWithDebInfo

'ArmKleidiAI': is not a class or namespace name

Check failure on line 1204 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (novcpkg, dynamic)

'CanUseSME2': undeclared identifier

Check failure on line 1204 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (novcpkg, dynamic)

'ArmKleidiAI': is not a class or namespace name

Check failure on line 1204 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (novcpkg, static)

'CanUseSME2': undeclared identifier

Check failure on line 1204 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (novcpkg, static)

'ArmKleidiAI': is not a class or namespace name

Check failure on line 1204 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (vcpkg, dynamic)

'CanUseSME2': undeclared identifier

Check failure on line 1204 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (vcpkg, dynamic)

'ArmKleidiAI': is not a class or namespace name

Check failure on line 1204 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (vcpkg, static)

'CanUseSME2': undeclared identifier

Check failure on line 1204 in onnxruntime/core/mlas/lib/convolve.cpp

View workflow job for this annotation

GitHub Actions / webgpu_build_x64_RelWithDebInfo (vcpkg, static)

'ArmKleidiAI': is not a class or namespace name
GetMlasPlatform().MlasConvPrepareOverride(Parameters, Dimensions, BatchCount, GroupCount, InputChannels,
InputShape,KernelShape,DilationShape, Padding, StrideShape, OutputShape, FilterCount,
Activation, WorkingBufferSize, Beta, ThreadPool)){
Expand Down
74 changes: 74 additions & 0 deletions onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,20 @@

#include "kai_ukernel_interface.h"
#include "mlasi.h"
#include "kleidiai/mlasi_kleidiai.h"

#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h"

#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla.h"
#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.h"
#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h"
#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.h"
#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h"
#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h"

const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod =
{kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod,
kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod,
Expand Down Expand Up @@ -64,6 +72,56 @@ const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp
kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm,
kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm};

const kai_matmul_clamp_f32_f32_f32p_ukernel sgemm_gemv_sme =
{kai_get_m_step_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla,
kai_get_n_step_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla,
kai_get_nr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla,
kai_get_kr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla,
kai_get_sr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla,
kai_get_lhs_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla,
kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla,
kai_get_dst_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla,
kai_get_dst_size_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla,
kai_run_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla};

const kai_matmul_clamp_f32_f32_f32p_ukernel sgemm_gemv_sme2 =
{kai_get_m_step_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla,
kai_get_n_step_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla,
kai_get_nr_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla,
kai_get_kr_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla,
kai_get_sr_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla,
kai_get_lhs_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla,
kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla,
kai_get_dst_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla,
kai_get_dst_size_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla,
kai_run_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla};

const kai_matmul_clamp_f32_f32p_f32p_ukernel sgemm_gemm_sme =
{kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa,
kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa,
kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa,
kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa,
kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa,
kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa,
kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa,
kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa,
kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa,
kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa,
kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa};

const kai_matmul_clamp_f32_f32p_f32p_ukernel sgemm_gemm_sme2 =
{kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa,
kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa};

const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemmUKernel() {
if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM()) {
return kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm;
Expand All @@ -79,3 +137,19 @@ const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemvUKernel() {
return kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod;
}
}

const kai_matmul_clamp_f32_f32p_f32p_ukernel& GetKleidiAISGemmUKernel() {
if (ArmKleidiAI::SMEInfo::CanUseSME2) {
return sgemm_gemm_sme2;
} else {
return sgemm_gemm_sme;
}
}

const kai_matmul_clamp_f32_f32_f32p_ukernel& GetKleidiAISGemvUKernel() {
if (ArmKleidiAI::SMEInfo::CanUseSME2) {
return sgemm_gemv_sme2;
} else {
return sgemm_gemv_sme;
}
}
7 changes: 7 additions & 0 deletions onnxruntime/core/mlas/lib/kai_ukernel_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,12 @@

#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h"

#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p_f32p_interface.h"

#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p_interface.h"

const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemmUKernel();
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemvUKernel();

const kai_matmul_clamp_f32_f32p_f32p_ukernel& GetKleidiAISGemmUKernel();
const kai_matmul_clamp_f32_f32_f32p_ukernel& GetKleidiAISGemvUKernel();
28 changes: 27 additions & 1 deletion onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,21 @@
#define RESTRICT __restrict__
#endif
namespace ArmKleidiAI {

struct SMEInfo {
static const bool CanUseSME2;
static const bool CanUseSME;
static const bool IsSMEAvailable;
};

// Boolean condition to determine if we can use SME2
// By default we should try for SME2 first before falling back to SME.
inline const bool UseSME2 = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2();
inline const bool SMEInfo::CanUseSME2 = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2();
// Boolean condition to determine if we can use SME
inline const bool SMEInfo::CanUseSME = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME();
// Boolean condition to tell us if SME is enabled on this system
inline const bool SMEInfo::IsSMEAvailable = SMEInfo::CanUseSME2 || SMEInfo::CanUseSME;


//
// Buffer packing routines.
Expand All @@ -43,6 +56,19 @@ MlasGemmPackB(
void* PackedB
);

bool
MLASCALL
MlasFp32Gemv(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t M,
size_t N,
size_t K,
const MLAS_SGEMM_DATA_PARAMS* Data,
size_t BatchSize
);


bool
MLASCALL
MlasGemmBatch(
Expand Down
Loading
Loading