diff --git a/metal-perf/int4mm.mm b/metal-perf/int4mm.mm index 76a073a..96bfd2b 100644 --- a/metal-perf/int4mm.mm +++ b/metal-perf/int4mm.mm @@ -306,22 +306,68 @@ void dispatchThreads(id encoder, using Int4MMOpDescriptor::N; void dispatchThreads(id encoder, unsigned maxThreadsPerGroup) const override { - constexpr auto blockSize = 8; - if (maxThreadsPerGroup < blockSize * blockSize) { - throw std::runtime_error("Can't dispatch!"); - } - [encoder dispatchThreads:MTLSizeMake(N/4, M/4, 1) - threadsPerThreadgroup:MTLSizeMake(blockSize, blockSize, 1)]; + [encoder dispatchThreads:MTLSizeMake(N / 4, (M + 3) / 4, 1) + threadsPerThreadgroup:MTLSizeMake(std::min(maxThreadsPerGroup, M), 1, 1)]; + } +}; + +template +struct Int4MM1OpDescriptor : public Int4MMOpDescriptor { + using Int4MMOpDescriptor::Int4MMOpDescriptor; + using Int4MMOpDescriptor::K; + using Int4MMOpDescriptor::N; + void dispatchThreads(id encoder, + unsigned maxThreadsPerGroup) const override { + [encoder dispatchThreads:MTLSizeMake(N / 2, 1, 1) + threadsPerThreadgroup:MTLSizeMake(16, 1, 1)]; + } +}; + +template +struct Int4MM1Vec4OpDescriptor : public Int4MMOpDescriptor { + using Int4MMOpDescriptor::Int4MMOpDescriptor; + using Int4MMOpDescriptor::K; + using Int4MMOpDescriptor::N; + void dispatchThreads(id encoder, + unsigned maxThreadsPerGroup) const override { + [encoder dispatchThreads:MTLSizeMake(N / 4, 1, 1) + threadsPerThreadgroup:MTLSizeMake(8, 1, 1)]; + } +}; + +template +struct Int4MM1BlockOpDescriptor : public Int4MMOpDescriptor { + using Int4MMOpDescriptor::Int4MMOpDescriptor; + using Int4MMOpDescriptor::K; + using Int4MMOpDescriptor::N; + void dispatchThreads(id encoder, + unsigned maxThreadsPerGroup) const override { + [encoder dispatchThreads:MTLSizeMake(N / 2, 4, 1) + threadsPerThreadgroup:MTLSizeMake(16, 4, 1)]; + } +}; + +template +struct Int4MM1Vec4BlockOpDescriptor : public Int4MMOpDescriptor { + using Int4MMOpDescriptor::Int4MMOpDescriptor; + using Int4MMOpDescriptor::K; + using Int4MMOpDescriptor::N; + void dispatchThreads(id encoder, + unsigned maxThreadsPerGroup) const override { + [encoder dispatchThreads:MTLSizeMake(N / 2, 4, 1) + threadsPerThreadgroup:MTLSizeMake(16, 4, 1)]; } }; int main() { unsigned M, N, K; - std::tie(M, N, K) = std::make_tuple(32, 4128, 4096); + //std::tie(M, N, K) = std::make_tuple(32, 4128, 4128); + std::tie(M, N, K) = std::make_tuple(1, 4128, 4128); constexpr unsigned groupSize = 32; @autoreleasepool { id device = getMetalDevice(); std::cout << "Using device " << device.name.UTF8String << std::endl; + std::cout << "Dimensions (M, N, K) = (" << M << ", " << N << ", " << K << ")" << std::endl; Int4MMOpDescriptor naive_int4mm(device, "naive_int4mm", M, N, K); Int4MMOpDescriptor reduce_vec4_int4mm(device, "reduce_vec4_int4mm", M, N, K); @@ -329,8 +375,18 @@ int main() { K); Int4MMMat4xMat4OpDescriptor reduce_mat4xmat4_int4mm(device, "reduce_mat4xmat4_int4mm", M, N, K); + Int4MM1OpDescriptor m1_int4mm(device, "m1_int4mm", M, N, K); + Int4MM1Vec4OpDescriptor m1vec4_int4mm(device, "m1vec4_int4mm", M, N, K); + Int4MM1BlockOpDescriptor m1block_int4mm(device, "m1block_int4mm", M, N, K); + Int4MM1Vec4BlockOpDescriptor m1vec4block_int4mm(device, "m1vec4block_int4mm", M, N, K); + Int4MM1Vec4BlockOpDescriptor m1vec4block2_int4mm(device, "m1vec4block2_int4mm", M, N, K); // Benchmarks + m1vec4block2_int4mm.benchmark(); + m1vec4block_int4mm.benchmark(); + m1block_int4mm.benchmark(); + m1vec4_int4mm.benchmark(); + m1_int4mm.benchmark(); reduce_mat4xmat4_int4mm.benchmark(); reduce_mat4_int4mm.benchmark(); reduce_vec4_int4mm.benchmark(); diff --git a/metal-perf/m1_int4mm.metal b/metal-perf/m1_int4mm.metal new file mode 100644 index 0000000..2349f7b --- /dev/null +++ b/metal-perf/m1_int4mm.metal @@ -0,0 +1,112 @@ +#include +using namespace metal; + +template struct Vec4Type {}; + +template <> struct Vec4Type { + using type = float4; +}; + +template <> struct Vec4Type { + using type = half4; +}; + +#if __METAL_VERSION__ >= 310 +template <> struct Vec4Type { + using type = bfloat4; +}; +#endif + +template struct Vec2Type {}; + +template <> struct Vec2Type { + using type = float2; +}; + +template <> struct Vec2Type { + using type = half2; +}; + +#if __METAL_VERSION__ >= 310 +template <> struct Vec2Type { + using type = bfloat2; +}; +#endif + +// [encoder dispatchThreads:MTLSizeMake(N / 2, 1, 1) +// threadsPerThreadgroup:MTLSizeMake(16, 1, 1)]; + +template +kernel void int4pack_mm( + constant T * A [[buffer(0)]], + constant uchar * B [[buffer(1)]], + constant T * scalesAndZeros [[buffer(2)]], + device T * outputData [[buffer(3)]], + constant uint3 & sizes [[buffer(4)]], // M, K, N + uint2 group_index [[threadgroup_position_in_grid]], + uint2 threadgroup_index [[thread_position_in_threadgroup]]) { + + const uint K = sizes.y; + const uint N = sizes.z; + const uint nb = group_index.x; // 0..N/32-1 + const uint n2 = 16 * nb + threadgroup_index.x; // 0..N/2-1 + const uint ldb = min(32U, N - nb * 32); + const uint32_t k_block = (K + groupSize - 1) / groupSize; + + using vecT = typename Vec2Type::type; + + constant T *A_ptr = A; + constant uchar *B_ptr = B + (nb * 16 * K); + + float2 rc = 0.0; + uint k = 0; + for (uint32_t kb = 0; kb < k_block ; kb ++) { + float2 scales, zeros; + for (int i = 0; i < 2; ++i) { + scales[i] = scalesAndZeros[(kb * N + 2*n2 + i) * 2 + 0]; + zeros[i] = scalesAndZeros[(kb * N + 2*n2 + i) * 2 + 1] - scales[i] * T(8); + } + + for(uint idx = 0; idx < groupSize && k < K; idx++, k++) { + const auto a_val = float(A_ptr[k]); + uchar b_byte0 = B_ptr[(k * ldb + (2*n2 % 32))/2]; + //uchar b_byte1 = B_ptr[(k * ldb + (2*n2 % 16))/2 + 1]; + + float2 b_val = float2( + float(b_byte0 & 0x0f), + float(b_byte0 >> 4)); + + float2 b_vec = scales * b_val + zeros; + + rc += a_val * b_vec; + } + } + reinterpret_cast(outputData)[n2] = vecT(rc); +} + +#define INSTANTIATE_INT4MM(DTYPE, GSIZE) \ +template \ +[[host_name("int4pack_mm_" #GSIZE "_" #DTYPE)]] \ +kernel void int4pack_mm( \ + constant DTYPE * A [[buffer(0)]], \ + constant uchar * B [[buffer(1)]], \ + constant DTYPE * scalesAndZeros [[buffer(2)]], \ + device DTYPE * outputData [[buffer(3)]], \ + constant uint3 & sizes [[buffer(4)]], \ + uint2 group_index [[threadgroup_position_in_grid]], \ + uint2 threadgroup_index [[thread_position_in_threadgroup]]) + +INSTANTIATE_INT4MM(float, 32); +INSTANTIATE_INT4MM(half, 32); +INSTANTIATE_INT4MM(float, 64); +INSTANTIATE_INT4MM(half, 64); +INSTANTIATE_INT4MM(float, 128); +INSTANTIATE_INT4MM(half, 128); +INSTANTIATE_INT4MM(float, 256); +INSTANTIATE_INT4MM(half, 256); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_INT4MM(bfloat, 32); +INSTANTIATE_INT4MM(bfloat, 64); +INSTANTIATE_INT4MM(bfloat, 128); +INSTANTIATE_INT4MM(bfloat, 256); +#endif diff --git a/metal-perf/m1block_int4mm.metal b/metal-perf/m1block_int4mm.metal new file mode 100644 index 0000000..7b9bf65 --- /dev/null +++ b/metal-perf/m1block_int4mm.metal @@ -0,0 +1,120 @@ +#include +using namespace metal; + +template struct Vec4Type {}; + +template <> struct Vec4Type { + using type = float4; +}; + +template <> struct Vec4Type { + using type = half4; +}; + +#if __METAL_VERSION__ >= 310 +template <> struct Vec4Type { + using type = bfloat4; +}; +#endif + +template struct Vec2Type {}; + +template <> struct Vec2Type { + using type = float2; +}; + +template <> struct Vec2Type { + using type = half2; +}; + +#if __METAL_VERSION__ >= 310 +template <> struct Vec2Type { + using type = bfloat2; +}; +#endif + +// [encoder dispatchThreads:MTLSizeMake(N / 2, 4, 1) +// threadsPerThreadgroup:MTLSizeMake(16, 4, 1)]; + +template +kernel void int4pack_mm( + constant T * A [[buffer(0)]], + constant uchar * B [[buffer(1)]], + constant T * scalesAndZeros [[buffer(2)]], + device T * outputData [[buffer(3)]], + constant uint3 & sizes [[buffer(4)]], // M, K, N + uint2 group_index [[threadgroup_position_in_grid]], + uint2 threadgroup_index [[thread_position_in_threadgroup]]) { + + const uint K = sizes.y; + const uint N = sizes.z; + const uint nb = group_index.x; // 0..N/32-1 + const uint n2 = 16 * nb + threadgroup_index.x; // 0..N/2-1 + const uint ldb = min(32U, N - nb * 32); + const uint32_t k_block = (K + groupSize - 1) / groupSize; + + using vecT = typename Vec2Type::type; + + constant T *A_ptr = A; + constant uchar *B_ptr = B + (nb * 16 * K); + + float2 rc = 0.0; + for (uint k = threadgroup_index.y; k < K; k += 4) { + threadgroup_barrier(mem_flags::mem_none); + + const auto a_val = float(A_ptr[k]); + uchar b_byte = B_ptr[(k * ldb + (2*n2 % 32))/2]; + + float2 b_val = float2( + float(b_byte & 0x0f), + float(b_byte >> 4)); + + uint kb = k / groupSize; + + float2 scales, zeros; + for (int i = 0; i < 2; ++i) { + scales[i] = scalesAndZeros[(kb * N + 2*n2 + i) * 2 + 0]; + zeros[i] = scalesAndZeros[(kb * N + 2*n2 + i) * 2 + 1] - scales[i] * T(8); + } + + float2 b_vec = scales * b_val + zeros; + rc += a_val * b_vec; + } + + threadgroup float2 tgp_memory[16][4]; + tgp_memory[threadgroup_index.x][threadgroup_index.y] = rc; + threadgroup_barrier(mem_flags::mem_threadgroup); + if (threadgroup_index.y == 0) { + for (unsigned i = 1; i < 4; i++) { + rc += tgp_memory[threadgroup_index.x][i]; + } + reinterpret_cast(outputData)[n2] = vecT(rc); + } +} + +#define INSTANTIATE_INT4MM(DTYPE, GSIZE) \ +template \ +[[host_name("int4pack_mm_" #GSIZE "_" #DTYPE)]] \ +kernel void int4pack_mm( \ + constant DTYPE * A [[buffer(0)]], \ + constant uchar * B [[buffer(1)]], \ + constant DTYPE * scalesAndZeros [[buffer(2)]], \ + device DTYPE * outputData [[buffer(3)]], \ + constant uint3 & sizes [[buffer(4)]], \ + uint2 group_index [[threadgroup_position_in_grid]], \ + uint2 threadgroup_index [[thread_position_in_threadgroup]]) + +INSTANTIATE_INT4MM(float, 32); +INSTANTIATE_INT4MM(half, 32); +INSTANTIATE_INT4MM(float, 64); +INSTANTIATE_INT4MM(half, 64); +INSTANTIATE_INT4MM(float, 128); +INSTANTIATE_INT4MM(half, 128); +INSTANTIATE_INT4MM(float, 256); +INSTANTIATE_INT4MM(half, 256); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_INT4MM(bfloat, 32); +INSTANTIATE_INT4MM(bfloat, 64); +INSTANTIATE_INT4MM(bfloat, 128); +INSTANTIATE_INT4MM(bfloat, 256); +#endif diff --git a/metal-perf/m1vec4_int4mm.metal b/metal-perf/m1vec4_int4mm.metal new file mode 100644 index 0000000..0b96f8e --- /dev/null +++ b/metal-perf/m1vec4_int4mm.metal @@ -0,0 +1,114 @@ +#include +using namespace metal; + +template struct Vec4Type {}; + +template <> struct Vec4Type { + using type = float4; +}; + +template <> struct Vec4Type { + using type = half4; +}; + +#if __METAL_VERSION__ >= 310 +template <> struct Vec4Type { + using type = bfloat4; +}; +#endif + +template struct Vec2Type {}; + +template <> struct Vec2Type { + using type = float2; +}; + +template <> struct Vec2Type { + using type = half2; +}; + +#if __METAL_VERSION__ >= 310 +template <> struct Vec2Type { + using type = bfloat2; +}; +#endif + +// [encoder dispatchThreads:MTLSizeMake(N / 4, 1, 1) +// threadsPerThreadgroup:MTLSizeMake(8, 1, 1)]; + +template +kernel void int4pack_mm( + constant T * A [[buffer(0)]], + constant uchar * B [[buffer(1)]], + constant T * scalesAndZeros [[buffer(2)]], + device T * outputData [[buffer(3)]], + constant uint3 & sizes [[buffer(4)]], // M, K, N + uint2 group_index [[threadgroup_position_in_grid]], + uint2 threadgroup_index [[thread_position_in_threadgroup]]) { + + const uint K = sizes.y; + const uint N = sizes.z; + const uint nb = group_index.x; // 0..N/32-1 + const uint n4 = 8 * nb + threadgroup_index.x; // 0..N/4-1 + const uint ldb = min(32U, N - nb * 32); + const uint32_t k_block = (K + groupSize - 1) / groupSize; + + using vecT = typename Vec4Type::type; + + constant T *A_ptr = A; + constant uchar *B_ptr = B + (nb * 16 * K); + + float4 rc = 0.0; + uint k = 0; + for (uint32_t kb = 0; kb < k_block ; kb ++) { + float4 scales, zeros; + for (int i = 0; i < 4; ++i) { + scales[i] = scalesAndZeros[(kb * N + 4*n4 + i) * 2 + 0]; + zeros[i] = scalesAndZeros[(kb * N + 4*n4 + i) * 2 + 1] - scales[i] * T(8); + } + + for(uint idx = 0; idx < groupSize && k < K; idx++, k++) { + const auto a_val = float(A_ptr[k]); + uchar b_byte0 = B_ptr[(k * ldb + (4*n4 % 32))/2]; + uchar b_byte1 = B_ptr[(k * ldb + (4*n4 % 32))/2 + 1]; + + float4 b_val = float4( + float(b_byte0 & 0x0f), + float(b_byte0 >> 4), + float(b_byte1 & 0x0f), + float(b_byte1 >> 4)); + + float4 b_vec = scales * b_val + zeros; + + rc += a_val * b_vec; + } + } + reinterpret_cast(outputData)[n4] = vecT(rc); +} + +#define INSTANTIATE_INT4MM(DTYPE, GSIZE) \ +template \ +[[host_name("int4pack_mm_" #GSIZE "_" #DTYPE)]] \ +kernel void int4pack_mm( \ + constant DTYPE * A [[buffer(0)]], \ + constant uchar * B [[buffer(1)]], \ + constant DTYPE * scalesAndZeros [[buffer(2)]], \ + device DTYPE * outputData [[buffer(3)]], \ + constant uint3 & sizes [[buffer(4)]], \ + uint2 group_index [[threadgroup_position_in_grid]], \ + uint2 threadgroup_index [[thread_position_in_threadgroup]]) + +INSTANTIATE_INT4MM(float, 32); +INSTANTIATE_INT4MM(half, 32); +INSTANTIATE_INT4MM(float, 64); +INSTANTIATE_INT4MM(half, 64); +INSTANTIATE_INT4MM(float, 128); +INSTANTIATE_INT4MM(half, 128); +INSTANTIATE_INT4MM(float, 256); +INSTANTIATE_INT4MM(half, 256); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_INT4MM(bfloat, 32); +INSTANTIATE_INT4MM(bfloat, 64); +INSTANTIATE_INT4MM(bfloat, 128); +INSTANTIATE_INT4MM(bfloat, 256); +#endif diff --git a/metal-perf/m1vec4block2_int4mm.metal b/metal-perf/m1vec4block2_int4mm.metal new file mode 100644 index 0000000..2796c43 --- /dev/null +++ b/metal-perf/m1vec4block2_int4mm.metal @@ -0,0 +1,128 @@ +#include +using namespace metal; + +template struct Vec4Type {}; + +template <> struct Vec4Type { + using type = float4; +}; + +template <> struct Vec4Type { + using type = half4; +}; + +#if __METAL_VERSION__ >= 310 +template <> struct Vec4Type { + using type = bfloat4; +}; +#endif + +template struct Vec2Type {}; + +template <> struct Vec2Type { + using type = float2; +}; + +template <> struct Vec2Type { + using type = half2; +}; + +#if __METAL_VERSION__ >= 310 +template <> struct Vec2Type { + using type = bfloat2; +}; +#endif + +// [encoder dispatchThreads:MTLSizeMake(N / 2, 4, 1) +// threadsPerThreadgroup:MTLSizeMake(16, 4, 1)]; + +template +kernel void int4pack_mm( + constant T * A [[buffer(0)]], + constant uchar * B [[buffer(1)]], + constant T * scalesAndZeros [[buffer(2)]], + device T * outputData [[buffer(3)]], + constant uint3 & sizes [[buffer(4)]], // M, K, N + uint2 group_index [[threadgroup_position_in_grid]], + uint2 threadgroup_index [[thread_position_in_threadgroup]]) { + + const uint K = sizes.y; + const uint N = sizes.z; + const uint nb = group_index.x; // 0..N/32-1 + const uint n2 = 16 * nb + threadgroup_index.x; // 0..N/2-1 + const uint ldb = min(32U, N - nb * 32); + const uint32_t k_block = (K + groupSize - 1) / groupSize; + + using vec2T = typename Vec2Type::type; + using vec4T = typename Vec4Type::type; + + constant vec4T *A_ptr = reinterpret_cast(A); + constant uchar *B_ptr = B + (nb * 16 * K); + + float2 rc = 0.0; + uint k = threadgroup_index.y * 4; + for (uint32_t kb = 0; kb < k_block ; kb ++) { + float2 scales, zeros; + for (int i = 0; i < 2; ++i) { + scales[i] = scalesAndZeros[(kb * N + 2*n2 + i) * 2 + 0]; + zeros[i] = scalesAndZeros[(kb * N + 2*n2 + i) * 2 + 1] - scales[i] * T(8); + } + + for(uint idx = k % groupSize; idx < groupSize && k < K; idx += 16, k += 16) { + threadgroup_barrier(mem_flags::mem_none); + + const auto a_vec = float4(A_ptr[k/4]); + uchar4 b_byte; + for (int i = 0; i < 4; i++) { + b_byte[i] = B_ptr[((k + i) * ldb + (2*n2 % 32))/2]; + } + + float4x2 b_mat; + + for (int i = 0; i < 4; i++) { + b_mat[i] = scales * float2( + float(b_byte[i] & 0x0f), + float(b_byte[i] >> 4)) + zeros; + } + + rc += b_mat * a_vec; + } + } + + threadgroup float2 tgp_memory[16][4]; + tgp_memory[threadgroup_index.x][threadgroup_index.y] = rc; + threadgroup_barrier(mem_flags::mem_threadgroup); + if (threadgroup_index.y == 0) { + for (unsigned i = 1; i < 4; i++) { + rc += tgp_memory[threadgroup_index.x][i]; + } + reinterpret_cast(outputData)[n2] = vec2T(rc); + } +} + +#define INSTANTIATE_INT4MM(DTYPE, GSIZE) \ +template \ +[[host_name("int4pack_mm_" #GSIZE "_" #DTYPE)]] \ +kernel void int4pack_mm( \ + constant DTYPE * A [[buffer(0)]], \ + constant uchar * B [[buffer(1)]], \ + constant DTYPE * scalesAndZeros [[buffer(2)]], \ + device DTYPE * outputData [[buffer(3)]], \ + constant uint3 & sizes [[buffer(4)]], \ + uint2 group_index [[threadgroup_position_in_grid]], \ + uint2 threadgroup_index [[thread_position_in_threadgroup]]) + +INSTANTIATE_INT4MM(float, 32); +INSTANTIATE_INT4MM(half, 32); +INSTANTIATE_INT4MM(float, 64); +INSTANTIATE_INT4MM(half, 64); +INSTANTIATE_INT4MM(float, 128); +INSTANTIATE_INT4MM(half, 128); +INSTANTIATE_INT4MM(float, 256); +INSTANTIATE_INT4MM(half, 256); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_INT4MM(bfloat, 32); +INSTANTIATE_INT4MM(bfloat, 64); +INSTANTIATE_INT4MM(bfloat, 128); +INSTANTIATE_INT4MM(bfloat, 256); +#endif diff --git a/metal-perf/m1vec4block_int4mm.metal b/metal-perf/m1vec4block_int4mm.metal new file mode 100644 index 0000000..d98b2d1 --- /dev/null +++ b/metal-perf/m1vec4block_int4mm.metal @@ -0,0 +1,126 @@ +#include +using namespace metal; + +template struct Vec4Type {}; + +template <> struct Vec4Type { + using type = float4; +}; + +template <> struct Vec4Type { + using type = half4; +}; + +#if __METAL_VERSION__ >= 310 +template <> struct Vec4Type { + using type = bfloat4; +}; +#endif + +template struct Vec2Type {}; + +template <> struct Vec2Type { + using type = float2; +}; + +template <> struct Vec2Type { + using type = half2; +}; + +#if __METAL_VERSION__ >= 310 +template <> struct Vec2Type { + using type = bfloat2; +}; +#endif + +// [encoder dispatchThreads:MTLSizeMake(N / 2, 4, 1) +// threadsPerThreadgroup:MTLSizeMake(16, 4, 1)]; + +template +kernel void int4pack_mm( + constant T * A [[buffer(0)]], + constant uchar * B [[buffer(1)]], + constant T * scalesAndZeros [[buffer(2)]], + device T * outputData [[buffer(3)]], + constant uint3 & sizes [[buffer(4)]], // M, K, N + uint2 group_index [[threadgroup_position_in_grid]], + uint2 threadgroup_index [[thread_position_in_threadgroup]]) { + + const uint K = sizes.y; + const uint N = sizes.z; + const uint nb = group_index.x; // 0..N/32-1 + const uint n2 = 16 * nb + threadgroup_index.x; // 0..N/2-1 + const uint ldb = min(32U, N - nb * 32); + const uint32_t k_block = (K + groupSize - 1) / groupSize; + + using vec2T = typename Vec2Type::type; + using vec4T = typename Vec4Type::type; + + constant vec4T *A_ptr = reinterpret_cast(A); + constant uchar *B_ptr = B + (nb * 16 * K); + + float2 rc = 0.0; + for (uint k = threadgroup_index.y * 4; k < K; k += 16) { + threadgroup_barrier(mem_flags::mem_none); + + const auto a_vec = float4(A_ptr[k/4]); + uchar4 b_byte; + for (int i = 0; i < 4; i++) { + b_byte[i] = B_ptr[((k + i) * ldb + (2*n2 % 32))/2]; + } + + uint kb = k / groupSize; + float2 scales, zeros; + for (int i = 0; i < 2; ++i) { + scales[i] = scalesAndZeros[(kb * N + 2*n2 + i) * 2 + 0]; + zeros[i] = scalesAndZeros[(kb * N + 2*n2 + i) * 2 + 1] - scales[i] * T(8); + } + + float4x2 b_mat; + + for (int i = 0; i < 4; i++) { + b_mat[i] = scales * float2( + float(b_byte[i] & 0x0f), + float(b_byte[i] >> 4)) + zeros; + } + + rc += b_mat * a_vec; + } + + threadgroup float2 tgp_memory[16][4]; + tgp_memory[threadgroup_index.x][threadgroup_index.y] = rc; + threadgroup_barrier(mem_flags::mem_threadgroup); + if (threadgroup_index.y == 0) { + for (unsigned i = 1; i < 4; i++) { + rc += tgp_memory[threadgroup_index.x][i]; + } + reinterpret_cast(outputData)[n2] = vec2T(rc); + } +} + +#define INSTANTIATE_INT4MM(DTYPE, GSIZE) \ +template \ +[[host_name("int4pack_mm_" #GSIZE "_" #DTYPE)]] \ +kernel void int4pack_mm( \ + constant DTYPE * A [[buffer(0)]], \ + constant uchar * B [[buffer(1)]], \ + constant DTYPE * scalesAndZeros [[buffer(2)]], \ + device DTYPE * outputData [[buffer(3)]], \ + constant uint3 & sizes [[buffer(4)]], \ + uint2 group_index [[threadgroup_position_in_grid]], \ + uint2 threadgroup_index [[thread_position_in_threadgroup]]) + +INSTANTIATE_INT4MM(float, 32); +INSTANTIATE_INT4MM(half, 32); +INSTANTIATE_INT4MM(float, 64); +INSTANTIATE_INT4MM(half, 64); +INSTANTIATE_INT4MM(float, 128); +INSTANTIATE_INT4MM(half, 128); +INSTANTIATE_INT4MM(float, 256); +INSTANTIATE_INT4MM(half, 256); +#if __METAL_VERSION__ >= 310 +INSTANTIATE_INT4MM(bfloat, 32); +INSTANTIATE_INT4MM(bfloat, 64); +INSTANTIATE_INT4MM(bfloat, 128); +INSTANTIATE_INT4MM(bfloat, 256); +#endif diff --git a/metal-perf/reduce_mat4xmat4_int4mm.metal b/metal-perf/reduce_mat4xmat4_int4mm.metal index d8b0203..9253f60 100644 --- a/metal-perf/reduce_mat4xmat4_int4mm.metal +++ b/metal-perf/reduce_mat4xmat4_int4mm.metal @@ -27,6 +27,7 @@ kernel void int4pack_mm( device T * outputData [[buffer(3)]], constant uint3 & sizes [[buffer(4)]], // M, K, N uint2 thread_index [[thread_position_in_grid]]) { + const uint M = sizes.x; const uint K = sizes.y; const uint N = sizes.z; const uint m = thread_index.y; // 0..M/4-1 @@ -37,7 +38,6 @@ kernel void int4pack_mm( using vecT = typename Vec4Type::type; constant vecT *A_ptr = reinterpret_cast(A + m * 4 * K); - //constant uchar2 *B_ptr = reinterpret_cast(B + (nb * 16 * K)); constant uchar *B_ptr = B + (nb * 16 * K); float4x4 rc; @@ -46,26 +46,40 @@ kernel void int4pack_mm( } uint k = 0; for (uint32_t kb = 0; kb < k_block ; kb ++) { - const T scale0 = scalesAndZeros[(kb * N + 4 * n) * 2 + 0]; - const T zero0 = scalesAndZeros[(kb * N + 4 * n) * 2 + 1] - scale0 * T(8); - - const T scale1 = scalesAndZeros[(kb * N + 4 * n + 1) * 2 + 0]; - const T zero1 = scalesAndZeros[(kb * N + 4 * n + 1) * 2 + 1] - scale1 * T(8); - - const T scale2 = scalesAndZeros[(kb * N + 4 * n + 2) * 2 + 0]; - const T zero2 = scalesAndZeros[(kb * N + 4 * n + 2) * 2 + 1] - scale2 * T(8); - - const T scale3 = scalesAndZeros[(kb * N + 4 * n + 3) * 2 + 0]; - const T zero3 = scalesAndZeros[(kb * N + 4 * n + 3) * 2 + 1] - scale3 * T(8); - - const float4 scales = float4(scale0, scale1, scale2, scale3); - const float4 zeros = float4(zero0, zero1, zero2, zero3); + float4 scales, zeros; + for (int i = 0; i < 4; ++i) { + scales[i] = scalesAndZeros[(kb * N + 4 * n + i) * 2 + 0]; + zeros[i] = scalesAndZeros[(kb * N + 4 * n + i) * 2 + 1] - scales[i] * T(8); + } for(uint idx = 0; idx < groupSize && k < K; idx += 4, k += 4) { float4x4 a_mat; + + /* for(int j = 0; j < 4; ++j) { + a_mat[j] = float4(0.0); + } + for(int j = 0; j < 4 & m + j < M; ++j) { + a_mat[j] = float4(A_ptr[k/4 + j * K / 4]); + } + */ + + /* + for(uint j = 0; j < 4; ++j) { + j = min(j, M-m-1); a_mat[j] = float4(A_ptr[k/4 + j * K / 4]); } + */ + + if (M % 4 == 0) { + for(int j = 0; j < 4; ++j) { + a_mat[j] = float4(A_ptr[k/4 + j * K / 4]); + } + } else { + for(int j = 0; j < 4; ++j) { + a_mat[j] = j < M - m ? float4(A_ptr[k/4 + j * K / 4]) : float4(0.0); + } + } float4x4 t_b_mat; for(int j = 0; j < 4; ++j) { @@ -82,10 +96,9 @@ kernel void int4pack_mm( rc += t_b_mat * a_mat; } } - reinterpret_cast(outputData + 4 * m * N)[n] = vecT(rc[0]); - reinterpret_cast(outputData + (4 * m + 1) * N)[n] = vecT(rc[1]); - reinterpret_cast(outputData + (4 * m + 2) * N)[n] = vecT(rc[2]); - reinterpret_cast(outputData + (4 * m + 3) * N)[n] = vecT(rc[3]); + for (int i = 0; i < 4; ++i) { + reinterpret_cast(outputData + (4 * m + i) * N)[n] = vecT(rc[i]); + } } #define INSTANTIATE_INT4MM(DTYPE, GSIZE) \