Skip to content

Commit 4abec97

Browse files
[mlas] add loongarch lsx and lasx optimize code (microsoft#17937)
### Description Hello we(@lixing-star) are the developers of loongson team. We add 128 (lsx), 256 (lasx) vector optimization code for the loongarch architecture [100% tests passed, 0 tests failed out of 7](https://cloud.a-boat.cn:2021/api/public/dl/6831z1Bi?inline=true) ### Development Environments1 ``` CPU: Loongson-3C5000L uname -a: Linux localhost.localdomain 4.19.190-6.4.lns8.loongarch64 #1 SMP Thu Jul 14 12:08:04 CST 2022 loongarch64 loongarch64 loongarch64 GNU/Linux ``` ### LonngArch Documents - [LoongArch Reference Manual - Volume 1: Basic Architecture: This manual describes the basic part of the LoongArch architecture.](https://loongson.github.io/LoongArch-Documentation/LoongArch-Vol1-EN.html) - [LoongArch ELF psABI: This manual describes the LoongArch ELF psABI.](https://loongson.github.io/LoongArch-Documentation/LoongArch-ELF-ABI-EN.html) - [more](https://loongson.github.io/LoongArch-Documentation/README-EN.html)
1 parent a045be3 commit 4abec97

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+7696
-34
lines changed

cmake/onnxruntime_mlas.cmake

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,8 @@ else()
284284
set(X86 TRUE)
285285
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$")
286286
set(X86_64 TRUE)
287+
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^loongarch64.*")
288+
set(LOONGARCH64 TRUE)
287289
endif()
288290
endif()
289291

@@ -575,6 +577,26 @@ else()
575577
set(MLAS_SOURCE_IS_NOT_SET 0)
576578
endif()
577579
endif()
580+
if(LOONGARCH64 AND MLAS_SOURCE_IS_NOT_SET)
581+
set(mlas_platform_srcs
582+
${MLAS_SRC_DIR}/qgemm_kernel_lsx.cpp
583+
${MLAS_SRC_DIR}/loongarch64/SgemmKernelLasx.S
584+
${MLAS_SRC_DIR}/loongarch64/DgemmKernelLsx.S
585+
${MLAS_SRC_DIR}/loongarch64/DgemmKernelLasx.S
586+
${MLAS_SRC_DIR}/loongarch64/SgemmKernelLsx.S
587+
${MLAS_SRC_DIR}/loongarch64/SconvKernelLsx.S
588+
${MLAS_SRC_DIR}/loongarch64/SconvKernelLasx.S
589+
${MLAS_SRC_DIR}/loongarch64/SpoolKernelLSX.S
590+
${MLAS_SRC_DIR}/loongarch64/SpoolKernelLasx.S
591+
${MLAS_SRC_DIR}/loongarch64/SgemmTransposePackB16x4LSX.S
592+
${MLAS_SRC_DIR}/loongarch64/SgemmTransposePackB16x4Lasx.S
593+
${MLAS_SRC_DIR}/loongarch64/SoftmaxKernelLasx.S
594+
)
595+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mlsx -mlasx")
596+
if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH)
597+
set(MLAS_SOURCE_IS_NOT_SET 0)
598+
endif()
599+
endif()
578600
if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH AND MLAS_SOURCE_IS_NOT_SET)
579601
file(GLOB_RECURSE mlas_platform_srcs
580602
"${MLAS_SRC_DIR}/scalar/*.cpp")

onnxruntime/core/mlas/inc/mlas.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ Module Name:
6969
#endif
7070
#endif
7171

72+
#if defined(__loongarch64)
73+
#define MLAS_TARGET_LARCH64
74+
#endif
7275
//
7376
// Define the support levels for the target architecture.
7477
//
@@ -87,7 +90,7 @@ Module Name:
8790

8891
#define MLAS_F16VEC_INTRINSICS_SUPPORTED
8992

90-
#endif //
93+
#endif //
9194
#endif // ARM64
9295
#endif // Visual Studio 16 or earlier does not support fp16 intrinsic
9396

@@ -1619,7 +1622,7 @@ MlasHalfGemmConvertPackB(
16191622
* @param Channels # of input channels
16201623
* @param OutputCount # of output pixels
16211624
* @param KernelSize # kernel size
1622-
* @return
1625+
* @return
16231626
*/
16241627
void
16251628
MLASCALL
@@ -1657,7 +1660,7 @@ MlasTranspose(
16571660
* @param Channels C in NHWC
16581661
* @param OutputCount Number of output pixels
16591662
* @param KernelSize Size of the kernel
1660-
* @return
1663+
* @return
16611664
*/
16621665
void
16631666
MLASCALL
@@ -1676,7 +1679,7 @@ MlasNhwcMaxPool(
16761679
* @param Channels C in NHWC
16771680
* @param OutputCount Number of output pixels
16781681
* @param KernelSize size of the kernel
1679-
* @return
1682+
* @return
16801683
*/
16811684
void
16821685
MLASCALL

onnxruntime/core/mlas/lib/activate.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ struct MLAS_ACTIVATION_FUNCTION<MlasLeakyReluActivation>
143143
return MlasBlendFloat32x4(ValueTimesAlpha, Value, _mm_cmple_ps(ZeroFloat32x4, Value));
144144
#elif defined(MLAS_VSX_INTRINSICS)
145145
return vec_sel(ValueTimesAlpha, Value, vec_cmple(ZeroFloat32x4, Value));
146+
#elif defined(MLAS_LSX_INTRINSICS)
147+
return MlasBlendFloat32x4(ValueTimesAlpha, Value, (__m128)__lsx_vfcmp_cle_s(ZeroFloat32x4, Value));
146148
#else
147149
return MlasBlendFloat32x4(ValueTimesAlpha, Value, ZeroFloat32x4 < Value);
148150
#endif

onnxruntime/core/mlas/lib/compute.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@ Return Value:
148148
// instead.
149149
normal = _mm_min_epi16(normal, MaximumExponent);
150150
normal = _mm_max_epi16(normal, MinimumExponent);
151+
#elif defined(MLAS_LSX_INTRINSICS)
152+
normal = __lsx_vmin_h(normal, MaximumExponent);
153+
normal = __lsx_vmax_h(normal, MinimumExponent);
151154
#else
152155
normal = MlasMinimumInt32x4(normal, MaximumExponent);
153156
normal = MlasMaximumInt32x4(normal, MinimumExponent);
@@ -215,6 +218,8 @@ Return Value:
215218
// N.B. SSE2 lacks a broadcast load instruction, so avoid a shuffle
216219
// and use zeroes for the upper elements.
217220
Vector = _mm_load_ss(Input);
221+
#elif defined(MLAS_LSX_INTRINSICS)
222+
Vector = (MLAS_FLOAT32X4)__lsx_vldrepl_w(Input, 0);
218223
#else
219224
Vector = MlasBroadcastFloat32x4(Input);
220225
#endif
@@ -467,6 +472,8 @@ Return Value:
467472
// N.B. SSE2 lacks a broadcast load instruction, so avoid a shuffle and
468473
// use zeroes for the upper elements.
469474
MLAS_FLOAT32X4 Vector = _mm_load_ss(Input);
475+
#elif defined(MLAS_LSX_INTRINSICS)
476+
MLAS_FLOAT32X4 Vector = (MLAS_FLOAT32X4)__lsx_vldrepl_w(Input, 0);
470477
#else
471478
MLAS_FLOAT32X4 Vector = MlasBroadcastFloat32x4(Input);
472479
#endif
@@ -849,7 +856,7 @@ Return Value:
849856
// Find the maximum value for the row.
850857
//
851858

852-
#if defined(MLAS_TARGET_AMD64)
859+
#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64)
853860
float Maximum = GetMlasPlatform().ReduceMaximumF32Kernel(Input, D);
854861
#else
855862
float Maximum = MlasReduceMaximumF32Kernel(Input, D);
@@ -874,7 +881,7 @@ Return Value:
874881

875882
float Parameters[] = { NegativeMaximum, std::log(Accumulation)};
876883

877-
#if defined(MLAS_TARGET_AMD64)
884+
#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64)
878885
GetMlasPlatform().ComputeLogSoftmaxOutputF32Kernel(Input, Output, D, Parameters);
879886
#else
880887
MlasComputeLogSoftmaxOutputF32Kernel(Input, Output, D, Parameters);
@@ -899,7 +906,7 @@ Return Value:
899906

900907
float Parameters[] = { 1.0f / Accumulation };
901908

902-
#if defined(MLAS_TARGET_AMD64)
909+
#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64)
903910
GetMlasPlatform().ComputeSoftmaxOutputF32Kernel(Output, D, Parameters);
904911
#else
905912
MlasComputeSoftmaxOutputF32Kernel(Output, D, Parameters);

onnxruntime/core/mlas/lib/dgemm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ Return Value:
530530

531531
size_t RowsHandled;
532532

533-
#if defined(MLAS_TARGET_AMD64_IX86) || defined (MLAS_TARGET_POWER)
533+
#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64)
534534
RowsHandled = GetMlasPlatform().GemmDoubleKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha, ZeroMode);
535535
#else
536536
if (ZeroMode) {
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/*++
2+
3+
Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved.
4+
5+
Licensed under the MIT License.
6+
7+
Module Name:
8+
9+
DgemmKernelCommon.h
10+
11+
Abstract:
12+
13+
This module contains common kernel macros and structures for the double
14+
precision matrix/matrix multiply operation (DGEMM).
15+
16+
--*/
17+
18+
#define LFgemmElementShift 3
19+
#define LFgemmElementSize (1 << LFgemmElementShift)
20+
#define LFgemmYmmElementCount (32/LFgemmElementSize)
21+
22+
#include "FgemmKernelCommon.h"
23+
24+
FGEMM_TYPED_INSTRUCTION(xvfadd, xvfadd.d)
25+
FGEMM_TYPED_INSTRUCTION(xvfmadd, xvfmadd.d)
26+
FGEMM_TYPED_INSTRUCTION(xvldrepl, xvldrepl.d)
27+
FGEMM_TYPED_INSTRUCTION(xvfmul, xvfmul.d)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*++
2+
3+
Copyright (C) 2023 Loongson Technology Corporation Limited. All rights reserved.
4+
5+
Licensed under the MIT License.
6+
7+
Module Name:
8+
9+
DgemmKernelLasx.s
10+
11+
Abstract:
12+
13+
This module implements the kernels for the double precision matrix/matrix
14+
multiply operation (DGEMM).
15+
16+
This implementation uses Lasx instructions.
17+
18+
--*/
19+
20+
#include "asmmacro.h"
21+
#include "DgemmKernelCommon.h"
22+
#include "FgemmKernelLasxCommon.h"
23+
24+
.text
25+
26+
//
27+
// Generate the GEMM kernel.
28+
//
29+
30+
FgemmKernelLasxFunction MlasGemmDoubleKernelLasx
31+
32+
.end

0 commit comments

Comments
 (0)