Skip to content

Commit 09552bc

Browse files
committed
fix norm
1 parent e0bbf3f commit 09552bc

File tree

3 files changed

+41
-84
lines changed

3 files changed

+41
-84
lines changed

ggml-cann/aclnn_ops.cpp

Lines changed: 13 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#include "aclnn_ops.h"
22

3-
#include <aclnnop/aclnn_batch_norm.h>
3+
#include <aclnnop/aclnn_layer_norm.h>
44
#include <aclnnop/aclnn_cast.h>
55

66
#include <cmath>
@@ -368,77 +368,32 @@ void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
368368

369369
float eps;
370370
memcpy(&eps, dst->op_params, sizeof(float));
371-
float *weight_host, *bias_host;
372-
int64_t channel = dst->ne[2];
373-
374-
weight_host = new float[channel];
375-
bias_host = new float[channel];
376-
377-
for (int i = 0; i < channel; i++) {
378-
weight_host[i] = 1;
379-
bias_host[i] = 0;
380-
}
381-
382-
aclrtStream stream = ctx.stream();
383-
384-
// Input tensors.
385-
void *buffer, *acl_weight, *acl_bias, *acl_mean, *acl_invstd;
386-
ACL_CHECK(aclrtMalloc(&buffer, 4 * channel * sizeof(float),
387-
ACL_MEM_MALLOC_HUGE_FIRST));
388-
acl_weight = buffer;
389-
acl_bias = acl_weight + sizeof(float) * channel;
390-
acl_mean = acl_bias + sizeof(float) * channel;
391-
acl_invstd = acl_mean + sizeof(float) * channel;
392-
393-
// Set input params.
394-
ACL_CHECK(aclrtMemcpyAsync(acl_weight, channel, weight_host, channel,
395-
ACL_MEMCPY_HOST_TO_DEVICE, stream));
396-
ACL_CHECK(aclrtMemcpyAsync(acl_bias, channel, bias_host, channel,
397-
ACL_MEMCPY_HOST_TO_DEVICE, stream));
398-
delete[] weight_host;
399-
delete[] bias_host;
400-
401-
// Create input tensors.
402-
int64_t input_tensor_shape[] = {channel};
403-
size_t input_tensor_stride[] = {1};
404-
aclTensor* weight =
405-
create_acl_tensor(acl_weight, ACL_FLOAT, sizeof(float),
406-
input_tensor_shape, input_tensor_stride, 1);
407-
aclTensor* bias =
408-
create_acl_tensor(acl_bias, ACL_FLOAT, sizeof(float),
409-
input_tensor_shape, input_tensor_stride, 1);
410-
aclTensor* mean =
411-
create_acl_tensor(acl_mean, ACL_FLOAT, sizeof(float),
412-
input_tensor_shape, input_tensor_stride, 1);
413-
aclTensor* invstd =
414-
create_acl_tensor(acl_invstd, ACL_FLOAT, sizeof(float),
415-
input_tensor_shape, input_tensor_stride, 1);
416371

417372
uint64_t workspaceSize = 0;
418373
aclOpExecutor* executor;
419374
void* workspaceAddr = nullptr;
420375

421-
ACL_CHECK(aclnnBatchNormGetWorkspaceSize(
422-
acl_src, weight, bias, nullptr, nullptr, false, 0, eps, acl_dst, mean,
423-
invstd, &workspaceSize, &executor));
376+
std::vector<int64_t> normData = {dst->ne[0]};
377+
aclIntArray* norm = aclCreateIntArray(normData.data(), normData.size());
378+
ACL_CHECK(aclnnLayerNormGetWorkspaceSize(acl_src, norm, nullptr, nullptr, eps,
379+
acl_dst, nullptr, nullptr,
380+
&workspaceSize, &executor));
424381

425382
if (workspaceSize > 0) {
426383
ACL_CHECK(aclrtMalloc(&workspaceAddr, workspaceSize,
427384
ACL_MEM_MALLOC_HUGE_FIRST));
428385
}
429386

430-
ACL_CHECK(aclnnBatchNorm(workspaceAddr, workspaceSize, executor, stream));
387+
aclrtStream stream = ctx.stream();
431388

432-
ACL_CHECK(aclDestroyTensor(weight));
433-
ACL_CHECK(aclDestroyTensor(bias));
434-
ACL_CHECK(aclDestroyTensor(mean));
435-
ACL_CHECK(aclDestroyTensor(invstd));
389+
ACL_CHECK(aclnnLayerNorm(workspaceAddr, workspaceSize, executor, stream));
436390

437-
// TODO: optimize argsort kernel or free tmp buffers after stream sync.
438-
ACL_CHECK(aclrtSynchronizeStream(stream));
439-
ACL_CHECK(aclrtFree(buffer));
391+
ACL_CHECK(aclDestroyIntArray(norm));
392+
ACL_CHECK(aclDestroyTensor(acl_src));
393+
ACL_CHECK(aclDestroyTensor(acl_dst));
440394

441395
if (workspaceSize > 0) {
442396
ACL_CHECK(aclrtFree(workspaceAddr));
443397
}
444-
}
398+
}
399+

ggml-cann/bcast.cpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ aclDataType type_mapping(ggml_type type) {
2727
* Transform ggml_tensor to acl_tensor. Note that ggml_tensor dimension order
2828
* is reversed compared to acl_tensor.
2929
*
30-
* If bcast_ne and bcast_stride is nullptr, use ggml_tensor's ne and nb.
31-
* otherwise, use bcast_ne bcast_stride, which means tensor dims should be
30+
* If bcast_ne and bcast_nb is nullptr, use ggml_tensor's ne and nb.
31+
* otherwise, use bcast_ne bcast_nb, which means tensor dims should be
3232
* changed to satisfy the broadcast. @sa: get_bcast_shape.
3333
*/
3434
aclTensor* create_acl_tensor(const ggml_tensor* tensor, int64_t* bcast_ne,
35-
int64_t* bcast_stride, int64_t bcast_dims) {
35+
size_t* bcast_nb, int64_t bcast_dims, aclFormat format) {
3636
size_t size = ggml_nbytes(tensor);
3737
void* deviceAddr = nullptr;
3838

@@ -53,13 +53,13 @@ aclTensor* create_acl_tensor(const ggml_tensor* tensor, int64_t* bcast_ne,
5353
for (int i = 0; i < GGML_MAX_DIMS; i++) {
5454
acl_ne[i] = tensor->ne[i];
5555
// The step size of acl is in elements.
56-
acl_stride[i] = tensor->nb[i] / tensor->nb[0];
56+
acl_stride[i] = tensor->nb[i] / ggml_type_size(tensor->type);
5757
}
5858
} else {
5959
// With bcast
6060
for (int i = 0; i < bcast_dims; i++) {
6161
acl_ne[i] = bcast_ne[i];
62-
acl_stride[i] = bcast_stride[i] / tensor->nb[0];
62+
acl_stride[i] = bcast_nb[i] / ggml_type_size(tensor->type);
6363
}
6464
}
6565

@@ -69,13 +69,13 @@ aclTensor* create_acl_tensor(const ggml_tensor* tensor, int64_t* bcast_ne,
6969

7070
aclTensor* acl_tensor =
7171
aclCreateTensor(acl_ne, dims, type_mapping(tensor->type), acl_stride, 0,
72-
aclFormat::ACL_FORMAT_ND, acl_ne, dims, deviceAddr);
72+
format, acl_ne, dims, deviceAddr);
7373

7474
return acl_tensor;
7575
}
7676

7777
aclTensor* create_acl_tensor(void* data_ptr, aclDataType dtype, size_t type_size, int64_t* ne,
78-
size_t* nb, int64_t dims) {
78+
size_t* nb, int64_t dims, aclFormat format) {
7979

8080
int64_t tmp_ne[GGML_MAX_DIMS * 2];
8181
int64_t tmp_stride[GGML_MAX_DIMS * 2];
@@ -90,7 +90,7 @@ aclTensor* create_acl_tensor(void* data_ptr, aclDataType dtype, size_t type_size
9090

9191
aclTensor* acl_tensor =
9292
aclCreateTensor(tmp_ne, dims, dtype, tmp_stride, 0,
93-
aclFormat::ACL_FORMAT_ND, tmp_ne, dims, data_ptr);
93+
format, tmp_ne, dims, data_ptr);
9494

9595
return acl_tensor;
9696
}
@@ -132,26 +132,26 @@ aclTensor* create_acl_tensor(void* data_ptr, aclDataType dtype, size_t type_size
132132
*/
133133
int64_t get_bcast_shape(const ggml_tensor* src0, const ggml_tensor* src1,
134134
int64_t* bcast_ne_src0, int64_t* bcast_ne_src1,
135-
int64_t* bcast_stride_src0,
136-
int64_t* bcast_stride_src1) {
135+
size_t* bcast_nb_src0,
136+
size_t* bcast_nb_src1) {
137137
GGML_ASSERT(ggml_can_repeat(src1, src0));
138138
int bcast_dim_cnt = 0;
139139
for (int i = 0; i < GGML_MAX_DIMS; i++) {
140140
int64_t nr = src0->ne[i] / src1->ne[i];
141141
bcast_ne_src0[bcast_dim_cnt] = src0->ne[i] / nr;
142142
bcast_ne_src1[bcast_dim_cnt] = src1->ne[i];
143-
bcast_stride_src0[bcast_dim_cnt] = src0->nb[i];
144-
bcast_stride_src1[bcast_dim_cnt] = src1->nb[i];
143+
bcast_nb_src0[bcast_dim_cnt] = src0->nb[i];
144+
bcast_nb_src1[bcast_dim_cnt] = src1->nb[i];
145145
bcast_dim_cnt++;
146146
if (nr != 1) {
147147
// Need to add an extra dim.
148148
bcast_ne_src0[bcast_dim_cnt] = nr;
149149
bcast_ne_src1[bcast_dim_cnt] = 1;
150-
bcast_stride_src0[bcast_dim_cnt] =
151-
bcast_stride_src0[bcast_dim_cnt - 1] *
150+
bcast_nb_src0[bcast_dim_cnt] =
151+
bcast_nb_src0[bcast_dim_cnt - 1] *
152152
bcast_ne_src0[bcast_dim_cnt - 1];
153-
bcast_stride_src1[bcast_dim_cnt] =
154-
bcast_stride_src1[bcast_dim_cnt - 1] *
153+
bcast_nb_src1[bcast_dim_cnt] =
154+
bcast_nb_src1[bcast_dim_cnt - 1] *
155155
bcast_ne_src1[bcast_dim_cnt - 1];
156156
bcast_dim_cnt++;
157157
}

ggml-cann/bcast.h

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,30 @@ aclDataType type_mapping(ggml_type type);
1010

1111
aclTensor* create_acl_tensor(const ggml_tensor* tensor,
1212
int64_t* bcast_ne = nullptr,
13-
int64_t* bcast_stride = nullptr,
14-
int64_t bcast_dims = 0);
13+
size_t* bcast_nb = nullptr,
14+
int64_t bcast_dims = 0,
15+
aclFormat format = ACL_FORMAT_ND);
1516

16-
aclTensor* create_acl_tensor(void* data_ptr, aclDataType dtype, size_t type_size, int64_t* ne,
17-
size_t* nb, int64_t dims);
17+
aclTensor* create_acl_tensor(void* data_ptr, aclDataType dtype,
18+
size_t type_size, int64_t* ne, size_t* nb,
19+
int64_t dims, aclFormat format = ACL_FORMAT_ND);
1820

1921
bool need_bcast(const ggml_tensor* t0, const ggml_tensor* t1);
2022

2123
int64_t get_bcast_shape(const ggml_tensor* src0, const ggml_tensor* src1,
2224
int64_t* bcast_ne_src0, int64_t* bcast_ne_src1,
23-
int64_t* bcast_stride_src0, int64_t* bcast_stride_src1);
25+
size_t* bcast_nb_src0, size_t* bcast_nb_src1);
2426

2527
// Bcast macro to avoid duplicate code.
2628
#define BCAST_SHAPE(src0, src1) \
2729
int64_t bcast_ne_##src0[GGML_MAX_DIMS * 2]; \
2830
int64_t bcast_ne_##src1[GGML_MAX_DIMS * 2]; \
29-
int64_t bcast_stride_##src0[GGML_MAX_DIMS * 2]; \
30-
int64_t bcast_stride_##src1[GGML_MAX_DIMS * 2]; \
31+
size_t bcast_nb_##src0[GGML_MAX_DIMS * 2]; \
32+
size_t bcast_nb_##src1[GGML_MAX_DIMS * 2]; \
3133
int64_t bcast_dims = \
3234
get_bcast_shape(src0, src1, bcast_ne_##src0, bcast_ne_##src1, \
33-
bcast_stride_##src0, bcast_stride_##src1);
35+
bcast_nb_##src0, bcast_nb_##src1);
3436

35-
#define BCAST_PARAM(src) bcast_ne_##src, bcast_stride_##src, bcast_dims
37+
#define BCAST_PARAM(src) bcast_ne_##src, bcast_nb_##src, bcast_dims
3638

37-
#endif //CANN_BCAST_H
39+
#endif // CANN_BCAST_H

0 commit comments

Comments
 (0)