1
1
#include " norm.cuh"
2
+ #include < cstdint>
2
3
3
4
template <int block_size>
4
- static __global__ void norm_f32 (const float * x, float * dst, const int ncols, const float eps) {
5
- const int row = blockIdx .x *blockDim .y + threadIdx .y ;
6
- const int tid = threadIdx .x ;
5
+ static __global__ void norm_f32 (
6
+ const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
7
+ const int64_t stride_sample, const float eps) {
8
+ const int nrows = gridDim .x ;
9
+ const int nchannels = gridDim .y ;
7
10
8
- x += int64_t (row)*ncols;
9
- dst += int64_t (row)*ncols;
11
+ const int row = blockIdx .x ;
12
+ const int channel = blockIdx .y ;
13
+ const int sample = blockIdx .z ;
14
+ const int tid = threadIdx .x ;
15
+
16
+ x += sample*stride_sample + channel*stride_channel + row*stride_row;
17
+ dst += ((sample*nchannels + channel)*nrows + row)*ncols;
10
18
11
19
float2 mean_var = make_float2 (0 .0f , 0 .0f );
12
20
@@ -97,12 +105,19 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
97
105
}
98
106
99
107
template <int block_size>
100
- static __global__ void rms_norm_f32 (const float * x, float * dst, const int ncols, const float eps) {
101
- const int row = blockIdx .x *blockDim .y + threadIdx .y ;
102
- const int tid = threadIdx .x ;
108
+ static __global__ void rms_norm_f32 (
109
+ const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
110
+ const int64_t stride_sample, const float eps) {
111
+ const int nrows = gridDim .x ;
112
+ const int nchannels = gridDim .y ;
113
+
114
+ const int row = blockIdx .x ;
115
+ const int channel = blockIdx .y ;
116
+ const int sample = blockIdx .z ;
117
+ const int tid = threadIdx .x ;
103
118
104
- x += int64_t ( row)*ncols ;
105
- dst += int64_t ( row)*ncols;
119
+ x += sample*stride_sample + channel*stride_channel + row*stride_row ;
120
+ dst += ((sample*nchannels + channel)*nrows + row)*ncols;
106
121
107
122
float tmp = 0 .0f ; // partial sum for thread in warp
108
123
@@ -186,13 +201,16 @@ static __global__ void rms_norm_back_f32(
186
201
}
187
202
}
188
203
189
- static void norm_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
204
+ static void norm_f32_cuda (
205
+ const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
206
+ const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
207
+ const dim3 blocks_num (nrows, nchannels, nsamples);
190
208
if (ncols < 1024 ) {
191
209
const dim3 block_dims (WARP_SIZE, 1 , 1 );
192
- norm_f32<WARP_SIZE><<<nrows , block_dims, 0 , stream>>> (x, dst, ncols, eps);
210
+ norm_f32<WARP_SIZE><<<blocks_num , block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample , eps);
193
211
} else {
194
212
const dim3 block_dims (1024 , 1 , 1 );
195
- norm_f32<1024 ><<<nrows , block_dims, 0 , stream>>> (x, dst, ncols, eps);
213
+ norm_f32<1024 ><<<blocks_num , block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample , eps);
196
214
}
197
215
}
198
216
@@ -207,13 +225,16 @@ static void group_norm_f32_cuda(
207
225
}
208
226
}
209
227
210
- static void rms_norm_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
228
+ static void rms_norm_f32_cuda (
229
+ const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
230
+ const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
231
+ const dim3 blocks_num (nrows, nchannels, nsamples);
211
232
if (ncols < 1024 ) {
212
233
const dim3 block_dims (WARP_SIZE, 1 , 1 );
213
- rms_norm_f32<WARP_SIZE><<<nrows , block_dims, 0 , stream>>> (x, dst, ncols, eps);
234
+ rms_norm_f32<WARP_SIZE><<<blocks_num , block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample , eps);
214
235
} else {
215
236
const dim3 block_dims (1024 , 1 , 1 );
216
- rms_norm_f32<1024 ><<<nrows , block_dims, 0 , stream>>> (x, dst, ncols, eps);
237
+ rms_norm_f32<1024 ><<<blocks_num , block_dims, 0 , stream>>> (x, dst, ncols, stride_row, stride_channel, stride_sample , eps);
217
238
}
218
239
}
219
240
@@ -229,23 +250,26 @@ static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float *
229
250
230
251
void ggml_cuda_op_norm (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
231
252
const ggml_tensor * src0 = dst->src [0 ];
232
- const float * src0_d = (const float *)src0->data ;
233
- float * dst_d = (float *)dst->data ;
253
+ const float * src0_d = (const float *) src0->data ;
254
+ float * dst_d = (float *) dst->data ;
234
255
cudaStream_t stream = ctx.stream ();
235
256
236
- GGML_ASSERT (ggml_is_contiguous (src0));
237
-
238
257
GGML_ASSERT (src0->type == GGML_TYPE_F32);
239
258
GGML_ASSERT ( dst->type == GGML_TYPE_F32);
240
259
241
- const int64_t ne00 = src0->ne [0 ];
242
- const int64_t nrows = ggml_nrows (src0);
260
+ GGML_TENSOR_UNARY_OP_LOCALS;
243
261
244
262
float eps;
245
263
memcpy (&eps, dst->op_params , sizeof (float ));
246
264
GGML_ASSERT (eps >= 0 .0f );
247
265
248
- norm_f32_cuda (src0_d, dst_d, ne00, nrows, eps, stream);
266
+ const size_t ts0 = ggml_type_size (src0->type );
267
+ GGML_ASSERT (nb00 == ts0);
268
+ const int64_t s01 = nb01 / ts0;
269
+ const int64_t s02 = nb02 / ts0;
270
+ const int64_t s03 = nb03 / ts0;
271
+
272
+ norm_f32_cuda (src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
249
273
}
250
274
251
275
void ggml_cuda_op_group_norm (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -254,8 +278,6 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
254
278
float * dst_d = (float *)dst->data ;
255
279
cudaStream_t stream = ctx.stream ();
256
280
257
- GGML_ASSERT (ggml_is_contiguous (src0));
258
-
259
281
GGML_ASSERT (src0->type == GGML_TYPE_F32);
260
282
GGML_ASSERT ( dst->type == GGML_TYPE_F32);
261
283
@@ -271,23 +293,26 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
271
293
272
294
void ggml_cuda_op_rms_norm (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
273
295
const ggml_tensor * src0 = dst->src [0 ];
274
- const float * src0_d = (const float *)src0->data ;
275
- float * dst_d = (float *)dst->data ;
296
+ const float * src0_d = (const float *) src0->data ;
297
+ float * dst_d = (float *) dst->data ;
276
298
cudaStream_t stream = ctx.stream ();
277
299
278
- GGML_ASSERT (ggml_is_contiguous (src0));
279
-
280
300
GGML_ASSERT (src0->type == GGML_TYPE_F32);
281
301
GGML_ASSERT ( dst->type == GGML_TYPE_F32);
282
302
283
- const int64_t ne00 = src0->ne [0 ];
284
- const int64_t nrows = ggml_nrows (src0);
303
+ GGML_TENSOR_UNARY_OP_LOCALS;
285
304
286
305
float eps;
287
306
memcpy (&eps, dst->op_params , sizeof (float ));
288
307
GGML_ASSERT (eps >= 0 .0f );
289
308
290
- rms_norm_f32_cuda (src0_d, dst_d, ne00, nrows, eps, stream);
309
+ const size_t ts0 = ggml_type_size (src0->type );
310
+ GGML_ASSERT (nb00 == ts0);
311
+ const int64_t s01 = nb01 / ts0;
312
+ const int64_t s02 = nb02 / ts0;
313
+ const int64_t s03 = nb03 / ts0;
314
+
315
+ rms_norm_f32_cuda (src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
291
316
}
292
317
293
318
void ggml_cuda_op_rms_norm_back (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
0 commit comments