@@ -39,6 +39,7 @@ limitations under the License.
39
39
#include " contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
40
40
#include " contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h"
41
41
#include " contrib_ops/cuda/bert/flash_attention/flash_api.h"
42
+ #include " contrib_ops/cuda/bert/lean_attention/lean_api.h"
42
43
#include " contrib_ops/cuda/bert/attention_impl.h"
43
44
44
45
using namespace onnxruntime ::cuda;
@@ -108,6 +109,7 @@ size_t GetAttentionWorkspaceSize(
108
109
size_t total_sequence_length,
109
110
void * fused_runner,
110
111
bool use_flash_attention,
112
+ bool use_lean_attention,
111
113
bool use_fused_cross_attention,
112
114
bool use_memory_efficient_attention,
113
115
bool use_cudnn_flash_attention,
@@ -119,12 +121,20 @@ size_t GetAttentionWorkspaceSize(
119
121
120
122
#if USE_FLASH_ATTENTION
121
123
if (use_flash_attention) {
122
- return qkv_bytes + onnxruntime::flash::get_softmax_lse_size (sequence_length, batch_size, num_heads) ;
124
+ return qkv_bytes;
123
125
}
124
126
#else
125
127
ORT_UNUSED_PARAMETER (use_flash_attention);
126
128
#endif
127
129
130
+ #if USE_LEAN_ATTENTION
131
+ if (use_lean_attention) {
132
+ return qkv_bytes;
133
+ }
134
+ #else
135
+ ORT_UNUSED_PARAMETER (use_lean_attention);
136
+ #endif
137
+
128
138
#if USE_MEMORY_EFFICIENT_ATTENTION
129
139
if (use_memory_efficient_attention) {
130
140
size_t fmha_buffer_bytes = 0 ;
@@ -301,10 +311,10 @@ Status FlashAttention(
301
311
302
312
constexpr bool is_bf16 = false ;
303
313
ORT_RETURN_IF_ERROR (onnxruntime::flash::mha_fwd (
304
- device_prop, stream, data.q , data.k , data.v , data.output , reinterpret_cast <void *>(data.scratch ),
314
+ device_prop, stream, data.q , data.k , data.v , data.output , reinterpret_cast <void *>(data.softmax_lse ),
305
315
parameters.batch_size , parameters.num_heads , parameters.num_heads , parameters.head_size ,
306
316
parameters.sequence_length , parameters.total_sequence_length , scale, 0.0 , parameters.is_unidirectional , is_bf16,
307
- false , parameters .num_splits , reinterpret_cast <void *>(data.softmax_lse_accum ),
317
+ false , data .num_splits , reinterpret_cast <void *>(data.softmax_lse_accum ),
308
318
reinterpret_cast <void *>(data.out_accum ), data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH));
309
319
310
320
return Status::OK ();
@@ -326,6 +336,81 @@ Status FlashAttention(
326
336
}
327
337
#endif
328
338
339
+ #if USE_LEAN_ATTENTION
340
+ template <typename T>
341
+ Status LeanAttention (
342
+ const cudaDeviceProp& device_prop,
343
+ cudaStream_t stream,
344
+ contrib::AttentionParameters& parameters,
345
+ AttentionData<T>& data,
346
+ float scale) {
347
+ assert (data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH ||
348
+ data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH);
349
+ assert (nullptr == data.mask_index );
350
+ assert (nullptr == data.attention_bias );
351
+ assert (parameters.head_size == parameters.v_head_size );
352
+
353
+ constexpr bool is_bf16 = false ;
354
+
355
+ ORT_RETURN_IF_ERROR (onnxruntime::lean::mha_fwd_kvcache (
356
+ device_prop, stream,
357
+ data.q ,
358
+ data.k , // k_cache
359
+ data.v , // v_cache
360
+ nullptr , // new_k (we have appended new_k to k_cache)
361
+ nullptr , // new_v (we have appended new_v to k_cache)
362
+ data.output ,
363
+ reinterpret_cast <void *>(data.softmax_lse ),
364
+ nullptr , // seqlens_k
365
+ nullptr , // cos_cache
366
+ nullptr , // sin_cache
367
+ nullptr , // block_table
368
+ parameters.batch_size ,
369
+ parameters.num_heads ,
370
+ parameters.num_heads , // num_heads_k
371
+ parameters.head_size ,
372
+ parameters.sequence_length , // seqlen_q
373
+ parameters.total_sequence_length , // seqlen_k
374
+ 0 , // seqlen_k_new
375
+ 0 , // rotary_dim
376
+ scale, // softmax_scale
377
+ parameters.is_unidirectional ,
378
+ is_bf16,
379
+ false , // past_bsnh
380
+ data.num_splits ,
381
+ data.grid_dim_z ,
382
+ data.max_tiles_per_tb ,
383
+ data.high_load_tbs ,
384
+ data.tiles_per_head ,
385
+ reinterpret_cast <void *>(data.softmax_lse_accum ),
386
+ reinterpret_cast <void *>(data.out_accum ),
387
+ data.lean_sync_flag ,
388
+ -1 , // local_window_size
389
+ false , // is_rotary_interleaved
390
+ false // is_packed_qkv
391
+ ));
392
+
393
+ return Status::OK ();
394
+ }
395
+
396
+ template <>
397
+ Status LeanAttention (
398
+ const cudaDeviceProp& device_prop,
399
+ cudaStream_t stream,
400
+ contrib::AttentionParameters& parameters,
401
+ AttentionData<float >& data,
402
+ float scale) {
403
+ ORT_UNUSED_PARAMETER (device_prop);
404
+ ORT_UNUSED_PARAMETER (stream);
405
+ ORT_UNUSED_PARAMETER (parameters);
406
+ ORT_UNUSED_PARAMETER (data);
407
+ ORT_UNUSED_PARAMETER (scale);
408
+ return ORT_MAKE_STATUS (ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, " lean attention does not support float tensor" );
409
+ }
410
+ #endif
411
+
412
+
413
+
329
414
template <typename T>
330
415
Status CudnnFlashAttention (
331
416
cudnnHandle_t cudnn_handle,
@@ -641,6 +726,11 @@ Status QkvToContext(
641
726
// For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation.
642
727
const float scale = parameters.scale == 0 .0f ? 1 .f / sqrt (static_cast <float >(qk_head_size))
643
728
: parameters.scale ;
729
+ #if USE_LEAN_ATTENTION
730
+ if (data.use_lean_attention ) {
731
+ return LeanAttention (device_prop, stream, parameters, data, scale);
732
+ }
733
+ #endif
644
734
645
735
#if USE_FLASH_ATTENTION
646
736
if (data.use_flash_attention ) {
0 commit comments