|
2 | 2 |
|
3 | 3 | #include <aclnnop/aclnn_layer_norm.h>
|
4 | 4 | #include <aclnnop/aclnn_cast.h>
|
| 5 | +#include <aclnnop/aclnn_group_norm.h> |
5 | 6 |
|
6 | 7 | #include <cmath>
|
7 | 8 | #include <cstring>
|
@@ -397,3 +398,57 @@ void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
|
397 | 398 | }
|
398 | 399 | }
|
399 | 400 |
|
| 401 | +void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) { |
| 402 | + ggml_tensor* src = dst->src[0]; |
| 403 | + |
| 404 | + aclTensor* acl_src = create_acl_tensor(src); |
| 405 | + aclTensor* acl_dst = create_acl_tensor(dst); |
| 406 | + |
| 407 | + const float eps = 1e-6f; // TODO: make this a parameter |
| 408 | + int n_groups = dst->op_params[0]; |
| 409 | + |
| 410 | + uint64_t workspaceSize = 0; |
| 411 | + aclOpExecutor* executor; |
| 412 | + void* workspaceAddr = nullptr; |
| 413 | + |
| 414 | + int64_t N = src->ne[3]; |
| 415 | + int64_t C = src->ne[2]; |
| 416 | + int64_t HxW = src->ne[1] * src->ne[0]; |
| 417 | + |
| 418 | + size_t type_size = ggml_type_size(src->type); |
| 419 | + int64_t ne[] = {n_groups, N}; |
| 420 | + size_t nb[] = {type_size, type_size * n_groups}; |
| 421 | + size_t n_bytes = N * n_groups; |
| 422 | + void* buffer; |
| 423 | + ACL_CHECK(aclrtMalloc(&buffer, n_bytes * 2, ACL_MEM_MALLOC_HUGE_FIRST)); |
| 424 | + aclTensor* acl_mean_out = |
| 425 | + create_acl_tensor(buffer, ACL_FLOAT, type_size, ne, nb, ACL_FORMAT_ND); |
| 426 | + aclTensor* acl_rstd_out = create_acl_tensor( |
| 427 | + (char*)buffer + n_bytes, ACL_FLOAT, type_size, ne, nb, ACL_FORMAT_ND); |
| 428 | + |
| 429 | + ACL_CHECK(aclnnGroupNormGetWorkspaceSize( |
| 430 | + acl_src, nullptr, nullptr, N, C, HxW, n_groups, eps, acl_dst, |
| 431 | + acl_mean_out, acl_rstd_out, &workspaceSize, &executor)); |
| 432 | + |
| 433 | + if (workspaceSize > 0) { |
| 434 | + ACL_CHECK(aclrtMalloc(&workspaceAddr, workspaceSize, |
| 435 | + ACL_MEM_MALLOC_HUGE_FIRST)); |
| 436 | + } |
| 437 | + |
| 438 | + aclrtStream stream = ctx.stream(); |
| 439 | + |
| 440 | + ACL_CHECK(aclnnGroupNorm(workspaceAddr, workspaceSize, executor, stream)); |
| 441 | + |
| 442 | + ACL_CHECK(aclDestroyTensor(acl_src)); |
| 443 | + ACL_CHECK(aclDestroyTensor(acl_dst)); |
| 444 | + ACL_CHECK(aclDestroyTensor(acl_mean_out)); |
| 445 | + ACL_CHECK(aclDestroyTensor(acl_rstd_out)); |
| 446 | + |
| 447 | + // TODO: free after sync. |
| 448 | + ACL_CHECK(aclrtSynchronizeStream(stream)); |
| 449 | + ACL_CHECK(aclrtFree(buffer)); |
| 450 | + |
| 451 | + if (workspaceSize > 0) { |
| 452 | + ACL_CHECK(aclrtFree(workspaceAddr)); |
| 453 | + } |
| 454 | +} |
0 commit comments