Skip to content

Commit 202084d

Browse files
JohannesGaesslerggerganov
authored andcommitted
tests: add gradient tests for all backends (ggml/932)
* tests: add gradient checking to test-backend-ops * remove old comment * reorder includes * adjust SIN/COS parameters * add documentation, use supports_op if possible
1 parent dbbebca commit 202084d

File tree

10 files changed

+1080
-92
lines changed

10 files changed

+1080
-92
lines changed

ggml/include/ggml.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,7 +1272,7 @@ extern "C" {
12721272
size_t nb1,
12731273
size_t nb2,
12741274
size_t nb3,
1275-
size_t offset);
1275+
size_t offset); // in bytes
12761276

12771277
// b -> view(a,offset,nb1,nb2,3), return view(a)
12781278
GGML_API struct ggml_tensor * ggml_set_inplace(
@@ -1282,35 +1282,35 @@ extern "C" {
12821282
size_t nb1,
12831283
size_t nb2,
12841284
size_t nb3,
1285-
size_t offset);
1285+
size_t offset); // in bytes
12861286

12871287
GGML_API struct ggml_tensor * ggml_set_1d(
12881288
struct ggml_context * ctx,
12891289
struct ggml_tensor * a,
12901290
struct ggml_tensor * b,
1291-
size_t offset);
1291+
size_t offset); // in bytes
12921292

12931293
GGML_API struct ggml_tensor * ggml_set_1d_inplace(
12941294
struct ggml_context * ctx,
12951295
struct ggml_tensor * a,
12961296
struct ggml_tensor * b,
1297-
size_t offset);
1297+
size_t offset); // in bytes
12981298

12991299
// b -> view(a,offset,nb1,nb2,3), return modified a
13001300
GGML_API struct ggml_tensor * ggml_set_2d(
13011301
struct ggml_context * ctx,
13021302
struct ggml_tensor * a,
13031303
struct ggml_tensor * b,
13041304
size_t nb1,
1305-
size_t offset);
1305+
size_t offset); // in bytes
13061306

13071307
// b -> view(a,offset,nb1,nb2,3), return view(a)
13081308
GGML_API struct ggml_tensor * ggml_set_2d_inplace(
13091309
struct ggml_context * ctx,
13101310
struct ggml_tensor * a,
13111311
struct ggml_tensor * b,
13121312
size_t nb1,
1313-
size_t offset);
1313+
size_t offset); // in bytes
13141314

13151315
// a -> b, return view(b)
13161316
GGML_API struct ggml_tensor * ggml_cpy(

ggml/src/ggml-backend.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,10 @@ GGML_CALL static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const
827827
op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
828828
case GGML_OP_MUL_MAT:
829829
return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type;
830+
case GGML_OP_ROPE_BACK:
831+
return op->src[2] == NULL && (op->op_params[2] & 4) == 0;
832+
case GGML_OP_IM2COL_BACK:
833+
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
830834
default:
831835
return true;
832836
}

ggml/src/ggml-cuda.cu

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "ggml-cuda/rope.cuh"
2828
#include "ggml-cuda/scale.cuh"
2929
#include "ggml-cuda/softmax.cuh"
30+
#include "ggml-cuda/sum.cuh"
3031
#include "ggml-cuda/sumrows.cuh"
3132
#include "ggml-cuda/tsembd.cuh"
3233
#include "ggml-cuda/unary.cuh"
@@ -2180,6 +2181,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
21802181
ggml_cuda_dup(ctx, dst);
21812182
break;
21822183
case GGML_OP_ADD:
2184+
case GGML_OP_ADD1: // TODO: more efficient implementation
21832185
ggml_cuda_op_add(ctx, dst);
21842186
break;
21852187
case GGML_OP_SUB:
@@ -2196,6 +2198,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
21962198
break;
21972199
case GGML_OP_UNARY:
21982200
switch (ggml_get_unary_op(dst)) {
2201+
case GGML_UNARY_OP_NEG:
2202+
ggml_cuda_op_neg(ctx, dst);
2203+
break;
21992204
case GGML_UNARY_OP_GELU:
22002205
ggml_cuda_op_gelu(ctx, dst);
22012206
break;
@@ -2304,6 +2309,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23042309
case GGML_OP_POOL_2D:
23052310
ggml_cuda_op_pool2d(ctx, dst);
23062311
break;
2312+
case GGML_OP_SUM:
2313+
ggml_cuda_op_sum(ctx, dst);
2314+
break;
23072315
case GGML_OP_SUM_ROWS:
23082316
ggml_cuda_op_sum_rows(ctx, dst);
23092317
break;
@@ -2748,6 +2756,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
27482756
switch (op->op) {
27492757
case GGML_OP_UNARY:
27502758
switch (ggml_get_unary_op(op)) {
2759+
case GGML_UNARY_OP_NEG:
27512760
case GGML_UNARY_OP_GELU:
27522761
case GGML_UNARY_OP_SILU:
27532762
case GGML_UNARY_OP_RELU:
@@ -2877,6 +2886,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
28772886
case GGML_OP_TRANSPOSE:
28782887
case GGML_OP_NORM:
28792888
case GGML_OP_ADD:
2889+
case GGML_OP_ADD1:
28802890
case GGML_OP_SUB:
28812891
case GGML_OP_MUL:
28822892
case GGML_OP_DIV:
@@ -2896,7 +2906,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
28962906
case GGML_OP_ROPE:
28972907
return ggml_is_contiguous(op->src[0]);
28982908
case GGML_OP_IM2COL:
2909+
return op->src[0]->type == GGML_TYPE_F16;
28992910
case GGML_OP_POOL_2D:
2911+
case GGML_OP_SUM:
29002912
case GGML_OP_SUM_ROWS:
29012913
case GGML_OP_ARGSORT:
29022914
case GGML_OP_ACC:

ggml/src/ggml-cuda/cross-entropy-loss.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#include "common.cuh"
22
#include "cross-entropy-loss.cuh"
3-
#include "sumrows.cuh"
3+
#include "sum.cuh"
44

55
#include <cmath>
66
#include <cstdint>
@@ -102,5 +102,5 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor *
102102
cross_entropy_loss_f32<<<blocks_num, blocks_dim, shmem, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
103103

104104
// Combine results from individual blocks:
105-
sum_rows_f32_cuda(dst_tmp.ptr, dst_d, blocks_num.x, 1, stream);
105+
sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream);
106106
}

ggml/src/ggml-cuda/sum.cu

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#include "sumrows.cuh"
2+
#include "sum.cuh"
3+
4+
#include <cstdint>
5+
6+
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
7+
#include <cub/cub.cuh>
8+
using namespace cub;
9+
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
10+
11+
void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream) {
12+
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
13+
size_t tmp_size = 0;
14+
DeviceReduce::Sum(nullptr, tmp_size, x, dst, ne, stream);
15+
ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
16+
DeviceReduce::Sum(tmp_alloc.ptr, tmp_size, x, dst, ne, stream);
17+
#else
18+
// Use (inefficient) sum_rows implementation as a fallback.
19+
// For AMD there is rocPRIM which could be used as a drop-in replacement via hipcub but this would require C++11 -> C++14.
20+
sum_rows_f32_cuda(x, dst, ne, 1, stream);
21+
GGML_UNUSED(pool);
22+
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
23+
}
24+
25+
void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
26+
const ggml_tensor * src0 = dst->src[0];
27+
28+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
29+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
30+
GGML_ASSERT(ggml_is_contiguous(src0));
31+
32+
const float * src0_d = (const float *) src0->data;
33+
float * dst_d = (float *) dst->data;
34+
35+
const int64_t ne = ggml_nelements(src0);
36+
37+
ggml_cuda_pool & pool = ctx.pool();
38+
cudaStream_t stream = ctx.stream();
39+
40+
sum_f32_cuda(pool, src0_d, dst_d, ne, stream);
41+
}

ggml/src/ggml-cuda/sum.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include "common.cuh"
2+
3+
void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream);
4+
5+
void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

ggml/src/ggml-cuda/unary.cu

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
#include "unary.cuh"
22

3+
static __global__ void neg_f32(const float * x, float * dst, const int k) {
4+
const int i = blockDim.x*blockIdx.x + threadIdx.x;
5+
6+
if (i >= k) {
7+
return;
8+
}
9+
10+
dst[i] = -x[i];
11+
}
12+
313
static __global__ void gelu_f32(const float * x, float * dst, const int k) {
414
const float GELU_COEF_A = 0.044715f;
515
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
@@ -119,6 +129,11 @@ static __global__ void cos_f32(const float * x, float * dst, const int k) {
119129
dst[i] = cosf(x[i]);
120130
}
121131

132+
static void neg_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
133+
const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE;
134+
neg_f32<<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
135+
}
136+
122137
static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
123138
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
124139
gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -184,6 +199,20 @@ static void cos_f32_cuda(const float * x, float * dst, const int k, cudaStream_t
184199
cos_f32<<<num_blocks, CUDA_COS_BLOCK_SIZE, 0, stream>>>(x, dst, k);
185200
}
186201

202+
void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
203+
const ggml_tensor * src0 = dst->src[0];
204+
const float * src0_d = (const float *)src0->data;
205+
float * dst_d = (float *)dst->data;
206+
cudaStream_t stream = ctx.stream();
207+
208+
GGML_ASSERT(ggml_is_contiguous(src0));
209+
210+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
211+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
212+
213+
neg_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
214+
}
215+
187216
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
188217
const ggml_tensor * src0 = dst->src[0];
189218
const float * src0_d = (const float *)src0->data;

ggml/src/ggml-cuda/unary.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "common.cuh"
22

3+
#define CUDA_NEG_BLOCK_SIZE 256
34
#define CUDA_GELU_BLOCK_SIZE 256
45
#define CUDA_SILU_BLOCK_SIZE 256
56
#define CUDA_TANH_BLOCK_SIZE 256
@@ -12,6 +13,8 @@
1213
#define CUDA_SIN_BLOCK_SIZE 256
1314
#define CUDA_COS_BLOCK_SIZE 256
1415

16+
void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
17+
1518
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
1619

1720
void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

ggml/src/ggml.c

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5267,6 +5267,7 @@ struct ggml_tensor * ggml_concat(
52675267
bool is_node = false;
52685268

52695269
if (a->grad || b->grad) {
5270+
GGML_ABORT("fatal error"); // TODO: implement
52705271
is_node = true;
52715272
}
52725273

@@ -5388,6 +5389,7 @@ struct ggml_tensor * ggml_leaky_relu(
53885389
bool is_node = false;
53895390

53905391
if (!inplace && (a->grad)) {
5392+
GGML_ABORT("fatal error"); // TODO: not implemented
53915393
is_node = true;
53925394
}
53935395

@@ -5826,6 +5828,7 @@ static struct ggml_tensor * ggml_set_impl(
58265828
// make a view of the destination
58275829
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
58285830

5831+
GGML_ASSERT(offset < (size_t)(1 << 30));
58295832
int32_t params[] = { nb1, nb2, nb3, offset, inplace ? 1 : 0 };
58305833
ggml_set_op_params(result, params, sizeof(params));
58315834

@@ -6783,14 +6786,12 @@ struct ggml_tensor * ggml_rope_back(
67836786
GGML_ASSERT(ggml_is_vector(b));
67846787
GGML_ASSERT(b->type == GGML_TYPE_I32);
67856788
GGML_ASSERT(a->ne[2] == b->ne[0]);
6786-
GGML_ASSERT(c == NULL && "freq factors not implemented yet");
6787-
6788-
GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
67896789

67906790
bool is_node = false;
67916791

67926792
if (a->grad) {
6793-
is_node = false; // TODO: implement backward
6793+
GGML_ASSERT(false && "backwards pass not implemented");
6794+
is_node = false;
67946795
}
67956796

67966797
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
@@ -6808,6 +6809,7 @@ struct ggml_tensor * ggml_rope_back(
68086809
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
68096810
result->src[0] = a;
68106811
result->src[1] = b;
6812+
result->src[2] = c;
68116813

68126814
return result;
68136815
}
@@ -7361,6 +7363,11 @@ struct ggml_tensor * ggml_argsort(
73617363
enum ggml_sort_order order) {
73627364
bool is_node = false;
73637365

7366+
if (a->grad) {
7367+
GGML_ABORT("fatal error"); // TODO: not implemented
7368+
is_node = true;
7369+
}
7370+
73647371
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, GGML_MAX_DIMS, a->ne);
73657372

73667373
ggml_set_op_params_i32(result, 0, (int32_t) order);
@@ -10953,9 +10960,6 @@ static void ggml_compute_forward_sum_f32(
1095310960
return;
1095410961
}
1095510962

10956-
assert(ggml_is_scalar(dst));
10957-
10958-
1095910963
assert(ggml_is_scalar(dst));
1096010964
assert(src0->nb[0] == sizeof(float));
1096110965

@@ -18356,14 +18360,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
1835618360
if (src0->grad || src1->grad) {
1835718361
GGML_ASSERT(src0->type == tensor->type);
1835818362
GGML_ASSERT(tensor->grad->type == tensor->type);
18359-
GGML_ASSERT(tensor->grad->type == src1->grad->type);
18363+
GGML_ASSERT(!src1->grad || src1->grad->type == tensor->grad->type);
1836018364

1836118365
tensor_grad_view = ggml_view_4d(ctx,
18362-
tensor->grad,
18363-
src1->grad->ne[0],
18364-
src1->grad->ne[1],
18365-
src1->grad->ne[2],
18366-
src1->grad->ne[3],
18366+
tensor->grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
1836718367
nb1, nb2, nb3, offset);
1836818368
}
1836918369

@@ -18432,9 +18432,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
1843218432

1843318433
memcpy(&offset, tensor->op_params, sizeof(offset));
1843418434

18435-
size_t nb1 = tensor->nb[1];
18436-
size_t nb2 = tensor->nb[2];
18437-
size_t nb3 = tensor->nb[3];
18435+
size_t nb1 = tensor->nb[1];
18436+
size_t nb2 = tensor->nb[2];
18437+
size_t nb3 = tensor->nb[3];
1843818438

1843918439
if (src0->type != src0->grad->type) {
1844018440
// gradient is typically F32, but src0 could be other type

0 commit comments

Comments
 (0)