Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 63 additions & 7 deletions metal-perf/int4mm.mm
Original file line number Diff line number Diff line change
Expand Up @@ -306,31 +306,87 @@ void dispatchThreads(id<MTLComputeCommandEncoder> encoder,
using Int4MMOpDescriptor<groupSize>::N;
void dispatchThreads(id<MTLComputeCommandEncoder> encoder,
unsigned maxThreadsPerGroup) const override {
constexpr auto blockSize = 8;
if (maxThreadsPerGroup < blockSize * blockSize) {
throw std::runtime_error("Can't dispatch!");
}
[encoder dispatchThreads:MTLSizeMake(N/4, M/4, 1)
threadsPerThreadgroup:MTLSizeMake(blockSize, blockSize, 1)];
[encoder dispatchThreads:MTLSizeMake(N / 4, (M + 3) / 4, 1)
threadsPerThreadgroup:MTLSizeMake(std::min(maxThreadsPerGroup, M), 1, 1)];
}
};

template <unsigned groupSize>
struct Int4MM1OpDescriptor : public Int4MMOpDescriptor<groupSize> {
using Int4MMOpDescriptor<groupSize>::Int4MMOpDescriptor;
using Int4MMOpDescriptor<groupSize>::K;
using Int4MMOpDescriptor<groupSize>::N;
void dispatchThreads(id<MTLComputeCommandEncoder> encoder,
unsigned maxThreadsPerGroup) const override {
[encoder dispatchThreads:MTLSizeMake(N / 2, 1, 1)
threadsPerThreadgroup:MTLSizeMake(16, 1, 1)];
}
};

template <unsigned groupSize>
struct Int4MM1Vec4OpDescriptor : public Int4MMOpDescriptor<groupSize> {
using Int4MMOpDescriptor<groupSize>::Int4MMOpDescriptor;
using Int4MMOpDescriptor<groupSize>::K;
using Int4MMOpDescriptor<groupSize>::N;
void dispatchThreads(id<MTLComputeCommandEncoder> encoder,
unsigned maxThreadsPerGroup) const override {
[encoder dispatchThreads:MTLSizeMake(N / 4, 1, 1)
threadsPerThreadgroup:MTLSizeMake(8, 1, 1)];
}
};

template <unsigned groupSize>
struct Int4MM1BlockOpDescriptor : public Int4MMOpDescriptor<groupSize> {
using Int4MMOpDescriptor<groupSize>::Int4MMOpDescriptor;
using Int4MMOpDescriptor<groupSize>::K;
using Int4MMOpDescriptor<groupSize>::N;
void dispatchThreads(id<MTLComputeCommandEncoder> encoder,
unsigned maxThreadsPerGroup) const override {
[encoder dispatchThreads:MTLSizeMake(N / 2, 4, 1)
threadsPerThreadgroup:MTLSizeMake(16, 4, 1)];
}
};

template <unsigned groupSize>
struct Int4MM1Vec4BlockOpDescriptor : public Int4MMOpDescriptor<groupSize> {
using Int4MMOpDescriptor<groupSize>::Int4MMOpDescriptor;
using Int4MMOpDescriptor<groupSize>::K;
using Int4MMOpDescriptor<groupSize>::N;
void dispatchThreads(id<MTLComputeCommandEncoder> encoder,
unsigned maxThreadsPerGroup) const override {
[encoder dispatchThreads:MTLSizeMake(N / 2, 4, 1)
threadsPerThreadgroup:MTLSizeMake(16, 4, 1)];
}
};

int main() {
unsigned M, N, K;
std::tie(M, N, K) = std::make_tuple(32, 4128, 4096);
//std::tie(M, N, K) = std::make_tuple(32, 4128, 4128);
std::tie(M, N, K) = std::make_tuple(1, 4128, 4128);
constexpr unsigned groupSize = 32;
@autoreleasepool {
id<MTLDevice> device = getMetalDevice();
std::cout << "Using device " << device.name.UTF8String << std::endl;
std::cout << "Dimensions (M, N, K) = (" << M << ", " << N << ", " << K << ")" << std::endl;
Int4MMOpDescriptor<groupSize> naive_int4mm(device, "naive_int4mm", M, N, K);
Int4MMOpDescriptor<groupSize> reduce_vec4_int4mm(device, "reduce_vec4_int4mm", M, N,
K);
Int4MMMat4OpDescriptor<groupSize> reduce_mat4_int4mm(device, "reduce_mat4_int4mm", M, N,
K);
Int4MMMat4xMat4OpDescriptor<groupSize> reduce_mat4xmat4_int4mm(device, "reduce_mat4xmat4_int4mm", M, N,
K);
Int4MM1OpDescriptor<groupSize> m1_int4mm(device, "m1_int4mm", M, N, K);
Int4MM1Vec4OpDescriptor<groupSize> m1vec4_int4mm(device, "m1vec4_int4mm", M, N, K);
Int4MM1BlockOpDescriptor<groupSize> m1block_int4mm(device, "m1block_int4mm", M, N, K);
Int4MM1Vec4BlockOpDescriptor<groupSize> m1vec4block_int4mm(device, "m1vec4block_int4mm", M, N, K);
Int4MM1Vec4BlockOpDescriptor<groupSize> m1vec4block2_int4mm(device, "m1vec4block2_int4mm", M, N, K);

// Benchmarks
m1vec4block2_int4mm.benchmark<BFloat16>();
m1vec4block_int4mm.benchmark<BFloat16>();
m1block_int4mm.benchmark<BFloat16>();
m1vec4_int4mm.benchmark<BFloat16>();
m1_int4mm.benchmark<BFloat16>();
reduce_mat4xmat4_int4mm.benchmark<BFloat16>();
reduce_mat4_int4mm.benchmark<BFloat16>();
reduce_vec4_int4mm.benchmark<BFloat16>();
Expand Down
112 changes: 112 additions & 0 deletions metal-perf/m1_int4mm.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#include <metal_stdlib>
using namespace metal;

template <typename T> struct Vec4Type {};

template <> struct Vec4Type<float> {
using type = float4;
};

template <> struct Vec4Type<half> {
using type = half4;
};

#if __METAL_VERSION__ >= 310
template <> struct Vec4Type<bfloat> {
using type = bfloat4;
};
#endif

template <typename T> struct Vec2Type {};

template <> struct Vec2Type<float> {
using type = float2;
};

template <> struct Vec2Type<half> {
using type = half2;
};

#if __METAL_VERSION__ >= 310
template <> struct Vec2Type<bfloat> {
using type = bfloat2;
};
#endif

// [encoder dispatchThreads:MTLSizeMake(N / 2, 1, 1)
// threadsPerThreadgroup:MTLSizeMake(16, 1, 1)];

template<typename T, unsigned groupSize>
kernel void int4pack_mm(
constant T * A [[buffer(0)]],
constant uchar * B [[buffer(1)]],
constant T * scalesAndZeros [[buffer(2)]],
device T * outputData [[buffer(3)]],
constant uint3 & sizes [[buffer(4)]], // M, K, N
uint2 group_index [[threadgroup_position_in_grid]],
uint2 threadgroup_index [[thread_position_in_threadgroup]]) {

const uint K = sizes.y;
const uint N = sizes.z;
const uint nb = group_index.x; // 0..N/32-1
const uint n2 = 16 * nb + threadgroup_index.x; // 0..N/2-1
const uint ldb = min(32U, N - nb * 32);
const uint32_t k_block = (K + groupSize - 1) / groupSize;

using vecT = typename Vec2Type<T>::type;

constant T *A_ptr = A;
constant uchar *B_ptr = B + (nb * 16 * K);

float2 rc = 0.0;
uint k = 0;
for (uint32_t kb = 0; kb < k_block ; kb ++) {
float2 scales, zeros;
for (int i = 0; i < 2; ++i) {
scales[i] = scalesAndZeros[(kb * N + 2*n2 + i) * 2 + 0];
zeros[i] = scalesAndZeros[(kb * N + 2*n2 + i) * 2 + 1] - scales[i] * T(8);
}

for(uint idx = 0; idx < groupSize && k < K; idx++, k++) {
const auto a_val = float(A_ptr[k]);
uchar b_byte0 = B_ptr[(k * ldb + (2*n2 % 32))/2];
//uchar b_byte1 = B_ptr[(k * ldb + (2*n2 % 16))/2 + 1];

float2 b_val = float2(
float(b_byte0 & 0x0f),
float(b_byte0 >> 4));

float2 b_vec = scales * b_val + zeros;

rc += a_val * b_vec;
}
}
reinterpret_cast<device vecT*>(outputData)[n2] = vecT(rc);
}

#define INSTANTIATE_INT4MM(DTYPE, GSIZE) \
template \
[[host_name("int4pack_mm_" #GSIZE "_" #DTYPE)]] \
kernel void int4pack_mm<DTYPE, GSIZE>( \
constant DTYPE * A [[buffer(0)]], \
constant uchar * B [[buffer(1)]], \
constant DTYPE * scalesAndZeros [[buffer(2)]], \
device DTYPE * outputData [[buffer(3)]], \
constant uint3 & sizes [[buffer(4)]], \
uint2 group_index [[threadgroup_position_in_grid]], \
uint2 threadgroup_index [[thread_position_in_threadgroup]])

INSTANTIATE_INT4MM(float, 32);
INSTANTIATE_INT4MM(half, 32);
INSTANTIATE_INT4MM(float, 64);
INSTANTIATE_INT4MM(half, 64);
INSTANTIATE_INT4MM(float, 128);
INSTANTIATE_INT4MM(half, 128);
INSTANTIATE_INT4MM(float, 256);
INSTANTIATE_INT4MM(half, 256);
#if __METAL_VERSION__ >= 310
INSTANTIATE_INT4MM(bfloat, 32);
INSTANTIATE_INT4MM(bfloat, 64);
INSTANTIATE_INT4MM(bfloat, 128);
INSTANTIATE_INT4MM(bfloat, 256);
#endif
120 changes: 120 additions & 0 deletions metal-perf/m1block_int4mm.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#include <metal_stdlib>
using namespace metal;

template <typename T> struct Vec4Type {};

template <> struct Vec4Type<float> {
using type = float4;
};

template <> struct Vec4Type<half> {
using type = half4;
};

#if __METAL_VERSION__ >= 310
template <> struct Vec4Type<bfloat> {
using type = bfloat4;
};
#endif

template <typename T> struct Vec2Type {};

template <> struct Vec2Type<float> {
using type = float2;
};

template <> struct Vec2Type<half> {
using type = half2;
};

#if __METAL_VERSION__ >= 310
template <> struct Vec2Type<bfloat> {
using type = bfloat2;
};
#endif

// [encoder dispatchThreads:MTLSizeMake(N / 2, 4, 1)
// threadsPerThreadgroup:MTLSizeMake(16, 4, 1)];

template<typename T, unsigned groupSize>
kernel void int4pack_mm(
constant T * A [[buffer(0)]],
constant uchar * B [[buffer(1)]],
constant T * scalesAndZeros [[buffer(2)]],
device T * outputData [[buffer(3)]],
constant uint3 & sizes [[buffer(4)]], // M, K, N
uint2 group_index [[threadgroup_position_in_grid]],
uint2 threadgroup_index [[thread_position_in_threadgroup]]) {

const uint K = sizes.y;
const uint N = sizes.z;
const uint nb = group_index.x; // 0..N/32-1
const uint n2 = 16 * nb + threadgroup_index.x; // 0..N/2-1
const uint ldb = min(32U, N - nb * 32);
const uint32_t k_block = (K + groupSize - 1) / groupSize;

using vecT = typename Vec2Type<T>::type;

constant T *A_ptr = A;
constant uchar *B_ptr = B + (nb * 16 * K);

float2 rc = 0.0;
for (uint k = threadgroup_index.y; k < K; k += 4) {
threadgroup_barrier(mem_flags::mem_none);

const auto a_val = float(A_ptr[k]);
uchar b_byte = B_ptr[(k * ldb + (2*n2 % 32))/2];

float2 b_val = float2(
float(b_byte & 0x0f),
float(b_byte >> 4));

uint kb = k / groupSize;

float2 scales, zeros;
for (int i = 0; i < 2; ++i) {
scales[i] = scalesAndZeros[(kb * N + 2*n2 + i) * 2 + 0];
zeros[i] = scalesAndZeros[(kb * N + 2*n2 + i) * 2 + 1] - scales[i] * T(8);
}

float2 b_vec = scales * b_val + zeros;
rc += a_val * b_vec;
}

threadgroup float2 tgp_memory[16][4];
tgp_memory[threadgroup_index.x][threadgroup_index.y] = rc;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (threadgroup_index.y == 0) {
for (unsigned i = 1; i < 4; i++) {
rc += tgp_memory[threadgroup_index.x][i];
}
reinterpret_cast<device vecT*>(outputData)[n2] = vecT(rc);
}
}

#define INSTANTIATE_INT4MM(DTYPE, GSIZE) \
template \
[[host_name("int4pack_mm_" #GSIZE "_" #DTYPE)]] \
kernel void int4pack_mm<DTYPE, GSIZE>( \
constant DTYPE * A [[buffer(0)]], \
constant uchar * B [[buffer(1)]], \
constant DTYPE * scalesAndZeros [[buffer(2)]], \
device DTYPE * outputData [[buffer(3)]], \
constant uint3 & sizes [[buffer(4)]], \
uint2 group_index [[threadgroup_position_in_grid]], \
uint2 threadgroup_index [[thread_position_in_threadgroup]])

INSTANTIATE_INT4MM(float, 32);
INSTANTIATE_INT4MM(half, 32);
INSTANTIATE_INT4MM(float, 64);
INSTANTIATE_INT4MM(half, 64);
INSTANTIATE_INT4MM(float, 128);
INSTANTIATE_INT4MM(half, 128);
INSTANTIATE_INT4MM(float, 256);
INSTANTIATE_INT4MM(half, 256);
#if __METAL_VERSION__ >= 310
INSTANTIATE_INT4MM(bfloat, 32);
INSTANTIATE_INT4MM(bfloat, 64);
INSTANTIATE_INT4MM(bfloat, 128);
INSTANTIATE_INT4MM(bfloat, 256);
#endif
Loading