Skip to content

Commit da0e67c

Browse files
committed
add op sum_rows
1 parent 0016c0b commit da0e67c

File tree

3 files changed

+37
-10
lines changed

3 files changed

+37
-10
lines changed

ggml-cann.cpp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,7 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
429429
return false;
430430
case GGML_OP_CONT:
431431
ggml_cann_cont(ctx, dst);
432+
break;
432433
case GGML_OP_NONE:
433434
case GGML_OP_RESHAPE:
434435
case GGML_OP_VIEW:
@@ -445,12 +446,13 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
445446
case GGML_OP_ALIBI:
446447
case GGML_OP_IM2COL:
447448
case GGML_OP_POOL_2D:
448-
case GGML_OP_SUM_ROWS:
449449
return false;
450+
case GGML_OP_SUM_ROWS:
451+
ggml_cann_sum_rows(ctx, dst);
452+
break;
450453
case GGML_OP_ARGSORT:
451454
ggml_cann_argsort(ctx, dst);
452455
break;
453-
return false;
454456
default:
455457
return false;
456458
}
@@ -651,25 +653,21 @@ GGML_CALL static bool ggml_backend_cann_supports_op(ggml_backend_t backend,
651653
case GGML_OP_CPY:
652654
return false;
653655
case GGML_OP_DUP:
654-
return true;
655656
case GGML_OP_REPEAT:
656657
case GGML_OP_CONCAT:
657658
case GGML_OP_NONE:
658659
case GGML_OP_RESHAPE:
659660
case GGML_OP_VIEW:
660661
case GGML_OP_PERMUTE:
661662
case GGML_OP_TRANSPOSE:
662-
return true;
663663
case GGML_OP_NORM:
664-
return true;
665664
case GGML_OP_ADD:
666665
case GGML_OP_MUL:
667666
case GGML_OP_DIV:
668667
return true;
669668
case GGML_OP_RMS_NORM:
670669
return false;
671670
case GGML_OP_SCALE:
672-
return true;
673671
case GGML_OP_SQR:
674672
case GGML_OP_CLAMP:
675673
case GGML_OP_CONT:
@@ -682,18 +680,15 @@ GGML_CALL static bool ggml_backend_cann_supports_op(ggml_backend_t backend,
682680
case GGML_OP_ALIBI:
683681
case GGML_OP_IM2COL:
684682
case GGML_OP_POOL_2D:
685-
case GGML_OP_SUM_ROWS:
686683
return false;
684+
case GGML_OP_SUM_ROWS:
687685
case GGML_OP_ARGSORT:
688-
return true;
689686
case GGML_OP_ACC:
690-
return true;
691687
case GGML_OP_GROUP_NORM:
692688
return true;
693689
case GGML_OP_UPSCALE:
694690
return false;
695691
case GGML_OP_PAD:
696-
return true;
697692
case GGML_OP_ARANGE:
698693
return true;
699694
case GGML_OP_TIMESTEP_EMBEDDING:

ggml-cann/aclnn_ops.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <aclnnop/aclnn_layer_norm.h>
66
#include <aclnnop/aclnn_repeat.h>
77
#include <aclnnop/aclnn_softmax.h>
8+
#include <aclnnop/aclnn_reduce_sum.h>
89

910
#include <cmath>
1011
#include <cstring>
@@ -475,4 +476,33 @@ void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
475476

476477
ACL_CHECK(aclDestroyTensor(acl_src1));
477478
ACL_CHECK(aclDestroyTensor(acl_dst));
479+
}
480+
481+
void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
482+
ggml_tensor* src = dst->src[0];
483+
484+
aclTensor* acl_src = create_acl_tensor(src);
485+
486+
GGML_ASSERT(dst->ne[0] == 1);
487+
aclTensor* acl_dst = create_acl_tensor(dst);
488+
489+
uint64_t workspaceSize = 0;
490+
aclOpExecutor* executor;
491+
void* workspaceAddr = nullptr;
492+
493+
int64_t reduce_dims_host[] = {3};
494+
aclIntArray* reduce_dims = aclCreateIntArray(reduce_dims_host, 1);
495+
496+
ACL_CHECK(aclnnReduceSumGetWorkspaceSize(acl_src, reduce_dims, true,
497+
type_mapping(src->type), acl_dst,
498+
&workspaceSize, &executor));
499+
if (workspaceSize > 0) {
500+
workspaceAddr = ctx.alloc_buffer(workspaceSize);
501+
}
502+
503+
aclrtStream stream = ctx.stream();
504+
ACL_CHECK(aclnnReduceSum(workspaceAddr, workspaceSize, executor, stream));
505+
506+
ACL_CHECK(aclDestroyTensor(acl_src));
507+
ACL_CHECK(aclDestroyTensor(acl_dst));
478508
}

ggml-cann/aclnn_ops.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst);
4545

4646
void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst);
4747

48+
void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst);
49+
4850
template <aclnnStatus getWorkspaceSize(const aclTensor*, const aclTensor*,
4951
aclTensor*, uint64_t*, aclOpExecutor**),
5052
aclnnStatus execute(void*, uint64_t, aclOpExecutor*, aclrtStream)>

0 commit comments

Comments
 (0)