Skip to content

Commit 7ae3f5f

Browse files
committed
optimize op repeat
1 parent 6201f9f commit 7ae3f5f

File tree

1 file changed

+23
-39
lines changed

1 file changed

+23
-39
lines changed

ggml-cann/aclnn_ops.cpp

Lines changed: 23 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,12 @@
44
#include <aclnnop/aclnn_cast.h>
55
#include <aclnnop/aclnn_group_norm.h>
66
#include <aclnnop/aclnn_softmax.h>
7+
#include <aclnnop/aclnn_repeat.h>
78

89
#include <cmath>
910
#include <cstring>
1011
#include <vector>
1112

12-
// TODO: repeat is implemented through add to apply bcast. Optimize it.
13-
// change to use aclnnRepeat
1413
void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
1514
ggml_tensor* src = dst->src[0];
1615
GGML_ASSERT(ggml_can_repeat(src, dst));
@@ -20,45 +19,30 @@ void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2019
// Set dst to a zero tensor.
2120
ACL_CHECK(aclrtMemsetAsync(dst->data, nbytes, 0, nbytes, main_stream));
2221

23-
aclTensor* acl_src;
24-
aclTensor* acl_dst;
22+
aclTensor* acl_src = create_acl_tensor(src);
23+
aclTensor* acl_dst = create_acl_tensor(dst);
2524

26-
// Short cut for same shape.
27-
if (ggml_are_same_shape(src, dst)) {
28-
ACL_CHECK(aclrtMemcpyAsync(dst->data, nbytes, src->data, nbytes,
29-
ACL_MEMCPY_DEVICE_TO_DEVICE, main_stream));
30-
} else {
31-
if (need_bcast(dst, src)) {
32-
BCAST_SHAPE(dst, src);
33-
acl_dst = create_acl_tensor(dst, BCAST_PARAM(dst));
34-
acl_src = create_acl_tensor(src, BCAST_PARAM(src));
35-
} else {
36-
acl_dst = create_acl_tensor(dst);
37-
acl_src = create_acl_tensor(src);
38-
}
39-
40-
// Add src0 to dst.
41-
aclScalar* alpha = nullptr;
42-
int alphaValue = 1;
43-
alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_INT32);
44-
45-
uint64_t workspaceSize = 0;
46-
aclOpExecutor* executor;
47-
void* workspaceAddr = nullptr;
48-
49-
ACL_CHECK(aclnnInplaceAddGetWorkspaceSize(acl_dst, acl_src, alpha,
50-
&workspaceSize, &executor));
51-
if (workspaceSize > 0) {
52-
workspaceAddr = ctx.alloc_buffer(workspaceSize);
53-
}
54-
55-
ACL_CHECK(aclnnInplaceAdd(workspaceAddr, workspaceSize, executor,
56-
main_stream));
57-
58-
ACL_CHECK(aclDestroyScalar(alpha));
59-
ACL_CHECK(aclDestroyTensor(acl_src));
60-
ACL_CHECK(aclDestroyTensor(acl_dst));
25+
int64_t repeatsArray[] = {dst->ne[3] / src->ne[3], dst->ne[2] / src->ne[2],
26+
dst->ne[1] / src->ne[1], dst->ne[0] / src->ne[0]};
27+
28+
aclIntArray *repeats = aclCreateIntArray(repeatsArray, GGML_MAX_DIMS);
29+
30+
uint64_t workspaceSize = 0;
31+
aclOpExecutor* executor;
32+
void* workspaceAddr = nullptr;
33+
34+
ACL_CHECK(aclnnRepeatGetWorkspaceSize(acl_src, repeats, acl_dst, &workspaceSize, &executor));
35+
36+
if (workspaceSize > 0) {
37+
workspaceAddr = ctx.alloc_buffer(workspaceSize);
6138
}
39+
40+
aclrtStream stream = ctx.stream();
41+
ACL_CHECK(aclnnRepeat(workspaceAddr, workspaceSize, executor, stream));
42+
ACL_CHECK(aclDestroyIntArray(repeats));
43+
ACL_CHECK(aclDestroyTensor(acl_src));
44+
ACL_CHECK(aclDestroyTensor(acl_dst));
45+
6246
}
6347

6448
void ggml_cann_add(ggml_backend_cann_context& ctx, ggml_tensor* dst) {

0 commit comments

Comments
 (0)