Skip to content

Commit 7941b6b

Browse files
committed
metal : GGML_OP_NORM
1 parent aee082f commit 7941b6b

File tree

3 files changed

+79
-44
lines changed

3 files changed

+79
-44
lines changed

ggml/src/ggml-common.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,13 @@ typedef struct {
643643
uint64_t nb1;
644644
} ggml_metal_kargs_mul_mv_id;
645645

646+
typedef struct {
647+
int32_t ne00;
648+
int32_t ne00_4;
649+
uint64_t nb01;
650+
float eps;
651+
} ggml_metal_kargs_norm;
652+
646653
typedef struct {
647654
int32_t ne00;
648655
int32_t ne00_4;

ggml/src/ggml-metal.m

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2681,22 +2681,35 @@ static void ggml_metal_encode_node(
26812681
} break;
26822682
case GGML_OP_NORM:
26832683
{
2684+
GGML_ASSERT(ne00 % 4 == 0);
26842685
GGML_ASSERT(ggml_is_contiguous_1(src0));
26852686

26862687
float eps;
26872688
memcpy(&eps, dst->op_params, sizeof(float));
26882689

2689-
const int nth = MIN(256, ne00);
2690-
26912690
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline;
26922691

2692+
int nth = 32; // SIMD width
2693+
2694+
while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
2695+
nth *= 2;
2696+
}
2697+
2698+
nth = MIN(nth, ne00/4);
2699+
2700+
ggml_metal_kargs_norm args = {
2701+
/*.ne00 =*/ ne00,
2702+
/*.ne00_4 =*/ ne00/4,
2703+
/*.nb01 =*/ nb01,
2704+
/*.eps =*/ eps,
2705+
};
2706+
26932707
[encoder setComputePipelineState:pipeline];
2694-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2695-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2696-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2697-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
2698-
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
2699-
[encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
2708+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
2709+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2710+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2711+
2712+
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
27002713

27012714
const int64_t nrows = ggml_nrows(src0);
27022715

ggml/src/ggml-metal.metal

Lines changed: 51 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,53 +1236,68 @@ kernel void kernel_ssm_scan_f32(
12361236
}
12371237

12381238
kernel void kernel_norm(
1239-
device const void * src0,
1240-
device float * dst,
1241-
constant int64_t & ne00,
1242-
constant uint64_t & nb01,
1243-
constant float & eps,
1244-
threadgroup float * sum [[threadgroup(0)]],
1245-
uint tgpig[[threadgroup_position_in_grid]],
1246-
uint tpitg[[thread_position_in_threadgroup]],
1247-
uint ntg[[threads_per_threadgroup]]) {
1248-
device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
1249-
// MEAN
1250-
// parallel sum
1251-
sum[tpitg] = 0.0f;
1252-
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
1253-
sum[tpitg] += x[i00];
1239+
constant ggml_metal_kargs_norm & args,
1240+
device const char * src0,
1241+
device char * dst,
1242+
threadgroup float * shmem_f32 [[threadgroup(0)]],
1243+
uint tgpig[[threadgroup_position_in_grid]],
1244+
ushort tpitg[[thread_position_in_threadgroup]],
1245+
ushort sgitg[[simdgroup_index_in_threadgroup]],
1246+
ushort tiisg[[thread_index_in_simdgroup]],
1247+
ushort ntg[[threads_per_threadgroup]]) {
1248+
if (sgitg == 0) {
1249+
shmem_f32[tiisg] = 0.0f;
12541250
}
1255-
// reduce
1251+
1252+
device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
1253+
1254+
float4 sumf4(0.0f);
1255+
1256+
float sumf = 0.0f;
1257+
1258+
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
1259+
sumf4 += x[i00];
1260+
}
1261+
sumf = sumf4[0] + sumf4[1] + sumf4[2] + sumf4[3];
1262+
sumf = simd_sum(sumf);
1263+
12561264
threadgroup_barrier(mem_flags::mem_threadgroup);
1257-
for (uint i = ntg/2; i > 0; i /= 2) {
1258-
if (tpitg < i) {
1259-
sum[tpitg] += sum[tpitg + i];
1260-
}
1261-
threadgroup_barrier(mem_flags::mem_threadgroup);
1265+
1266+
if (tiisg == 0) {
1267+
shmem_f32[sgitg] = sumf;
12621268
}
1263-
const float mean = sum[0] / ne00;
12641269

1265-
// recenter and VARIANCE
12661270
threadgroup_barrier(mem_flags::mem_threadgroup);
1267-
device float * y = dst + tgpig*ne00;
1268-
sum[tpitg] = 0.0f;
1269-
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
1271+
1272+
sumf = shmem_f32[tiisg];
1273+
sumf = simd_sum(sumf);
1274+
1275+
const float mean = sumf/args.ne00;
1276+
1277+
device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
1278+
1279+
sumf = 0.0f;
1280+
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
12701281
y[i00] = x[i00] - mean;
1271-
sum[tpitg] += y[i00] * y[i00];
1282+
sumf += dot(y[i00], y[i00]);
12721283
}
1284+
sumf = simd_sum(sumf);
12731285

1274-
// reduce
12751286
threadgroup_barrier(mem_flags::mem_threadgroup);
1276-
for (uint i = ntg/2; i > 0; i /= 2) {
1277-
if (tpitg < i) {
1278-
sum[tpitg] += sum[tpitg + i];
1279-
}
1280-
threadgroup_barrier(mem_flags::mem_threadgroup);
1287+
1288+
if (tiisg == 0) {
1289+
shmem_f32[sgitg] = sumf;
12811290
}
1282-
const float variance = sum[0] / ne00;
12831291

1284-
const float scale = 1.0f/sqrt(variance + eps);
1285-
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
1292+
threadgroup_barrier(mem_flags::mem_threadgroup);
1293+
1294+
sumf = shmem_f32[tiisg];
1295+
sumf = simd_sum(sumf);
1296+
1297+
const float variance = sumf/args.ne00;
1298+
1299+
const float scale = 1.0f/sqrt(variance + args.eps);
1300+
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
12861301
y[i00] = y[i00] * scale;
12871302
}
12881303
}

0 commit comments

Comments
 (0)