diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index b0feea362380b..045c6d3006b2e 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -75,6 +75,8 @@ #define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4) #define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1) #define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1) +#define GGML_CUDA_CC_IS_CDNA1(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_CDNA2) +#define GGML_CUDA_CC_IS_CDNA2(cc) (cc >= GGML_CUDA_CC_CDNA2 && cc < GGML_CUDA_CC_CDNA3) #define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1) // Moore Threads @@ -325,6 +327,20 @@ static constexpr __device__ int ggml_cuda_get_physical_warp_size() { #endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) } +// Maximum number of bytes that can be copied in a single instruction. +static constexpr __device__ int ggml_cuda_get_max_cpy_bytes() { +#ifdef GGML_USE_HIP + return 16; +#else +#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + return 16; +#else + return 8; +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#endif // GGML_USE_HIP +} + + [[noreturn]] static __device__ void no_device_code( const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) { diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index b69f57d659a26..142a3a88d1d7c 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -647,9 +647,7 @@ static __global__ void flash_attn_stream_k_fixup( } template // D == head size -#if !defined(GGML_USE_HIP) __launch_bounds__(D, 1) -#endif // !(defined(GGML_USE_HIP) static __global__ void flash_attn_combine_results( const float * __restrict__ VKQ_parts, const float2 * __restrict__ VKQ_meta, @@ -692,10 +690,7 @@ static __global__ void flash_attn_combine_results( float VKQ_numerator = 0.0f; float VKQ_denominator = 0.0f; for (int l = 0; l < parallel_blocks; ++l) { - const float diff = meta[l].x - kqmax; - float KQ_max_scale = expf(diff); - const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD); - *((uint32_t *) &KQ_max_scale) &= ftz_mask; + const float KQ_max_scale = expf(meta[l].x - kqmax); VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid]; VKQ_denominator += KQ_max_scale * meta[l].y; @@ -836,11 +831,10 @@ void launch_fattn( CUDA_CHECK(cudaGetLastError()); } - int parallel_blocks = 1; - const dim3 block_dim(warp_size, nwarps, 1); int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy. CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared)); + int parallel_blocks = max_blocks_per_sm; dim3 blocks_num; if (stream_k) { @@ -862,9 +856,6 @@ void launch_fattn( GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0); const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. - // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave: - parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1); - // parallel_blocks must not be larger than what the tensor size allows: parallel_blocks = std::min(parallel_blocks, ntiles_KQ); diff --git a/ggml/src/ggml-cuda/fattn-tile.cu b/ggml/src/ggml-cuda/fattn-tile.cu index c6a399ce5d791..a2d9951ea56a7 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cu +++ b/ggml/src/ggml-cuda/fattn-tile.cu @@ -2,20 +2,30 @@ #include "fattn-common.cuh" #include "fattn-tile.cuh" -#define FATTN_TILE_NTHREADS 256 +// kq_stride == number of KQ rows to process per iteration +// kq_nbatch == number of K columns to load in parallel for KQ calculation static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int cc, const int warp_size) { if (GGML_CUDA_CC_IS_AMD(cc)) { + if (GGML_CUDA_CC_IS_RDNA(cc)) { + switch (D) { + case 64: + return 128; + case 128: + case 256: + return ncols <= 16 ? 128 : 64; + default: + GGML_ABORT("fatal error"); + return -1; + } + } switch (D) { case 64: - return 64; + return ncols == 32 ? 128 : 64; case 128: + return ncols == 32 ? 64 : 32; case 256: - if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) { - return ncols <= 16 ? 64 : 32; - } else { - return 64; - } + return 32; default: GGML_ABORT("fatal error"); return -1; @@ -49,24 +59,28 @@ static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols, int warp_size) { #ifdef GGML_USE_HIP +#ifdef RDNA switch (D) { case 64: - return 64; + return 128; case 128: -#if defined(GCN) || defined(CDNA) - return ncols <= 16 ? 64 : 32; -#else - return 64; -#endif // defined(GCN) || defined(CDNA) case 256: -#if defined(GCN) || defined(CDNA) - return ncols <= 16 ? 64 : 32; + return ncols <= 16 ? 128 : 64; + default: + return -1; + } #else - return 64; -#endif // defined(GCN) || defined(CDNA) + switch (D) { + case 64: + return ncols == 32 ? 128 : 64; + case 128: + return ncols == 32 ? 64 : 32; + case 256: + return 32; default: return -1; } +#endif // RDNA #else #ifdef FAST_FP16_AVAILABLE switch (D) { @@ -100,17 +114,8 @@ static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols case 64: return 64; case 128: -#if defined(GCN) || defined(CDNA) - return ncols <= 16 ? 64 : 128; -#else - return 64; -#endif // defined(GCN) || defined(CDNA) case 256: -#if defined(GCN) || defined(CDNA) - return ncols <= 16 ? 64 : 128; -#else - return ncols <= 16 ? 64 : 256; -#endif // defined(GCN) || defined(CDNA) + return 128; default: return -1; } @@ -120,9 +125,8 @@ static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols case 64: return 64; case 128: - return ncols <= 16 ? 128 : 64; case 256: - return ncols <= 16 ? 64 : 128; + return 128; default: return -1; } @@ -142,12 +146,27 @@ static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols GGML_UNUSED_VARS(ncols, warp_size); } -template // D == head size -#ifdef GGML_USE_HIP -__launch_bounds__(FATTN_TILE_NTHREADS, 1) +static int fattn_tile_get_nthreads_host(const int cc, const int ncols) { + return 256; + GGML_UNUSED_VARS(cc, ncols); +} + +static constexpr __device__ int fattn_tile_get_nthreads_device(int ncols) { + return 256; + GGML_UNUSED(ncols); +} + +static constexpr __device__ int fattn_tile_get_occupancy_device(int ncols) { +#ifdef RDNA + return 3; #else -__launch_bounds__(FATTN_TILE_NTHREADS, 2) -#endif // GGML_USE_HIP + return ncols <= 16 ? 3 : 2; +#endif // RDNA + GGML_UNUSED(ncols); +} + +template // D == head size +__launch_bounds__(fattn_tile_get_nthreads_device(ncols), fattn_tile_get_occupancy_device(ncols)) static __global__ void flash_attn_tile( const char * __restrict__ Q, const char * __restrict__ K, @@ -193,7 +212,7 @@ static __global__ void flash_attn_tile( } constexpr int warp_size = 32; - constexpr int nwarps = FATTN_TILE_NTHREADS / warp_size; + constexpr int nwarps = fattn_tile_get_nthreads_device(ncols) / warp_size; constexpr int kq_stride = fattn_tile_get_kq_stride_device(D, ncols, warp_size); static_assert(kq_stride % warp_size == 0, "kq_stride not divisable by warp_size."); constexpr int kq_nbatch = fattn_tile_get_kq_nbatch_device(D, ncols, warp_size); @@ -206,90 +225,126 @@ static __global__ void flash_attn_tile( const int sequence = blockIdx.z / ne02; const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); - const float * sinksf = (const float *) (sinks); + const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0); + const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); + const float * sinksf = (const float *) (sinks); const int stride_KV2 = nb11 / sizeof(half2); const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); -#if defined(GGML_USE_HIP) - constexpr int cpy_nb = 16; -#else - constexpr int cpy_nb = 8; -#endif // defined(GGML_USE_HIP) && defined(GCN) + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); constexpr int cpy_ne = cpy_nb / 4; - __shared__ float KQ[ncols][kq_stride]; + constexpr int cpw = ncols/nwarps; // cols per warp + + // softmax_iter_j == number of KQ columns for which to calculate softmax in parallel. + // KQ is originall 2D but uses a Z-shaped memory pattern for larger reads/writes. #ifdef FAST_FP16_AVAILABLE + constexpr int softmax_iter_j = cpw < 2*cpy_ne ? cpw : 2*cpy_ne; + + __shared__ half KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j]; __shared__ half2 Q_tmp[ncols][D/2]; - __shared__ half2 KV_tmp_h2[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts. - half2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}}; + __shared__ half2 KV_tmp[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts. + half2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}}; #else + constexpr int softmax_iter_j = cpw < 1*cpy_ne ? cpw : 1*cpy_ne; + + __shared__ float KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j]; __shared__ float Q_tmp[ncols][D]; - __shared__ float KV_tmp_f[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts. - float2 * KV_tmp_f2 = (float2 *) KV_tmp_f; - float2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}}; + __shared__ float KV_tmp[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts. + float2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}}; #endif // FAST_FP16_AVAILABLE + static_assert(cpw % softmax_iter_j == 0, "bad softmax_iter_j"); - - float kqmax[ncols/nwarps]; + float KQ_max[cpw]; #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { - kqmax[j0/nwarps] = -FLT_MAX/2.0f; + KQ_max[j0/nwarps] = -FLT_MAX/2.0f; } - float kqsum[ncols/nwarps] = {0.0f}; + float KQ_sum[cpw] = {0.0f}; + // Load Q data, convert to FP16 if fast. #pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; + for (int j0 = 0; j0 < cpw; ++j0) { + const int j = j0 + threadIdx.y*cpw; + + constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size; #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size) { - const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0 + threadIdx.x] : make_float2(0.0f, 0.0f); + for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) { + float tmp_f[cpy_ne_D] = {0.0f}; + if (ic0 + j < ne01) { + ggml_cuda_memcpy_1(tmp_f, &Q_f[j*(nb01/sizeof(float)) + i0 + threadIdx.x*cpy_ne_D]); + } + +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + tmp_f[i1] *= scale; + } + #ifdef FAST_FP16_AVAILABLE - Q_tmp[j][i0 + threadIdx.x] = make_half2(tmp.x * scale, tmp.y * scale); + half2 tmp_h2[cpy_ne_D/2]; +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) { + tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]); + } + ggml_cuda_memcpy_1(&Q_tmp[j][i0/2 + threadIdx.x*(cpy_ne_D/2)], tmp_h2); #else - Q_tmp[j][2*i0 + threadIdx.x] = tmp.x * scale; - Q_tmp[j][2*i0 + warp_size + threadIdx.x] = tmp.y * scale; + ggml_cuda_memcpy_1 (&Q_tmp[j][i0 + threadIdx.x* cpy_ne_D], tmp_f); #endif // FAST_FP16_AVAILABLE } } __syncthreads(); + // Main loop over KV cache: const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; for (int k_VKQ_0 = blockIdx.y*kq_stride; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*kq_stride) { // Calculate KQ tile and keep track of new maximum KQ values: - float kqmax_new[ncols/nwarps]; + float KQ_max_new[cpw]; #pragma unroll - for (int j = 0; j < ncols/nwarps; ++j) { - kqmax_new[j] = kqmax[j]; + for (int j = 0; j < cpw; ++j) { + KQ_max_new[j] = KQ_max[j]; } - float sum[kq_stride/warp_size][ncols/nwarps] = {{0.0f}}; + float KQ_acc[kq_stride/warp_size][cpw] = {{0.0f}}; // Accumulators for KQ matrix multiplication. + // KQ = K @ Q matrix multiplication: #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += kq_nbatch) { #pragma unroll for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += nwarps) { const int i_KQ = i_KQ_0 + threadIdx.y; -#pragma unroll - for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size) { - const half2 tmp_h2 = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x]; #ifdef FAST_FP16_AVAILABLE - KV_tmp_h2[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x] = tmp_h2; + constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/(2*warp_size) ? cpy_ne : kq_nbatch/(2*warp_size); +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size*cpy_ne_kqnb) { + ggml_cuda_memcpy_1( + &KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb], + &K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x*cpy_ne_kqnb]); + } #else - const float2 tmp_f2 = __half22float2(tmp_h2); - KV_tmp_f[i_KQ*(kq_nbatch + cpy_ne) + 2*k_KQ_1 + threadIdx.x] = tmp_f2.x; - KV_tmp_f[i_KQ*(kq_nbatch + cpy_ne) + 2*k_KQ_1 + warp_size + threadIdx.x] = tmp_f2.y; -#endif // FAST_FP16_AVAILABLE + constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/warp_size ? cpy_ne : kq_nbatch/warp_size; +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += warp_size*cpy_ne_kqnb) { + half2 tmp_h2[cpy_ne_kqnb/2]; + ggml_cuda_memcpy_1( + tmp_h2, &K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1/2 + threadIdx.x*(cpy_ne_kqnb/2)]); + + float2 tmp_f2[cpy_ne_kqnb/2]; +#pragma unroll + for (int k_KQ_2 = 0; k_KQ_2 < cpy_ne_kqnb/2; ++k_KQ_2) { + tmp_f2[k_KQ_2] = __half22float2(tmp_h2[k_KQ_2]); + } + ggml_cuda_memcpy_1( + &KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb], tmp_f2); } +#endif // FAST_FP16_AVAILABLE } __syncthreads(); @@ -298,12 +353,12 @@ static __global__ void flash_attn_tile( #pragma unroll for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += cpy_ne) { half2 K_k[kq_stride/warp_size][cpy_ne]; - half2 Q_k[ncols/nwarps][cpy_ne]; + half2 Q_k[cpw][cpy_ne]; #else #pragma unroll for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += cpy_ne) { float K_k[kq_stride/warp_size][cpy_ne]; - float Q_k[ncols/nwarps][cpy_ne]; + float Q_k[cpw][cpy_ne]; #endif // FAST_FP16_AVAILABLE #pragma unroll @@ -311,29 +366,29 @@ static __global__ void flash_attn_tile( const int i_KQ = i_KQ_0 + threadIdx.x; #ifdef FAST_FP16_AVAILABLE - ggml_cuda_memcpy_1(&K_k[i_KQ_0/warp_size], &KV_tmp_h2[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]); + ggml_cuda_memcpy_1(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]); #else - ggml_cuda_memcpy_1(&K_k[i_KQ_0/warp_size], &KV_tmp_f [i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1]); + ggml_cuda_memcpy_1(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1]); #endif // FAST_FP16_AVAILABLE } #pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - const int j_KQ = j_KQ_0 + threadIdx.y; + for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) { + const int j_KQ = j_KQ_0 + threadIdx.y*cpw; #ifdef FAST_FP16_AVAILABLE - ggml_cuda_memcpy_1(&Q_k[j_KQ_0/nwarps], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]); + ggml_cuda_memcpy_1(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]); #else - ggml_cuda_memcpy_1(&Q_k[j_KQ_0/nwarps], &Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]); + ggml_cuda_memcpy_1(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]); #endif // FAST_FP16_AVAILABLE } #pragma unroll for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) { #pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { + for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) { #pragma unroll for (int k = 0; k < cpy_ne; ++k) { - ggml_cuda_mad(sum[i_KQ_0/warp_size][j_KQ_0/nwarps], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0/nwarps][k]); + ggml_cuda_mad(KQ_acc[i_KQ_0/warp_size][j_KQ_0], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0][k]); } } } @@ -344,104 +399,77 @@ static __global__ void flash_attn_tile( } } + // Apply logit softcap, mask, update KQ_max: #pragma unroll for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) { const int i_KQ = i_KQ_0 + threadIdx.x; #pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - const int j_KQ = j_KQ_0 + threadIdx.y; + for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) { + const int j_KQ = j_KQ_0 + threadIdx.y*cpw; if (use_logit_softcap) { - sum[i_KQ_0/warp_size][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/warp_size][j_KQ_0/nwarps]); + KQ_acc[i_KQ_0/warp_size][j_KQ_0] = logit_softcap * tanhf(KQ_acc[i_KQ_0/warp_size][j_KQ_0]); } - sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; - - kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/warp_size][j_KQ_0/nwarps]); + KQ_acc[i_KQ_0/warp_size][j_KQ_0] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; - KQ[j_KQ][i_KQ] = sum[i_KQ_0/warp_size][j_KQ_0/nwarps]; + KQ_max_new[j_KQ_0] = fmaxf(KQ_max_new[j_KQ_0], KQ_acc[i_KQ_0/warp_size][j_KQ_0]); } } __syncthreads(); + // Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators: #pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]); - const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]); - kqmax[j0/nwarps] = kqmax_new[j0/nwarps]; - - float kqsum_add = 0.0f; - if (kq_stride % (4*warp_size) == 0 && cpy_ne % 4 == 0) { -#pragma unroll - for (int i0 = 0; i0 < kq_stride; i0 += 4*warp_size) { - const int i = i0 + 4*threadIdx.x; - - float4 val = *(const float4 *) &KQ[j][i]; - val.x = expf(val.x - kqmax[j0/nwarps]); - val.y = expf(val.y - kqmax[j0/nwarps]); - val.z = expf(val.z - kqmax[j0/nwarps]); - val.w = expf(val.w - kqmax[j0/nwarps]); - kqsum_add += val.x + val.y + val.z + val.w; - + for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) { #ifdef FAST_FP16_AVAILABLE - const half2 tmp[2] = {make_half2(val.x, val.y), make_half2(val.z, val.w)}; - ggml_cuda_memcpy_1(&KQ[j][i/2], &tmp); + half tmp[kq_stride/warp_size][softmax_iter_j]; #else - ggml_cuda_memcpy_1(&KQ[j][i], &val); + float tmp[kq_stride/warp_size][softmax_iter_j]; #endif // FAST_FP16_AVAILABLE - } - } else if (kq_stride % (2*warp_size) == 0 && cpy_ne % 2 == 0) { + #pragma unroll - for (int i0 = 0; i0 < kq_stride; i0 += 2*warp_size) { - const int i = i0 + 2*threadIdx.x; + for (int j1 = 0; j1 < softmax_iter_j; ++j1) { + KQ_max_new[j0+j1] = warp_reduce_max(KQ_max_new[j0+j1]); + const float KQ_max_scale = expf(KQ_max[j0+j1] - KQ_max_new[j0+j1]); + KQ_max[j0+j1] = KQ_max_new[j0+j1]; - float2 val = *(const float2 *) &KQ[j][i]; - val.x = expf(val.x - kqmax[j0/nwarps]); - val.y = expf(val.y - kqmax[j0/nwarps]); - kqsum_add += val.x + val.y; -#ifdef FAST_FP16_AVAILABLE - const half2 tmp = make_half2(val.x, val.y); - ggml_cuda_memcpy_1(&KQ[j][i/2], &tmp); -#else - ggml_cuda_memcpy_1(&KQ[j][i], &val); -#endif // FAST_FP16_AVAILABLE - } - } else { + float KQ_sum_add = 0.0f; +#pragma unroll for (int i0 = 0; i0 < kq_stride; i0 += warp_size) { - const int i = i0 + threadIdx.x; + const float val = expf(KQ_acc[i0/warp_size][j0+j1] - KQ_max[j0+j1]); + KQ_sum_add += val; + tmp[i0/warp_size][j1] = val; + } + KQ_sum[j0+j1] = KQ_sum[j0+j1]*KQ_max_scale + KQ_sum_add; - const float diff = KQ[j][i] - kqmax[j0/nwarps]; - const float val = expf(diff); - kqsum_add += val; #ifdef FAST_FP16_AVAILABLE - ((half *) KQ[j])[i] = val; + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + VKQ[j0+j1][i0/warp_size] *= KQ_max_scale_h2; + } #else - KQ[j][i] = val; -#endif // FAST_FP16_AVAILABLE +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + VKQ[j0+j1][i0/warp_size].x *= KQ_max_scale; + VKQ[j0+j1][i0/warp_size].y *= KQ_max_scale; } +#endif // FAST_FP16_AVAILABLE } - kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add; -#ifdef FAST_FP16_AVAILABLE - const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size) { - VKQ[j0/nwarps][i0/warp_size] *= KQ_max_scale_h2; - } -#else -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size) { - VKQ[j0/nwarps][i0/warp_size].x *= KQ_max_scale; - VKQ[j0/nwarps][i0/warp_size].y *= KQ_max_scale; + for (int i0 = 0; i0 < kq_stride; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + ggml_cuda_memcpy_1( + KQ[j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j)][i], tmp[i0/warp_size]); } -#endif // FAST_FP16_AVAILABLE } - constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D; + // VKQ = V @ KQ matrix multiplication: + constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D; // Number of V columns that fit in SRAM for K. static_assert(kq_stride % V_cols_per_iter == 0, "bad V_cols_per_iter"); #pragma unroll for (int k0 = 0; k0 < kq_stride; k0 += V_cols_per_iter) { @@ -449,65 +477,96 @@ static __global__ void flash_attn_tile( for (int k1 = 0; k1 < V_cols_per_iter; k1 += nwarps) { const int k_tile = k1 + threadIdx.y; -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - const half2 tmp = V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i]; #ifdef FAST_FP16_AVAILABLE - KV_tmp_h2[k_tile*(D/2) + i] = tmp; + constexpr int cpy_ne_D = cpy_ne < D/(2*warp_size) ? cpy_ne : D/(2*warp_size); +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1( + &KV_tmp[k_tile*(D/2) + i0 + threadIdx.x*cpy_ne_D], + &V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0 + threadIdx.x*cpy_ne_D]); + } #else - KV_tmp_f2[k_tile*(D/2) + i] = __half22float2(tmp); -#endif // FAST_FP16_AVAILABLE + constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size; +#pragma unroll + for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) { + half2 tmp_h2[cpy_ne_D/2]; + ggml_cuda_memcpy_1( + tmp_h2, &V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0/2 + threadIdx.x*(cpy_ne_D/2)]); + + float2 tmp_f2[cpy_ne_D/2]; +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) { + tmp_f2[i1] = __half22float2(tmp_h2[i1]); + } + ggml_cuda_memcpy_1( + &KV_tmp[k_tile*D + i0 + threadIdx.x*cpy_ne_D], tmp_f2); } +#endif // FAST_FP16_AVAILABLE } __syncthreads(); +#ifdef FAST_FP16_AVAILABLE #pragma unroll for (int k1 = 0; k1 < V_cols_per_iter; ++k1) { -#ifdef FAST_FP16_AVAILABLE half2 V_k[(D/2)/warp_size]; - half2 KQ_k[ncols/nwarps]; -#else - float2 V_k[(D/2)/warp_size]; - float KQ_k[ncols/nwarps]; -#endif // FAST_FP16_AVAILABLE + half2 KQ_k[cpw]; + constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size; #pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size) { - const int i = i0 + threadIdx.x; + for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1(&V_k[i0/warp_size], &KV_tmp[k1*(D/2) + i0 + threadIdx.x*cpy_ne_D]); + } +#pragma unroll + for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) { + const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j); -#ifdef FAST_FP16_AVAILABLE - V_k[i0/warp_size] = KV_tmp_h2[k1*(D/2) + i]; + half tmp[softmax_iter_j]; + ggml_cuda_memcpy_1( + &tmp, KQ[j][k0 + k1]); +#pragma unroll + for (int j1 = 0; j1 < softmax_iter_j; ++j1) { + KQ_k[j0+j1] = __half2half2(tmp[j1]); + } + } + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { +#pragma unroll + for (int j0 = 0; j0 < cpw; ++j0) { + VKQ[j0][i0/warp_size] += V_k[i0/warp_size]*KQ_k[j0]; + } + } + } #else - V_k[i0/warp_size] = KV_tmp_f2[k1*(D/2) + i]; -#endif // FAST_FP16_AVAILABLE +#pragma unroll + for (int k1 = 0; k1 < V_cols_per_iter; ++k1) { + float2 V_k[(D/2)/warp_size]; + float KQ_k[cpw]; + + constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size; +#pragma unroll + for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1(&V_k[i0/(2*warp_size)], &KV_tmp[k1*D + i0 + threadIdx.x*cpy_ne_D]); } #pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; + for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) { + const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j); -#ifdef FAST_FP16_AVAILABLE - KQ_k[j0/nwarps] = __half2half2(((const half *)KQ[j])[k0 + k1]); -#else - KQ_k[j0/nwarps] = KQ[j][k0 + k1]; -#endif // FAST_FP16_AVAILABLE + ggml_cuda_memcpy_1( + &KQ_k[j0], KQ[j][k0 + k1]); } #pragma unroll for (int i0 = 0; i0 < D/2; i0 += warp_size) { #pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { -#ifdef FAST_FP16_AVAILABLE - VKQ[j0/nwarps][i0/warp_size] += V_k[i0/warp_size] *KQ_k[j0/nwarps]; -#else - VKQ[j0/nwarps][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0/nwarps]; - VKQ[j0/nwarps][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0/nwarps]; -#endif // FAST_FP16_AVAILABLE + for (int j0 = 0; j0 < cpw; ++j0) { + VKQ[j0][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0]; + VKQ[j0][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0]; } } } +#endif // FAST_FP16_AVAILABLE __syncthreads(); } @@ -519,69 +578,92 @@ static __global__ void flash_attn_tile( const float sink = sinksf[head]; #pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink); - kqmax_new_j = warp_reduce_max(kqmax_new_j); + for (int j0 = 0; j0 < cpw; ++j0) { + float KQ_max_new_j = fmaxf(KQ_max[j0], sink); + KQ_max_new_j = warp_reduce_max(KQ_max_new_j); - const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new_j); - kqmax[j0/nwarps] = kqmax_new_j; + const float KQ_max_scale = expf(KQ_max[j0] - KQ_max_new_j); + KQ_max[j0] = KQ_max_new_j; - const float val = expf(sink - kqmax[j0/nwarps]); - kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale; + const float val = expf(sink - KQ_max[j0]); + KQ_sum[j0] = KQ_sum[j0] * KQ_max_scale; if (threadIdx.x == 0) { - kqsum[j0/nwarps] += val; + KQ_sum[j0] += val; } #ifdef FAST_FP16_AVAILABLE const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); #pragma unroll for (int i0 = 0; i0 < D/2; i0 += warp_size) { - VKQ[j0/nwarps][i0/warp_size] *= KQ_max_scale_h2; + VKQ[j0][i0/warp_size] *= KQ_max_scale_h2; } #else #pragma unroll for (int i0 = 0; i0 < D/2; i0 += warp_size) { - VKQ[j0/nwarps][i0/warp_size].x *= KQ_max_scale; - VKQ[j0/nwarps][i0/warp_size].y *= KQ_max_scale; + VKQ[j0][i0/warp_size].x *= KQ_max_scale; + VKQ[j0][i0/warp_size].y *= KQ_max_scale; } #endif // FAST_FP16_AVAILABLE } } - float2 * dst2 = (float2 *) dst; +#pragma unroll + for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) { + KQ_sum[j_VKQ_0] = warp_reduce_sum(KQ_sum[j_VKQ_0]); + } + if (gridDim.y == 1) { +#pragma unroll + for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) { +#ifdef FAST_FP16_AVAILABLE + const half2 KQ_sum_j_inv = make_half2(1.0f/KQ_sum[j_VKQ_0], 1.0f/KQ_sum[j_VKQ_0]); +#pragma unroll + for (int i = 0; i < (D/2)/warp_size; ++i) { + VKQ[j_VKQ_0][i] *= KQ_sum_j_inv; + } +#else + const float KQ_sum_j_inv = 1.0f/KQ_sum[j_VKQ_0]; +#pragma unroll + for (int i = 0; i < (D/2)/warp_size; ++i) { + VKQ[j_VKQ_0][i].x *= KQ_sum_j_inv; + VKQ[j_VKQ_0][i].y *= KQ_sum_j_inv; + } +#endif // FAST_FP16_AVAILABLE + } + } + // Write back results: #pragma unroll - for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { - const int j_VKQ = j_VKQ_0 + threadIdx.y; + for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) { + const int j_VKQ = j_VKQ_0 + threadIdx.y*cpw; if (ic0 + j_VKQ >= ne01) { return; } - float kqsum_j = kqsum[j_VKQ_0/nwarps]; - kqsum_j = warp_reduce_sum(kqsum_j); - const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; -#pragma unroll - for (int i00 = 0; i00 < D/2; i00 += warp_size) { - const int i0 = i00 + threadIdx.x; - #ifdef FAST_FP16_AVAILABLE - float2 dst_val = __half22float2(VKQ[j_VKQ_0/nwarps][i0/warp_size]); -#else - float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/warp_size]; -#endif // FAST_FP16_AVAILABLE - - if (gridDim.y == 1) { - dst_val.x /= kqsum_j; - dst_val.y /= kqsum_j; + constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) { + float2 tmp[cpy_ne_D]; +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + tmp[i1] = __half22float2(VKQ[j_VKQ_0][i0/warp_size + i1]); } - dst2[j_dst_unrolled*(D/2) + i0] = dst_val; + ggml_cuda_memcpy_1(&dst[j_dst_unrolled*D + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp); } +#else + constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size; +#pragma unroll + for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1( + &dst[j_dst_unrolled*D + i0 + threadIdx.x*cpy_ne_D], &VKQ[j_VKQ_0][i0/(2*warp_size)]); + } +#endif // FAST_FP16_AVAILABLE if (gridDim.y != 1 && threadIdx.x == 0) { - dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); + dst_meta[j_dst_unrolled] = make_float2(KQ_max[j_VKQ_0], KQ_sum[j_VKQ_0]); } } #else @@ -602,15 +684,29 @@ template static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * Q = dst->src[0]; - const int id = ggml_cuda_get_device(); - const int cc = ggml_cuda_info().devices[id].cc; - const int warp_size = 32; - const int nwarps = FATTN_TILE_NTHREADS / warp_size; + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const int warp_size = 32; constexpr size_t nbytes_shared = 0; +#ifdef GGML_USE_HIP + if constexpr (D <= 128) { + if (Q->ne[1] > 32) { + constexpr int cols_per_block = 64; + const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size; + fattn_kernel_t fattn_kernel = flash_attn_tile; + const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size); + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size); + return; + } + } +#endif // GGML_USE_HIP + if (Q->ne[1] > 16) { constexpr int cols_per_block = 32; + const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size; fattn_kernel_t fattn_kernel = flash_attn_tile; const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size); launch_fattn @@ -619,6 +715,7 @@ static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml } constexpr int cols_per_block = 16; + const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size; fattn_kernel_t fattn_kernel = flash_attn_tile; const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size); launch_fattn diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 12bbee45566de..37386afcd405b 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -158,41 +158,41 @@ #define __CUDA_ARCH__ 1300 -#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) -#define GCN -#endif - #if defined(__gfx900__) || defined(__gfx906__) #define GCN5 -#endif +#endif // defined(__gfx900__) || defined(__gfx906__) #if defined(__gfx803__) #define GCN4 -#endif +#endif // defined(__gfx803__) -#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) -#define CDNA // For the entire family -#endif +#if defined(GCN5) || defined(GCN4) +#define GCN +#endif // defined(GCN5) || defined(GCN4) #if defined(__gfx942__) #define CDNA3 -#endif +#endif // defined(__gfx942__) #if defined(__gfx90a__) #define CDNA2 -#endif +#endif // defined(__gfx90a__) #if defined(__gfx908__) #define CDNA1 -#endif +#endif // defined(__gfx908__) + +#if defined(CDNA3) || defined(CDNA2) || defined(CDNA1) +#define CDNA // For the entire family +#endif // defined(CDNA3) || defined(CDNA2) || defined(CDNA1) #if defined(__GFX12__) #define RDNA4 -#endif +#endif // defined(__GFX12__) #if defined(__GFX11__) #define RDNA3 -#endif +#endif // defined(__GFX11__) #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \ defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__) @@ -201,7 +201,11 @@ #if defined(__gfx1010__) || defined(__gfx1012__) #define RDNA1 -#endif +#endif // defined(__gfx1010__) || defined(__gfx1012__) + +#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1) +#define RDNA // For the entire family +#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1) #ifndef __has_builtin #define __has_builtin(x) 0