Skip to content

Commit 4b55e48

Browse files
committed
add group norm
1 parent 09552bc commit 4b55e48

File tree

3 files changed

+61
-1
lines changed

3 files changed

+61
-1
lines changed

ggml-cann.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,8 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
400400
ggml_cann_norm(ctx, dst);
401401
break;
402402
case GGML_OP_GROUP_NORM:
403-
return false;
403+
ggml_cann_group_norm(ctx, dst);
404+
break;
404405
case GGML_OP_CONCAT:
405406
ggml_cann_concat(ctx, dst);
406407
break;
@@ -679,7 +680,9 @@ GGML_CALL static bool ggml_backend_cann_supports_op(ggml_backend_t backend,
679680
case GGML_OP_ARGSORT:
680681
return true;
681682
case GGML_OP_ACC:
683+
return false;
682684
case GGML_OP_GROUP_NORM:
685+
return true;
683686
case GGML_OP_UPSCALE:
684687
return false;
685688
case GGML_OP_PAD:

ggml-cann/aclnn_ops.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <aclnnop/aclnn_layer_norm.h>
44
#include <aclnnop/aclnn_cast.h>
5+
#include <aclnnop/aclnn_group_norm.h>
56

67
#include <cmath>
78
#include <cstring>
@@ -397,3 +398,57 @@ void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
397398
}
398399
}
399400

401+
void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
402+
ggml_tensor* src = dst->src[0];
403+
404+
aclTensor* acl_src = create_acl_tensor(src);
405+
aclTensor* acl_dst = create_acl_tensor(dst);
406+
407+
const float eps = 1e-6f; // TODO: make this a parameter
408+
int n_groups = dst->op_params[0];
409+
410+
uint64_t workspaceSize = 0;
411+
aclOpExecutor* executor;
412+
void* workspaceAddr = nullptr;
413+
414+
int64_t N = src->ne[3];
415+
int64_t C = src->ne[2];
416+
int64_t HxW = src->ne[1] * src->ne[0];
417+
418+
size_t type_size = ggml_type_size(src->type);
419+
int64_t ne[] = {n_groups, N};
420+
size_t nb[] = {type_size, type_size * n_groups};
421+
size_t n_bytes = N * n_groups;
422+
void* buffer;
423+
ACL_CHECK(aclrtMalloc(&buffer, n_bytes * 2, ACL_MEM_MALLOC_HUGE_FIRST));
424+
aclTensor* acl_mean_out =
425+
create_acl_tensor(buffer, ACL_FLOAT, type_size, ne, nb, ACL_FORMAT_ND);
426+
aclTensor* acl_rstd_out = create_acl_tensor(
427+
(char*)buffer + n_bytes, ACL_FLOAT, type_size, ne, nb, ACL_FORMAT_ND);
428+
429+
ACL_CHECK(aclnnGroupNormGetWorkspaceSize(
430+
acl_src, nullptr, nullptr, N, C, HxW, n_groups, eps, acl_dst,
431+
acl_mean_out, acl_rstd_out, &workspaceSize, &executor));
432+
433+
if (workspaceSize > 0) {
434+
ACL_CHECK(aclrtMalloc(&workspaceAddr, workspaceSize,
435+
ACL_MEM_MALLOC_HUGE_FIRST));
436+
}
437+
438+
aclrtStream stream = ctx.stream();
439+
440+
ACL_CHECK(aclnnGroupNorm(workspaceAddr, workspaceSize, executor, stream));
441+
442+
ACL_CHECK(aclDestroyTensor(acl_src));
443+
ACL_CHECK(aclDestroyTensor(acl_dst));
444+
ACL_CHECK(aclDestroyTensor(acl_mean_out));
445+
ACL_CHECK(aclDestroyTensor(acl_rstd_out));
446+
447+
// TODO: free after sync.
448+
ACL_CHECK(aclrtSynchronizeStream(stream));
449+
ACL_CHECK(aclrtFree(buffer));
450+
451+
if (workspaceSize > 0) {
452+
ACL_CHECK(aclrtFree(workspaceAddr));
453+
}
454+
}

ggml-cann/aclnn_ops.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst);
3939

4040
void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);
4141

42+
void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);
43+
4244
template <aclnnStatus getWorkspaceSize(const aclTensor*, const aclTensor*,
4345
aclTensor*, uint64_t*, aclOpExecutor**),
4446
aclnnStatus execute(void*, uint64_t, aclOpExecutor*, aclrtStream)>

0 commit comments

Comments
 (0)