4
4
#include < aclnnop/aclnn_cast.h>
5
5
#include < aclnnop/aclnn_group_norm.h>
6
6
#include < aclnnop/aclnn_softmax.h>
7
+ #include < aclnnop/aclnn_repeat.h>
7
8
8
9
#include < cmath>
9
10
#include < cstring>
10
11
#include < vector>
11
12
12
- // TODO: repeat is implemented through add to apply bcast. Optimize it.
13
- // change to use aclnnRepeat
14
13
void ggml_cann_repeat (ggml_backend_cann_context& ctx, ggml_tensor* dst) {
15
14
ggml_tensor* src = dst->src [0 ];
16
15
GGML_ASSERT (ggml_can_repeat (src, dst));
@@ -20,45 +19,30 @@ void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
20
19
// Set dst to a zero tensor.
21
20
ACL_CHECK (aclrtMemsetAsync (dst->data , nbytes, 0 , nbytes, main_stream));
22
21
23
- aclTensor* acl_src;
24
- aclTensor* acl_dst;
22
+ aclTensor* acl_src = create_acl_tensor (src) ;
23
+ aclTensor* acl_dst = create_acl_tensor (dst) ;
25
24
26
- // Short cut for same shape.
27
- if (ggml_are_same_shape (src, dst)) {
28
- ACL_CHECK (aclrtMemcpyAsync (dst->data , nbytes, src->data , nbytes,
29
- ACL_MEMCPY_DEVICE_TO_DEVICE, main_stream));
30
- } else {
31
- if (need_bcast (dst, src)) {
32
- BCAST_SHAPE (dst, src);
33
- acl_dst = create_acl_tensor (dst, BCAST_PARAM (dst));
34
- acl_src = create_acl_tensor (src, BCAST_PARAM (src));
35
- } else {
36
- acl_dst = create_acl_tensor (dst);
37
- acl_src = create_acl_tensor (src);
38
- }
39
-
40
- // Add src0 to dst.
41
- aclScalar* alpha = nullptr ;
42
- int alphaValue = 1 ;
43
- alpha = aclCreateScalar (&alphaValue, aclDataType::ACL_INT32);
44
-
45
- uint64_t workspaceSize = 0 ;
46
- aclOpExecutor* executor;
47
- void * workspaceAddr = nullptr ;
48
-
49
- ACL_CHECK (aclnnInplaceAddGetWorkspaceSize (acl_dst, acl_src, alpha,
50
- &workspaceSize, &executor));
51
- if (workspaceSize > 0 ) {
52
- workspaceAddr = ctx.alloc_buffer (workspaceSize);
53
- }
54
-
55
- ACL_CHECK (aclnnInplaceAdd (workspaceAddr, workspaceSize, executor,
56
- main_stream));
57
-
58
- ACL_CHECK (aclDestroyScalar (alpha));
59
- ACL_CHECK (aclDestroyTensor (acl_src));
60
- ACL_CHECK (aclDestroyTensor (acl_dst));
25
+ int64_t repeatsArray[] = {dst->ne [3 ] / src->ne [3 ], dst->ne [2 ] / src->ne [2 ],
26
+ dst->ne [1 ] / src->ne [1 ], dst->ne [0 ] / src->ne [0 ]};
27
+
28
+ aclIntArray *repeats = aclCreateIntArray (repeatsArray, GGML_MAX_DIMS);
29
+
30
+ uint64_t workspaceSize = 0 ;
31
+ aclOpExecutor* executor;
32
+ void * workspaceAddr = nullptr ;
33
+
34
+ ACL_CHECK (aclnnRepeatGetWorkspaceSize (acl_src, repeats, acl_dst, &workspaceSize, &executor));
35
+
36
+ if (workspaceSize > 0 ) {
37
+ workspaceAddr = ctx.alloc_buffer (workspaceSize);
61
38
}
39
+
40
+ aclrtStream stream = ctx.stream ();
41
+ ACL_CHECK (aclnnRepeat (workspaceAddr, workspaceSize, executor, stream));
42
+ ACL_CHECK (aclDestroyIntArray (repeats));
43
+ ACL_CHECK (aclDestroyTensor (acl_src));
44
+ ACL_CHECK (aclDestroyTensor (acl_dst));
45
+
62
46
}
63
47
64
48
void ggml_cann_add (ggml_backend_cann_context& ctx, ggml_tensor* dst) {
0 commit comments