Skip to content

Commit aee082f

Browse files
committed
metal : GGML_OP_RMS_NORM
1 parent 967727a commit aee082f

File tree

3 files changed

+51
-41
lines changed

3 files changed

+51
-41
lines changed

ggml/src/ggml-common.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,13 @@ typedef struct {
642642
int32_t ne1;
643643
uint64_t nb1;
644644
} ggml_metal_kargs_mul_mv_id;
645+
646+
typedef struct {
647+
int32_t ne00;
648+
int32_t ne00_4;
649+
uint64_t nb01;
650+
float eps;
651+
} ggml_metal_kargs_rms_norm;
645652
#endif
646653

647654
#endif // GGML_COMMON_DECL

ggml/src/ggml-metal.m

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2618,20 +2618,28 @@ static void ggml_metal_encode_node(
26182618
float eps;
26192619
memcpy(&eps, dst->op_params, sizeof(float));
26202620

2621+
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
2622+
26212623
int nth = 32; // SIMD width
26222624

2623-
while (nth < ne00/4 && nth < 1024) {
2625+
while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
26242626
nth *= 2;
26252627
}
26262628

2627-
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
2629+
nth = MIN(nth, ne00/4);
2630+
2631+
ggml_metal_kargs_rms_norm args = {
2632+
/*.ne00 =*/ ne00,
2633+
/*.ne00_4 =*/ ne00/4,
2634+
/*.nb01 =*/ nb01,
2635+
/*.eps =*/ eps,
2636+
};
26282637

26292638
[encoder setComputePipelineState:pipeline];
2630-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2631-
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
2632-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
2633-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
2634-
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
2639+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
2640+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
2641+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2642+
26352643
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
26362644

26372645
const int64_t nrows = ggml_nrows(src0);

ggml/src/ggml-metal.metal

Lines changed: 29 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,50 +1288,45 @@ kernel void kernel_norm(
12881288
}
12891289

12901290
kernel void kernel_rms_norm(
1291-
device const void * src0,
1292-
device float * dst,
1293-
constant int64_t & ne00,
1294-
constant uint64_t & nb01,
1295-
constant float & eps,
1296-
threadgroup float * buf [[threadgroup(0)]],
1297-
uint tgpig[[threadgroup_position_in_grid]],
1298-
uint tpitg[[thread_position_in_threadgroup]],
1299-
uint sgitg[[simdgroup_index_in_threadgroup]],
1300-
uint tiisg[[thread_index_in_simdgroup]],
1301-
uint ntg[[threads_per_threadgroup]]) {
1302-
device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
1291+
constant ggml_metal_kargs_rms_norm & args,
1292+
device const char * src0,
1293+
device char * dst,
1294+
threadgroup float * shmem_f32 [[threadgroup(0)]],
1295+
uint tgpig[[threadgroup_position_in_grid]],
1296+
ushort tpitg[[thread_position_in_threadgroup]],
1297+
ushort sgitg[[simdgroup_index_in_threadgroup]],
1298+
ushort tiisg[[thread_index_in_simdgroup]],
1299+
ushort ntg[[threads_per_threadgroup]]) {
1300+
if (sgitg == 0) {
1301+
shmem_f32[tiisg] = 0.0f;
1302+
}
13031303

1304-
float4 sumf = 0;
1305-
float all_sum = 0;
1304+
device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
1305+
1306+
float sumf = 0.0f;
13061307

13071308
// parallel sum
1308-
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
1309-
sumf += x[i00] * x[i00];
1309+
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
1310+
sumf += dot(x[i00], x[i00]);
13101311
}
1311-
all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
1312-
all_sum = simd_sum(all_sum);
1313-
if (ntg > N_SIMDWIDTH) {
1314-
if (sgitg == 0) {
1315-
buf[tiisg] = 0.0f;
1316-
}
1312+
sumf = simd_sum(sumf);
13171313

1318-
threadgroup_barrier(mem_flags::mem_threadgroup);
1314+
threadgroup_barrier(mem_flags::mem_threadgroup);
13191315

1320-
if (tiisg == 0) {
1321-
buf[sgitg] = all_sum;
1322-
}
1316+
if (tiisg == 0) {
1317+
shmem_f32[sgitg] = sumf;
1318+
}
13231319

1324-
threadgroup_barrier(mem_flags::mem_threadgroup);
1320+
threadgroup_barrier(mem_flags::mem_threadgroup);
13251321

1326-
all_sum = buf[tiisg];
1327-
all_sum = simd_sum(all_sum);
1328-
}
1322+
sumf = shmem_f32[tiisg];
1323+
sumf = simd_sum(sumf);
13291324

1330-
const float mean = all_sum/ne00;
1331-
const float scale = 1.0f/sqrt(mean + eps);
1325+
const float mean = sumf/args.ne00;
1326+
const float scale = 1.0f/sqrt(mean + args.eps);
13321327

1333-
device float4 * y = (device float4 *) (dst + tgpig*ne00);
1334-
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
1328+
device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
1329+
for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
13351330
y[i00] = x[i00] * scale;
13361331
}
13371332
}

0 commit comments

Comments
 (0)