Skip to content

Commit ae18016

Browse files
JohannesGaesslerggerganov
authored andcommitted
CUDA: non-contiguous (RMS) norm support (ggml-org#11659)
* CUDA: non-contiguous (RMS) norm support --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 2c6e2fa commit ae18016

File tree

6 files changed

+97
-47
lines changed

6 files changed

+97
-47
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include "ggml-cuda/upscale.cuh"
3939
#include "ggml-cuda/wkv6.cuh"
4040
#include "ggml-cuda/gla.cuh"
41+
#include "ggml.h"
4142

4243
#include <algorithm>
4344
#include <array>
@@ -3139,6 +3140,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
31393140
break;
31403141
case GGML_OP_NORM:
31413142
case GGML_OP_RMS_NORM:
3143+
return true;
31423144
case GGML_OP_RMS_NORM_BACK:
31433145
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
31443146
break;
@@ -3181,7 +3183,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
31813183
case GGML_OP_SUM_ROWS:
31823184
case GGML_OP_ARGSORT:
31833185
case GGML_OP_ACC:
3186+
return true;
31843187
case GGML_OP_GROUP_NORM:
3188+
return ggml_is_contiguous(op->src[0]);
31853189
case GGML_OP_UPSCALE:
31863190
case GGML_OP_PAD:
31873191
case GGML_OP_ARANGE:

ggml/src/ggml-cuda/norm.cu

Lines changed: 57 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
#include "norm.cuh"
2+
#include <cstdint>
23

34
template <int block_size>
4-
static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) {
5-
const int row = blockIdx.x*blockDim.y + threadIdx.y;
6-
const int tid = threadIdx.x;
5+
static __global__ void norm_f32(
6+
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
7+
const int64_t stride_sample, const float eps) {
8+
const int nrows = gridDim.x;
9+
const int nchannels = gridDim.y;
710

8-
x += int64_t(row)*ncols;
9-
dst += int64_t(row)*ncols;
11+
const int row = blockIdx.x;
12+
const int channel = blockIdx.y;
13+
const int sample = blockIdx.z;
14+
const int tid = threadIdx.x;
15+
16+
x += sample*stride_sample + channel*stride_channel + row*stride_row;
17+
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
1018

1119
float2 mean_var = make_float2(0.0f, 0.0f);
1220

@@ -97,12 +105,19 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
97105
}
98106

99107
template <int block_size>
100-
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
101-
const int row = blockIdx.x*blockDim.y + threadIdx.y;
102-
const int tid = threadIdx.x;
108+
static __global__ void rms_norm_f32(
109+
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
110+
const int64_t stride_sample, const float eps) {
111+
const int nrows = gridDim.x;
112+
const int nchannels = gridDim.y;
113+
114+
const int row = blockIdx.x;
115+
const int channel = blockIdx.y;
116+
const int sample = blockIdx.z;
117+
const int tid = threadIdx.x;
103118

104-
x += int64_t(row)*ncols;
105-
dst += int64_t(row)*ncols;
119+
x += sample*stride_sample + channel*stride_channel + row*stride_row;
120+
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
106121

107122
float tmp = 0.0f; // partial sum for thread in warp
108123

@@ -186,13 +201,16 @@ static __global__ void rms_norm_back_f32(
186201
}
187202
}
188203

189-
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
204+
static void norm_f32_cuda(
205+
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
206+
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
207+
const dim3 blocks_num(nrows, nchannels, nsamples);
190208
if (ncols < 1024) {
191209
const dim3 block_dims(WARP_SIZE, 1, 1);
192-
norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
210+
norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
193211
} else {
194212
const dim3 block_dims(1024, 1, 1);
195-
norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
213+
norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
196214
}
197215
}
198216

@@ -207,13 +225,16 @@ static void group_norm_f32_cuda(
207225
}
208226
}
209227

210-
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
228+
static void rms_norm_f32_cuda(
229+
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
230+
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
231+
const dim3 blocks_num(nrows, nchannels, nsamples);
211232
if (ncols < 1024) {
212233
const dim3 block_dims(WARP_SIZE, 1, 1);
213-
rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
234+
rms_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
214235
} else {
215236
const dim3 block_dims(1024, 1, 1);
216-
rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
237+
rms_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
217238
}
218239
}
219240

@@ -229,23 +250,26 @@ static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float *
229250

230251
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
231252
const ggml_tensor * src0 = dst->src[0];
232-
const float * src0_d = (const float *)src0->data;
233-
float * dst_d = (float *)dst->data;
253+
const float * src0_d = (const float *) src0->data;
254+
float * dst_d = (float *) dst->data;
234255
cudaStream_t stream = ctx.stream();
235256

236-
GGML_ASSERT(ggml_is_contiguous(src0));
237-
238257
GGML_ASSERT(src0->type == GGML_TYPE_F32);
239258
GGML_ASSERT( dst->type == GGML_TYPE_F32);
240259

241-
const int64_t ne00 = src0->ne[0];
242-
const int64_t nrows = ggml_nrows(src0);
260+
GGML_TENSOR_UNARY_OP_LOCALS;
243261

244262
float eps;
245263
memcpy(&eps, dst->op_params, sizeof(float));
246264
GGML_ASSERT(eps >= 0.0f);
247265

248-
norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
266+
const size_t ts0 = ggml_type_size(src0->type);
267+
GGML_ASSERT(nb00 == ts0);
268+
const int64_t s01 = nb01 / ts0;
269+
const int64_t s02 = nb02 / ts0;
270+
const int64_t s03 = nb03 / ts0;
271+
272+
norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
249273
}
250274

251275
void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -254,8 +278,6 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
254278
float * dst_d = (float *)dst->data;
255279
cudaStream_t stream = ctx.stream();
256280

257-
GGML_ASSERT(ggml_is_contiguous(src0));
258-
259281
GGML_ASSERT(src0->type == GGML_TYPE_F32);
260282
GGML_ASSERT( dst->type == GGML_TYPE_F32);
261283

@@ -271,23 +293,26 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
271293

272294
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
273295
const ggml_tensor * src0 = dst->src[0];
274-
const float * src0_d = (const float *)src0->data;
275-
float * dst_d = (float *)dst->data;
296+
const float * src0_d = (const float *) src0->data;
297+
float * dst_d = (float *) dst->data;
276298
cudaStream_t stream = ctx.stream();
277299

278-
GGML_ASSERT(ggml_is_contiguous(src0));
279-
280300
GGML_ASSERT(src0->type == GGML_TYPE_F32);
281301
GGML_ASSERT( dst->type == GGML_TYPE_F32);
282302

283-
const int64_t ne00 = src0->ne[0];
284-
const int64_t nrows = ggml_nrows(src0);
303+
GGML_TENSOR_UNARY_OP_LOCALS;
285304

286305
float eps;
287306
memcpy(&eps, dst->op_params, sizeof(float));
288307
GGML_ASSERT(eps >= 0.0f);
289308

290-
rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
309+
const size_t ts0 = ggml_type_size(src0->type);
310+
GGML_ASSERT(nb00 == ts0);
311+
const int64_t s01 = nb01 / ts0;
312+
const int64_t s02 = nb02 / ts0;
313+
const int64_t s03 = nb03 / ts0;
314+
315+
rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
291316
}
292317

293318
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,10 +1206,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
12061206
case GGML_OP_GROUP_NORM:
12071207
return has_simdgroup_reduction;
12081208
case GGML_OP_RMS_NORM:
1209-
return has_simdgroup_reduction && (op->ne[0] % 4 == 0);
1209+
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
12101210
case GGML_OP_ARGMAX:
1211-
case GGML_OP_NORM:
12121211
return true;
1212+
case GGML_OP_NORM:
1213+
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
12131214
case GGML_OP_ROPE:
12141215
{
12151216
const int mode = ((const int32_t *) op->op_params)[2];

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8182,9 +8182,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
81828182
case GGML_OP_VIEW:
81838183
case GGML_OP_PERMUTE:
81848184
case GGML_OP_TRANSPOSE:
8185+
return true;
81858186
case GGML_OP_NORM:
81868187
case GGML_OP_GROUP_NORM:
81878188
case GGML_OP_RMS_NORM:
8189+
return ggml_is_contiguous(op->src[0]);
81888190
case GGML_OP_ADD:
81898191
case GGML_OP_ACC:
81908192
case GGML_OP_MUL:

src/llama.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4628,7 +4628,8 @@ struct llm_build_context {
46284628
ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
46294629
cb(k_pe, "k_pe", il);
46304630

4631-
kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
4631+
// TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont
4632+
kv_compressed = ggml_cont(ctx0, kv_compressed);
46324633
kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
46334634
model.layers[il].attn_kv_a_norm, NULL,
46344635
LLM_NORM_RMS, cb, il);
@@ -6482,7 +6483,8 @@ struct llm_build_context {
64826483
ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
64836484
cb(k_pe, "k_pe", il);
64846485

6485-
kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
6486+
// TODO: the CUDA backend used to not support non-cont. (RMS) norm, investigate removing ggml_cont
6487+
kv_compressed = ggml_cont(ctx0, kv_compressed);
64866488
kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
64876489
model.layers[il].attn_kv_a_norm, NULL,
64886490
LLM_NORM_RMS, cb, il);

tests/test-backend-ops.cpp

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1674,21 +1674,28 @@ struct test_silu_back : public test_case {
16741674
struct test_norm : public test_case {
16751675
const ggml_type type;
16761676
const std::array<int64_t, 4> ne;
1677-
float eps;
1677+
const bool v; // whether a is a non-contiguous view
1678+
const float eps;
16781679

16791680
std::string vars() override {
1680-
return VARS_TO_STR3(type, ne, eps);
1681+
return VARS_TO_STR4(type, ne, v, eps);
16811682
}
16821683

16831684
test_norm(ggml_type type = GGML_TYPE_F32,
16841685
std::array<int64_t, 4> ne = {64, 5, 4, 3},
1686+
bool v = false,
16851687
float eps = 1e-6f)
1686-
: type(type), ne(ne), eps(eps) {}
1688+
: type(type), ne(ne), v(v), eps(eps) {}
16871689

16881690
ggml_tensor * build_graph(ggml_context * ctx) override {
16891691
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
16901692
ggml_set_name(a, "a");
16911693

1694+
if (v) {
1695+
a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0);
1696+
ggml_set_name(a, "view of a");
1697+
}
1698+
16921699
ggml_tensor * out = ggml_norm(ctx, a, eps);
16931700
ggml_set_name(out, "out");
16941701

@@ -1700,22 +1707,29 @@ struct test_norm : public test_case {
17001707
struct test_rms_norm : public test_case {
17011708
const ggml_type type;
17021709
const std::array<int64_t, 4> ne;
1703-
float eps;
1710+
const bool v; // whether a is a non-contiguous view
1711+
const float eps;
17041712

17051713
std::string vars() override {
1706-
return VARS_TO_STR3(type, ne, eps);
1714+
return VARS_TO_STR4(type, ne, v, eps);
17071715
}
17081716

17091717
test_rms_norm(ggml_type type = GGML_TYPE_F32,
17101718
std::array<int64_t, 4> ne = {64, 5, 4, 3},
1719+
bool v = false,
17111720
float eps = 1e-6f)
1712-
: type(type), ne(ne), eps(eps) {}
1721+
: type(type), ne(ne), v(v), eps(eps) {}
17131722

17141723
ggml_tensor * build_graph(ggml_context * ctx) override {
17151724
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
17161725
ggml_set_param(ctx, a);
17171726
ggml_set_name(a, "a");
17181727

1728+
if (v) {
1729+
a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0);
1730+
ggml_set_name(a, "view of a");
1731+
}
1732+
17191733
ggml_tensor * out = ggml_rms_norm(ctx, a, eps);
17201734
ggml_set_name(out, "out");
17211735

@@ -1741,7 +1755,7 @@ struct test_rms_norm : public test_case {
17411755
struct test_rms_norm_back : public test_case {
17421756
const ggml_type type;
17431757
const std::array<int64_t, 4> ne;
1744-
float eps;
1758+
const float eps;
17451759

17461760
std::string vars() override {
17471761
return VARS_TO_STR3(type, ne, eps);
@@ -2919,7 +2933,7 @@ struct test_group_norm : public test_case {
29192933
const float eps;
29202934

29212935
std::string vars() override {
2922-
return VARS_TO_STR3(type, ne, num_groups);
2936+
return VARS_TO_STR4(type, ne, num_groups, eps);
29232937
}
29242938

29252939
test_group_norm(ggml_type type = GGML_TYPE_F32,
@@ -3964,9 +3978,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
39643978
test_cases.emplace_back(new test_scale());
39653979
test_cases.emplace_back(new test_silu_back());
39663980

3967-
for (float eps : {0.0f, 1e-7f, 1e-4f, 1e-1f}) {
3968-
test_cases.emplace_back(new test_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
3969-
test_cases.emplace_back(new test_rms_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
3981+
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {
3982+
for (bool v : {false, true}) {
3983+
test_cases.emplace_back(new test_norm (GGML_TYPE_F32, {64, 5, 4, 3}, v, eps));
3984+
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, v, eps));
3985+
}
39703986
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
39713987
}
39723988

0 commit comments

Comments
 (0)