Skip to content

Commit d76a86d

Browse files
authored
opencl: Noncontiguous norm, rms_norm, disable fp16 for some ops (#12217)
* opencl: support noncontiguous `norm` * opencl: support noncontiguous `rms_norm` * opencl: disable fp16 for `ADD`, `MUL`, `SCALE`, `RELU`, `GELU`, `SILU`, `CLAMP`
1 parent 776f9e5 commit d76a86d

File tree

2 files changed

+65
-31
lines changed

2 files changed

+65
-31
lines changed

ggml/src/ggml-opencl/ggml-opencl.cpp

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,17 +1007,18 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
10071007
case GGML_OP_ADD:
10081008
case GGML_OP_SCALE:
10091009
case GGML_OP_MUL:
1010-
return true;
1010+
return op->src[0]->type == GGML_TYPE_F32;
10111011
case GGML_OP_UNARY:
10121012
switch (ggml_get_unary_op(op)) {
10131013
case GGML_UNARY_OP_GELU:
10141014
case GGML_UNARY_OP_SILU:
10151015
case GGML_UNARY_OP_RELU:
1016-
return ggml_is_contiguous(op->src[0]);
1016+
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
10171017
default:
10181018
return false;
10191019
}
10201020
case GGML_OP_CLAMP:
1021+
return op->src[0]->type == GGML_TYPE_F32;
10211022
case GGML_OP_SOFT_MAX:
10221023
case GGML_OP_NORM:
10231024
case GGML_OP_RMS_NORM:
@@ -2573,26 +2574,33 @@ static void ggml_cl_norm(ggml_backend_t backend, const ggml_tensor * src0, const
25732574
memcpy(&eps, dst->op_params, sizeof(float));
25742575

25752576
const int ne00 = src0 ? src0->ne[0] : 0;
2576-
const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
2577+
const int ne01 = src0 ? src0->ne[1] : 0;
2578+
const int ne02 = src0 ? src0->ne[2] : 0;
2579+
const int ne03 = src0 ? src0->ne[3] : 0;
25772580

2578-
GGML_ASSERT(ggml_is_contiguous_1(src0));
2581+
const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
2582+
const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
2583+
const cl_ulong nb03 = src0 ? src0->nb[3] : 0;
25792584

25802585
const int nth = MIN(64, ne00);
25812586

25822587
cl_kernel kernel = backend_ctx->kernel_norm;
25832588

2584-
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
2585-
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
2586-
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
2587-
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
2588-
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
2589-
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb01));
2590-
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(float), &eps));
2591-
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(float)*nth, NULL));
2592-
2593-
const int64_t nrows = ggml_nrows(src0);
2589+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
2590+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
2591+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
2592+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
2593+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
2594+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
2595+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
2596+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
2597+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
2598+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
2599+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));
2600+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &eps));
2601+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth, NULL));
25942602

2595-
size_t global_work_size[] = {(size_t)nrows*nth, 1, 1};
2603+
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
25962604
size_t local_work_size[] = {(size_t)nth, 1, 1};
25972605

25982606
#ifdef GGML_OPENCL_PROFILING
@@ -2630,16 +2638,19 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c
26302638
memcpy(&eps, dst->op_params, sizeof(float));
26312639

26322640
const int ne00 = src0 ? src0->ne[0] : 0;
2641+
const int ne01 = src0 ? src0->ne[1] : 0;
2642+
const int ne02 = src0 ? src0->ne[2] : 0;
2643+
const int ne03 = src0 ? src0->ne[3] : 0;
2644+
26332645
const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
2646+
const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
2647+
const cl_ulong nb03 = src0 ? src0->nb[3] : 0;
26342648

26352649
GGML_ASSERT(ne00 % 4 == 0);
2636-
GGML_ASSERT(ggml_is_contiguous_1(src0));
26372650

26382651
const int nth = MIN(64, ne00);
26392652

2640-
const int64_t nrows = ggml_nrows(src0);
2641-
2642-
size_t global_work_size[] = {(size_t)nrows*nth, 1, 1};
2653+
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
26432654
size_t local_work_size[] = {(size_t)nth, 1, 1};
26442655

26452656
cl_kernel kernel = backend_ctx->kernel_rms_norm;
@@ -2654,15 +2665,20 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c
26542665
sizeof(local_work_size), local_work_size,
26552666
sizeof(size_t), &sgs, NULL));
26562667

2657-
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
2658-
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
2659-
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
2660-
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
2661-
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
2662-
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb01));
2663-
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(float), &eps));
2668+
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
2669+
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
2670+
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
2671+
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
2672+
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
2673+
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
2674+
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
2675+
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
2676+
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
2677+
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
2678+
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));
2679+
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &eps));
26642680
// This is local memory - the size depends on subgroup size.
2665-
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(float)*nth/sgs, NULL));
2681+
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth/sgs, NULL));
26662682

26672683
#ifdef GGML_OPENCL_PROFILING
26682684
cl_event evt;

ggml/src/ggml-opencl/kernels/ggml-opencl.cl

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -506,14 +506,23 @@ kernel void kernel_norm(
506506
global float * dst,
507507
ulong offsetd,
508508
int ne00,
509+
int ne01,
510+
int ne02,
511+
int ne03,
509512
ulong nb01,
513+
ulong nb02,
514+
ulong nb03,
510515
float eps,
511516
local float * sum
512517
) {
513518
src0 = (global void*)((global char*)src0 + offset0);
514519
dst = (global void*)((global char*)dst + offsetd);
515520

516-
global float * x = (global float *) ((global char *) src0 + get_group_id(0)*nb01);
521+
int i03 = get_group_id(2);
522+
int i02 = get_group_id(1);
523+
int i01 = get_group_id(0);
524+
525+
global float * x = (global float *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01);
517526

518527
// MEAN
519528
// parallel sum
@@ -533,7 +542,7 @@ kernel void kernel_norm(
533542

534543
// recenter and VARIANCE
535544
barrier(CLK_LOCAL_MEM_FENCE);
536-
global float * y = dst + get_group_id(0)*ne00;
545+
global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
537546
sum[get_local_id(0)] = 0.0f;
538547
for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
539548
y[i00] = x[i00] - mean;
@@ -566,14 +575,23 @@ kernel void kernel_rms_norm(
566575
global float * dst,
567576
ulong offsetd,
568577
int ne00,
578+
int ne01,
579+
int ne02,
580+
int ne03,
569581
ulong nb01,
582+
ulong nb02,
583+
ulong nb03,
570584
float eps,
571585
local float * sum // Note, the size depends on number of subgroups
572586
) {
573587
src0 = (global void*)((global char*)src0 + offset0);
574588
dst = (global float*)((global char*)dst + offsetd);
575589

576-
global float4 * x = (global float4 *) ((global char *) src0 + get_group_id(0)*nb01);
590+
int i03 = get_group_id(2);
591+
int i02 = get_group_id(1);
592+
int i01 = get_group_id(0);
593+
594+
global float4 * x = (global float4 *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01);
577595
global float * x_scalar = (global float *) x;
578596
float4 sumf = 0;
579597
float all_sum = 0;
@@ -607,7 +625,7 @@ kernel void kernel_rms_norm(
607625
const float mean = sum[0];
608626
const float scale = 1.0f/sqrt(mean + eps);
609627

610-
global float4 * y = (global float4 *) (dst + get_group_id(0)*ne00);
628+
global float4 * y = (global float4 *) (dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
611629
global float * y_scalar = (global float *) y;
612630
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
613631
y[i00] = x[i00] * scale;

0 commit comments

Comments
 (0)