Skip to content

Commit 65cfe13

Browse files
noemotiovonnoemotiovon
andauthored
CANN: Support operator SIN COS ARGMAX (#12709)
* [CANN]support sin cos argmax Signed-off-by: noemotiovon <[email protected]> * [CANN]codestyle adjustment Signed-off-by: noemotiovon <[email protected]> * [CANN]Remove redundant code Signed-off-by: noemotiovon <[email protected]> --------- Signed-off-by: noemotiovon <[email protected]> Co-authored-by: noemotiovon <[email protected]>
1 parent 3f9da22 commit 65cfe13

File tree

3 files changed

+97
-0
lines changed

3 files changed

+97
-0
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
#include <aclnnop/aclnn_triu.h>
5252
#include <aclnnop/aclnn_upsample_nearest_2d.h>
5353
#include <aclnnop/aclnn_weight_quant_batch_matmul_v2.h>
54+
#include <aclnnop/aclnn_argmax.h>
5455
#include <float.h>
5556

5657
#include <cmath>
@@ -3440,3 +3441,46 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
34403441
ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor));
34413442
ACL_CHECK(aclDestroyTensor(acl_dst));
34423443
}
3444+
3445+
3446+
void ggml_cann_argmax(ggml_backend_cann_context& ctx, ggml_tensor* dst){
3447+
ggml_tensor * src0 = dst->src[0];
3448+
3449+
aclTensor* acl_src = ggml_cann_create_tensor(src0);
3450+
aclTensor* acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3);
3451+
3452+
uint64_t workspaceSize = 0;
3453+
aclOpExecutor* executor;
3454+
void* workspaceAddr = nullptr;
3455+
3456+
ACL_CHECK(aclnnArgMaxGetWorkspaceSize(acl_src, 3, false, acl_dst,
3457+
&workspaceSize, &executor));
3458+
if (workspaceSize > 0) {
3459+
ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
3460+
workspaceAddr = workspace_allocator.get();
3461+
}
3462+
ACL_CHECK(aclnnArgMax(workspaceAddr, workspaceSize, executor, ctx.stream()));
3463+
3464+
ACL_CHECK(aclDestroyTensor(acl_src));
3465+
ACL_CHECK(aclDestroyTensor(acl_dst));
3466+
}
3467+
3468+
void ggml_cann_cos(ggml_backend_cann_context& ctx, ggml_tensor* dst){
3469+
ggml_tensor * src0 = dst->src[0];
3470+
3471+
aclTensor* acl_src = ggml_cann_create_tensor(src0);
3472+
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
3473+
aclnn_cos(ctx, acl_src, acl_dst);
3474+
ACL_CHECK(aclDestroyTensor(acl_src));
3475+
ACL_CHECK(aclDestroyTensor(acl_dst));
3476+
}
3477+
3478+
void ggml_cann_sin(ggml_backend_cann_context& ctx, ggml_tensor* dst){
3479+
ggml_tensor * src0 = dst->src[0];
3480+
3481+
aclTensor* acl_src = ggml_cann_create_tensor(src0);
3482+
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
3483+
aclnn_sin(ctx, acl_src, acl_dst);
3484+
ACL_CHECK(aclDestroyTensor(acl_src));
3485+
ACL_CHECK(aclDestroyTensor(acl_dst));
3486+
}

ggml/src/ggml-cann/aclnn_ops.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,47 @@ void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst);
484484
*/
485485
void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst);
486486

487+
/**
488+
* @brief Computes the index of the maximum value along the specified dimension
489+
* of a ggml tensor using the CANN backend.
490+
*
491+
* @details This function performs an argmax operation on the input tensor.
492+
* It finds the index of the maximum value along the specified axis
493+
* and stores these indices in the destination tensor `dst`. The
494+
* operation is executed using the CANN backend for optimized performance.
495+
*
496+
* @param ctx The CANN context used for operations.
497+
* @param dst The destination tensor where the indices of the maximum values will be stored.
498+
* dst->op is `GGML_OP_ARGMAX`.
499+
*/
500+
void ggml_cann_argmax(ggml_backend_cann_context& ctx, ggml_tensor* dst);
501+
502+
/**
503+
* @brief Computes the cosine of each element in a ggml tensor using the CANN backend.
504+
*
505+
* @details This function applies the cosine function element-wise to the input tensor.
506+
* The computed cosine values are stored in the destination tensor `dst`.
507+
* The operation is optimized using the CANN backend for improved performance.
508+
*
509+
* @param ctx The CANN context used for operations.
510+
* @param dst The destination tensor where the cosine values will be stored.
511+
* dst->op is `GGML_OP_COS`.
512+
*/
513+
void ggml_cann_cos(ggml_backend_cann_context& ctx, ggml_tensor* dst);
514+
515+
/**
516+
* @brief Computes the sine of each element in a ggml tensor using the CANN backend.
517+
*
518+
* @details This function applies the sine function element-wise to the input tensor.
519+
* The computed sine values are stored in the destination tensor `dst`.
520+
* The operation is optimized using the CANN backend for improved performance.
521+
*
522+
* @param ctx The CANN context used for operations.
523+
* @param dst The destination tensor where the sine values will be stored.
524+
* dst->op is `GGML_OP_SIN`.
525+
*/
526+
void ggml_cann_sin(ggml_backend_cann_context& ctx, ggml_tensor* dst);
527+
487528
template <aclnnStatus getWorkspaceSize(const aclTensor*, const aclTensor*,
488529
aclTensor*, uint64_t*, aclOpExecutor**),
489530
aclnnStatus execute(void*, uint64_t, aclOpExecutor*, aclrtStream)>

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1420,6 +1420,15 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
14201420
case GGML_OP_ARGSORT:
14211421
ggml_cann_argsort(ctx, dst);
14221422
break;
1423+
case GGML_OP_ARGMAX:
1424+
ggml_cann_argmax(ctx, dst);
1425+
break;
1426+
case GGML_OP_COS:
1427+
ggml_cann_cos(ctx, dst);
1428+
break;
1429+
case GGML_OP_SIN:
1430+
ggml_cann_sin(ctx, dst);
1431+
break;
14231432
default:
14241433
return false;
14251434
}
@@ -1802,6 +1811,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
18021811
case GGML_OP_ARANGE:
18031812
case GGML_OP_TIMESTEP_EMBEDDING:
18041813
case GGML_OP_LEAKY_RELU:
1814+
case GGML_OP_ARGMAX:
1815+
case GGML_OP_COS:
1816+
case GGML_OP_SIN:
18051817
return true;
18061818
default:
18071819
return false;

0 commit comments

Comments
 (0)