Skip to content

Commit 43b35e3

Browse files
Add support for sqrt on CUDA (#7953)
* cuda sqrt support * enable cuda in pca * fix comments in pca * add test * add sqrt to ggml_backend_cuda_supports_op * fix test * new line * Use F32 sqrtf instead of F64 sqrt Co-authored-by: Johannes Gäßler <[email protected]> --------- Co-authored-by: Johannes Gäßler <[email protected]>
1 parent 19b7a83 commit 43b35e3

File tree

5 files changed

+71
-8
lines changed

5 files changed

+71
-8
lines changed

examples/cvector-generator/pca.hpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,15 @@ struct pca_model {
6464
struct ggml_tensor * dev_eigenvector;
6565

6666
pca_model(struct ggml_tensor * t_input) {
67-
// TODO: enable GPU support when support for GGML_OP_SQRT is added
68-
// #ifdef GGML_USE_CUDA
69-
// fprintf(stderr, "%s: using CUDA backend\n", __func__);
70-
// backend = ggml_backend_cuda_init(0); // init device 0
71-
// if (!backend) {
72-
// fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
73-
// }
74-
// #endif
67+
#ifdef GGML_USE_CUDA
68+
fprintf(stderr, "%s: using CUDA backend\n", __func__);
69+
backend = ggml_backend_cuda_init(0); // init device 0
70+
if (!backend) {
71+
fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
72+
}
73+
#endif
7574

75+
// TODO: enable Metal support when support for GGML_OP_SQRT is added
7676
// #ifdef GGML_USE_METAL
7777
// fprintf(stderr, "%s: using Metal backend\n", __func__);
7878
// backend = ggml_backend_metal_init();

ggml-cuda.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2267,6 +2267,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
22672267
case GGML_OP_SQR:
22682268
ggml_cuda_op_sqr(ctx, dst);
22692269
break;
2270+
case GGML_OP_SQRT:
2271+
ggml_cuda_op_sqrt(ctx, dst);
2272+
break;
22702273
case GGML_OP_CLAMP:
22712274
ggml_cuda_op_clamp(ctx, dst);
22722275
break;
@@ -2830,6 +2833,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
28302833
case GGML_OP_RMS_NORM:
28312834
case GGML_OP_SCALE:
28322835
case GGML_OP_SQR:
2836+
case GGML_OP_SQRT:
28332837
case GGML_OP_CLAMP:
28342838
case GGML_OP_CONT:
28352839
case GGML_OP_DIAG_MASK_INF:

ggml-cuda/unary.cu

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,15 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
9292
dst[i] = x[i] * x[i];
9393
}
9494

95+
static __global__ void sqrt_f32(const float * x, float * dst, const int k) {
96+
const int i = blockDim.x*blockIdx.x + threadIdx.x;
97+
98+
if (i >= k) {
99+
return;
100+
}
101+
dst[i] = sqrtf(x[i]);
102+
}
103+
95104
static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
96105
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
97106
gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -142,6 +151,11 @@ static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t
142151
sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
143152
}
144153

154+
static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
155+
const int num_blocks = (k + CUDA_SQRT_BLOCK_SIZE - 1) / CUDA_SQRT_BLOCK_SIZE;
156+
sqrt_f32<<<num_blocks, CUDA_SQRT_BLOCK_SIZE, 0, stream>>>(x, dst, k);
157+
}
158+
145159
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
146160
const ggml_tensor * src0 = dst->src[0];
147161
const float * src0_d = (const float *)src0->data;
@@ -284,3 +298,17 @@ void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
284298

285299
sqr_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
286300
}
301+
302+
void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
303+
const ggml_tensor * src0 = dst->src[0];
304+
const float * src0_d = (const float *)src0->data;
305+
float * dst_d = (float *)dst->data;
306+
cudaStream_t stream = ctx.stream();
307+
308+
GGML_ASSERT(ggml_is_contiguous(src0));
309+
310+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
311+
GGML_ASSERT( dst->type == GGML_TYPE_F32);
312+
313+
sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
314+
}

ggml-cuda/unary.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#define CUDA_HARDSIGMOID_BLOCK_SIZE 256
99
#define CUDA_HARDSWISH_BLOCK_SIZE 256
1010
#define CUDA_SQR_BLOCK_SIZE 256
11+
#define CUDA_SQRT_BLOCK_SIZE 256
1112

1213
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
1314

@@ -28,3 +29,5 @@ void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
2829
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
2930

3031
void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
32+
33+
void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

tests/test-backend-ops.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,6 +1063,33 @@ struct test_sqr : public test_case {
10631063
}
10641064
};
10651065

1066+
// GGML_OP_SQRT
1067+
struct test_sqrt : public test_case {
1068+
const ggml_type type;
1069+
const std::array<int64_t, 4> ne;
1070+
1071+
std::string vars() override {
1072+
return VARS_TO_STR2(type, ne);
1073+
}
1074+
1075+
test_sqrt(ggml_type type = GGML_TYPE_F32,
1076+
std::array<int64_t, 4> ne = {10, 10, 10, 10})
1077+
: type(type), ne(ne) {}
1078+
1079+
ggml_tensor * build_graph(ggml_context * ctx) override {
1080+
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
1081+
ggml_tensor * out = ggml_sqrt(ctx, a);
1082+
return out;
1083+
}
1084+
1085+
void initialize_tensors(ggml_context * ctx) override {
1086+
// fill with positive values
1087+
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
1088+
init_tensor_uniform(t, 0.0f, 100.0f);
1089+
}
1090+
}
1091+
};
1092+
10661093
// GGML_OP_CLAMP
10671094
struct test_clamp : public test_case {
10681095
const ggml_type type;
@@ -2200,6 +2227,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
22002227
}
22012228

22022229
test_cases.emplace_back(new test_sqr());
2230+
test_cases.emplace_back(new test_sqrt());
22032231
test_cases.emplace_back(new test_clamp());
22042232

22052233
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5));

0 commit comments

Comments
 (0)