diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp index 86bafba4a4398..4b13733079ab5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp @@ -221,26 +221,29 @@ void main() { #endif /* Load kernel to A_block: (BS_K x BS_CRS)*/ + uint32_t B_lx = Ac; + uint32_t knl_idx_base = KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02; for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) { uint32_t B_ly = r_offset + Ar; - uint32_t B_lx = Ac; uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/ - uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1); - float val = knl_data[knl_idx]; + uint32_t knl_idx = knl_idx_base + K_idx * p.nb03; + float val; if (K_idx >= K || CRS_idx_a >= CRS) { val = 0.0; + } else { + val = knl_data[knl_idx]; } Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val); } /* Load input to B_block: (BS_CRS x BS_NPQ) */ + B_lx = Bc; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */ + uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW; + uint32_t NPQ_remainder = NPQ_idx - N_idx * p.OH * p.OW; + uint32_t OH_idx = fastdiv(NPQ_remainder, p.OWmp, p.OWL); // divide by p.OW; + uint32_t OW_idx = NPQ_remainder - OH_idx * p.OW; UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) { uint32_t B_ly = r_offset + Br; /* Row index of B block */ - uint32_t B_lx = Bc; - uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */ - uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW; - uint32_t NPQ_remainder = NPQ_idx - N_idx * p.OH * p.OW; - uint32_t OH_idx = fastdiv(NPQ_remainder, p.OWmp, p.OWL); // divide by p.OW; - uint32_t OW_idx = NPQ_remainder - OH_idx * p.OW; uint32_t CRS_idx_b; uint32_t Cin_idx_b; @@ -269,11 +272,12 @@ void main() { uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1; uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0; - uint32_t src_idx = - min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1); - float val = src_data[src_idx]; - if (CRS_idx_b >= CRS || NPQ_idx >= NPQ || H_idx < 0 || H_idx >= p.H || W_idx < 0 || W_idx >= p.W) { + uint32_t src_idx = W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13; + float val; + if (CRS_idx_b >= CRS || NPQ_idx >= NPQ || H_idx >= p.H || W_idx >= p.W) { val = 0.0; + } else { + val = src_data[src_idx]; } Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val); }