@@ -647,7 +647,9 @@ static __global__ void flash_attn_stream_k_fixup(
647
647
}
648
648
649
649
template <int D> // D == head size
650
+ #if !defined(GGML_USE_HIP)
650
651
__launch_bounds__ (D, 1 )
652
+ #endif // !(defined(GGML_USE_HIP)
651
653
static __global__ void flash_attn_combine_results (
652
654
const float * __restrict__ VKQ_parts,
653
655
const float2 * __restrict__ VKQ_meta,
@@ -690,7 +692,10 @@ static __global__ void flash_attn_combine_results(
690
692
float VKQ_numerator = 0 .0f ;
691
693
float VKQ_denominator = 0 .0f ;
692
694
for (int l = 0 ; l < parallel_blocks; ++l) {
693
- const float KQ_max_scale = expf (meta[l].x - kqmax);
695
+ const float diff = meta[l].x - kqmax;
696
+ float KQ_max_scale = expf (diff);
697
+ const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
698
+ *((uint32_t *) &KQ_max_scale) &= ftz_mask;
694
699
695
700
VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
696
701
VKQ_denominator += KQ_max_scale * meta[l].y ;
@@ -831,10 +836,11 @@ void launch_fattn(
831
836
CUDA_CHECK (cudaGetLastError ());
832
837
}
833
838
839
+ int parallel_blocks = 1 ;
840
+
834
841
const dim3 block_dim (warp_size, nwarps, 1 );
835
842
int max_blocks_per_sm = 1 ; // Max. number of active blocks limited by occupancy.
836
843
CUDA_CHECK (cudaOccupancyMaxActiveBlocksPerMultiprocessor (&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z , nbytes_shared));
837
- int parallel_blocks = max_blocks_per_sm;
838
844
839
845
dim3 blocks_num;
840
846
if (stream_k) {
@@ -856,6 +862,9 @@ void launch_fattn(
856
862
GGML_ASSERT (K->ne [1 ] % KQ_row_granularity == 0 );
857
863
const int ntiles_KQ = K->ne [1 ] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
858
864
865
+ // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
866
+ parallel_blocks = std::max ((nsm * max_blocks_per_sm) / ntiles_total, 1 );
867
+
859
868
// parallel_blocks must not be larger than what the tensor size allows:
860
869
parallel_blocks = std::min (parallel_blocks, ntiles_KQ);
861
870
0 commit comments