Skip to content

Commit 4dea1d5

Browse files
author
noemotiovon
committed
[cann] rope optimization
1 parent cce5a90 commit 4dea1d5

File tree

2 files changed

+91
-178
lines changed

2 files changed

+91
-178
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 39 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -2859,15 +2859,27 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
28592859
ACL_CHECK(aclDestroyTensor(acl_cos_tensor));
28602860
}
28612861

2862+
#ifdef __cplusplus
2863+
extern "C" {
2864+
#endif
2865+
aclnnStatus aclnnRotaryPositionEmbeddingGetWorkspaceSize(
2866+
const aclTensor* x, const aclTensor* cos, const aclTensor* sin,
2867+
int64_t mode, const aclTensor* yOut, uint64_t* workspaceSize,
2868+
aclOpExecutor** executor);
2869+
aclnnStatus aclnnRotaryPositionEmbedding(void* workspace,
2870+
uint64_t workspaceSize,
2871+
aclOpExecutor* executor,
2872+
aclrtStream stream);
2873+
#ifdef __cplusplus
2874+
}
2875+
#endif
2876+
28622877
void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
28632878
// TODO: use ascendc
28642879
// Only test with LLAMA model.
28652880
ggml_tensor* src0 = dst->src[0]; // input
28662881
ggml_tensor* src2 = dst->src[2]; // freq_factors
28672882

2868-
// TODO: with freq_factors
2869-
GGML_ASSERT(src2 == NULL);
2870-
28712883
// param
28722884
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
28732885
// const int n_past = ((int32_t *) dst->op_params)[0];
@@ -2885,14 +2897,19 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
28852897
memcpy(&beta_fast, (int32_t*)dst->op_params + 9, sizeof(float));
28862898
memcpy(&beta_slow, (int32_t*)dst->op_params + 10, sizeof(float));
28872899

2888-
GGML_ASSERT(n_dims <= ne0);
2900+
// TODO: with freq_factors
2901+
GGML_ASSERT(src2 == NULL);
2902+
2903+
GGML_ASSERT(n_dims == ne0);
28892904
GGML_ASSERT(n_dims % 2 == 0);
28902905

28912906
// TODO: ext_factor != 0
28922907
GGML_ASSERT(ext_factor == 0);
28932908
// TODO: freq_scale != 1
28942909
GGML_ASSERT(freq_scale == 1);
28952910

2911+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
2912+
28962913
const float theta_scale = powf(freq_base, -2.0f / n_dims);
28972914

28982915
float corr_dims[2];
@@ -2924,177 +2941,30 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
29242941
aclnn_cache_init(ctx, dst, acl_cos_reshape_tensor, acl_sin_reshape_tensor,
29252942
theta_scale, is_neox);
29262943

2927-
// roll input
2928-
void* input_roll_buffer;
2929-
aclTensor* acl_minus_one_tensor;
2930-
void* minus_one_scale_buffer = nullptr;
2931-
ggml_cann_pool_alloc roll_allocator(ctx.pool(), ggml_nbytes(src0));
2932-
ggml_cann_pool_alloc minus_one_scale_allocator(
2933-
ctx.pool(), sizeof(float_t) * src0->ne[0]);
2934-
if (!is_neox) {
2935-
// roll input: [q0,q1,q2,q3,...] -> [q1,q0,q3,q2,...]
2936-
input_roll_buffer = roll_allocator.get();
2937-
int64_t input_roll_ne[4] = {2, src0->ne[1] * (src0->ne[0] / 2),
2938-
src0->ne[2], src0->ne[3]};
2939-
size_t input_roll_nb[GGML_MAX_DIMS];
2940-
input_roll_nb[0] = ggml_type_size(src0->type);
2941-
for (int i = 1; i < GGML_MAX_DIMS; i++) {
2942-
input_roll_nb[i] = input_roll_nb[i - 1] * input_roll_ne[i - 1];
2943-
}
2944-
aclTensor* acl_input_roll_tensor = ggml_cann_create_tensor(
2945-
input_roll_buffer, ggml_cann_type_mapping(src0->type),
2946-
ggml_type_size(src0->type), input_roll_ne, input_roll_nb,
2947-
GGML_MAX_DIMS);
2948-
aclTensor* acl_input_tensor = ggml_cann_create_tensor(
2949-
src0->data, ggml_cann_type_mapping(src0->type),
2950-
ggml_type_size(src0->type), input_roll_ne, input_roll_nb,
2951-
GGML_MAX_DIMS);
2952-
2953-
int64_t shifts[] = {1};
2954-
int64_t dims[] = {3};
2955-
aclnn_roll(ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims);
2956-
ACL_CHECK(aclDestroyTensor(acl_input_roll_tensor));
2957-
ACL_CHECK(aclDestroyTensor(acl_input_tensor));
2958-
2959-
// init [-1, 1, -1, 1, ...]
2960-
minus_one_scale_buffer = minus_one_scale_allocator.get();
2961-
2962-
int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1};
2963-
size_t minus_one_nb[GGML_MAX_DIMS];
2964-
minus_one_nb[0] = sizeof(float_t);
2965-
for (int i = 1; i < GGML_MAX_DIMS; i++) {
2966-
minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1];
2967-
}
2968-
acl_minus_one_tensor = aclnn_ones(
2969-
ctx, minus_one_scale_buffer, sizeof(float_t) * src0->ne[0],
2970-
minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), 1);
2971-
int64_t dim = 3;
2972-
int64_t* index = new int64_t[src0->ne[0]];
2973-
for (int i = 0; i < src0->ne[0]; i++) {
2974-
index[i] = i / 2 * 2;
2975-
}
2976-
int64_t index_num = src0->ne[0];
2977-
float value = -1;
2978-
aclnn_index_fill_tensor(ctx, acl_minus_one_tensor, dim, index,
2979-
index_num, value);
2980-
} else {
2981-
// roll input: [q0,q1,q2,...] ->
2982-
// [q_half,q_half+1,...,q_end,q0,q1,...q_half-1]
2983-
input_roll_buffer = roll_allocator.get();
2984-
aclTensor* acl_input_roll_tensor = ggml_cann_create_tensor(
2985-
input_roll_buffer, ggml_cann_type_mapping(src0->type),
2986-
ggml_type_size(src0->type), src0->ne, src0->nb, GGML_MAX_DIMS);
2987-
aclTensor* acl_input_tensor = ggml_cann_create_tensor(src0);
2988-
2989-
int64_t shifts[] = {src0->ne[0] / 2};
2990-
int64_t dims[] = {3};
2991-
aclnn_roll(ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims);
2992-
2993-
ACL_CHECK(aclDestroyTensor(acl_input_roll_tensor));
2994-
ACL_CHECK(aclDestroyTensor(acl_input_tensor));
2995-
2996-
// init [-1, -1, -1, 1, 1,1,...]
2997-
minus_one_scale_buffer = minus_one_scale_allocator.get();
2944+
uint64_t workspaceSize = 0;
2945+
aclOpExecutor* executor;
29982946

2999-
int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1};
3000-
size_t minus_one_nb[GGML_MAX_DIMS];
3001-
minus_one_nb[0] = sizeof(float_t);
3002-
for (int i = 1; i < GGML_MAX_DIMS; i++) {
3003-
minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1];
3004-
}
3005-
acl_minus_one_tensor = aclnn_ones(
3006-
ctx, minus_one_scale_buffer, sizeof(float_t) * src0->ne[0],
3007-
minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), 1);
3008-
// -1 * first half
3009-
int64_t first_half_ne[4] = {src0->ne[0] / 2, 1, 1, 1};
3010-
size_t first_half_nb[GGML_MAX_DIMS];
3011-
first_half_nb[0] = sizeof(float_t);
3012-
for (int i = 1; i < GGML_MAX_DIMS; i++) {
3013-
first_half_nb[i] = first_half_nb[i - 1] * first_half_ne[i - 1];
3014-
}
3015-
aclTensor* acl_first_half_tensor = ggml_cann_create_tensor(
3016-
minus_one_scale_buffer, ACL_FLOAT, sizeof(float_t), first_half_ne,
3017-
first_half_nb, GGML_MAX_DIMS);
3018-
bool inplace = true;
3019-
float scale = -1;
3020-
aclnn_muls(ctx, acl_first_half_tensor, scale, nullptr, inplace);
3021-
ACL_CHECK(aclDestroyTensor(acl_first_half_tensor));
3022-
}
3023-
3024-
// TODO: n_dims < ne0
3025-
GGML_ASSERT(n_dims == src0->ne[0]);
3026-
3027-
// input * scale
3028-
ggml_cann_pool_alloc roll_mul_scale_allocator(ctx.pool(),
3029-
ggml_nbytes(src0));
3030-
void* input_roll_mul_scale_buffer = roll_mul_scale_allocator.get();
3031-
size_t input_nb[GGML_MAX_DIMS];
3032-
input_nb[0] = ggml_type_size(src0->type);
3033-
for (int i = 1; i < GGML_MAX_DIMS; i++) {
3034-
input_nb[i] = input_nb[i - 1] * src0->ne[i - 1];
3035-
}
3036-
aclTensor* acl_input_roll_mul_scale_tensor = ggml_cann_create_tensor(
3037-
input_roll_mul_scale_buffer, ggml_cann_type_mapping(src0->type),
3038-
ggml_type_size(src0->type), src0->ne, input_nb, GGML_MAX_DIMS);
3039-
aclTensor* acl_input_roll_reshape_tensor = ggml_cann_create_tensor(
3040-
input_roll_buffer, ggml_cann_type_mapping(src0->type),
3041-
ggml_type_size(src0->type), src0->ne, input_nb, GGML_MAX_DIMS);
2947+
void* workspaceAddr = nullptr;
30422948

3043-
aclnn_mul(ctx, acl_input_roll_reshape_tensor, acl_minus_one_tensor,
3044-
acl_input_roll_mul_scale_tensor);
2949+
int acl_mode = mode;
2950+
if (mode == 0) {
2951+
acl_mode = 1;
2952+
}
30452953

3046-
// output
3047-
aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
2954+
aclTensor* acl_x = ggml_cann_create_tensor(src0);
30482955
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
3049-
void* output_fp32_buffer;
3050-
if (src0->type == GGML_TYPE_F32) {
3051-
aclnn_inplace_mul(ctx, acl_src0, acl_cos_reshape_tensor);
3052-
aclnn_inplace_mul(ctx, acl_input_roll_mul_scale_tensor,
3053-
acl_sin_reshape_tensor);
3054-
aclnn_add(ctx, acl_src0, acl_input_roll_mul_scale_tensor, acl_dst);
3055-
// TODO: ne0 != n_dims in mode2
3056-
} else if (src0->type == GGML_TYPE_F16) {
3057-
size_t input_fp32_nb[GGML_MAX_DIMS];
3058-
input_fp32_nb[0] = sizeof(float_t);
3059-
for (int i = 1; i < GGML_MAX_DIMS; i++) {
3060-
input_fp32_nb[i] = input_fp32_nb[i - 1] * dst->ne[i - 1];
3061-
}
3062-
ggml_cann_pool_alloc fp32_allocator1(
3063-
ctx.pool(), ggml_nelements(dst) * sizeof(float_t));
3064-
void* input_fp32_buffer1 = fp32_allocator1.get();
3065-
aclTensor* input_fp32_tensor1 = ggml_cann_create_tensor(
3066-
input_fp32_buffer1, ACL_FLOAT, sizeof(float_t), dst->ne,
3067-
input_fp32_nb, GGML_MAX_DIMS);
3068-
ggml_cann_pool_alloc fp32_allocator2(
3069-
ctx.pool(), ggml_nelements(dst) * sizeof(float_t));
3070-
void* input_fp32_buffer2 = fp32_allocator2.get();
3071-
aclTensor* input_fp32_tensor2 = ggml_cann_create_tensor(
3072-
input_fp32_buffer2, ACL_FLOAT, sizeof(float_t), dst->ne,
3073-
input_fp32_nb, GGML_MAX_DIMS);
3074-
3075-
ggml_cann_pool_alloc fp32_allocator(
3076-
ctx.pool(), ggml_nelements(dst) * sizeof(float_t));
3077-
output_fp32_buffer = fp32_allocator.get();
3078-
aclTensor* output_fp32_tensor = ggml_cann_create_tensor(
3079-
output_fp32_buffer, ACL_FLOAT, sizeof(float_t), dst->ne,
3080-
input_fp32_nb, GGML_MAX_DIMS);
3081-
aclnn_mul(ctx, acl_src0, acl_cos_reshape_tensor, input_fp32_tensor1);
3082-
aclnn_mul(ctx, acl_input_roll_mul_scale_tensor, acl_sin_reshape_tensor,
3083-
input_fp32_tensor2);
3084-
aclnn_add(ctx, input_fp32_tensor1, input_fp32_tensor2,
3085-
output_fp32_tensor);
3086-
aclnn_cast(ctx, output_fp32_tensor, acl_dst, ACL_FLOAT16);
3087-
3088-
ACL_CHECK(aclDestroyTensor(input_fp32_tensor1));
3089-
ACL_CHECK(aclDestroyTensor(input_fp32_tensor2));
3090-
ACL_CHECK(aclDestroyTensor(output_fp32_tensor));
2956+
ACL_CHECK(aclnnRotaryPositionEmbeddingGetWorkspaceSize(
2957+
acl_x, acl_cos_reshape_tensor, acl_sin_reshape_tensor, acl_mode, acl_dst, &workspaceSize, &executor));
2958+
if (workspaceSize > 0) {
2959+
ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
2960+
workspaceAddr = workspace_allocator.get();
30912961
}
30922962

3093-
ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor));
2963+
ACL_CHECK(aclnnRotaryPositionEmbedding(workspaceAddr, workspaceSize,
2964+
executor, ctx.stream()));
2965+
2966+
ACL_CHECK(aclDestroyTensor(acl_x));
30942967
ACL_CHECK(aclDestroyTensor(acl_cos_reshape_tensor));
3095-
ACL_CHECK(aclDestroyTensor(acl_minus_one_tensor));
3096-
ACL_CHECK(aclDestroyTensor(acl_input_roll_mul_scale_tensor));
3097-
ACL_CHECK(aclDestroyTensor(acl_input_roll_reshape_tensor));
3098-
ACL_CHECK(aclDestroyTensor(acl_src0));
2968+
ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor));
30992969
ACL_CHECK(aclDestroyTensor(acl_dst));
31002970
}

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1669,12 +1669,12 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
16691669
}
16701670
case GGML_OP_MUL_MAT: {
16711671
switch (op->src[0]->type) {
1672+
case GGML_TYPE_Q8_0:
1673+
if (op->src[0]->ne[0] <= QK8_0) {
1674+
return false;
1675+
}
16721676
case GGML_TYPE_F16:
16731677
case GGML_TYPE_F32:
1674-
case GGML_TYPE_Q8_0:
1675-
// TODO: fix me
1676-
// Current groupsize should not be greater than k-1 in
1677-
// aclnnWeightQuantBatchMatmulV2GetWorkspaceSize().
16781678
case GGML_TYPE_Q4_0:
16791679
return true;
16801680
default:
@@ -1706,9 +1706,56 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
17061706
return false;
17071707
}
17081708
}
1709+
case GGML_OP_IM2COL: {
1710+
switch (op->src[0]->type) {
1711+
case GGML_TYPE_F16:
1712+
return true;
1713+
default:
1714+
return false;
1715+
}
1716+
switch (op->src[1]->type) {
1717+
case GGML_TYPE_F32:
1718+
return true;
1719+
default:
1720+
return false;
1721+
}
1722+
}
1723+
case GGML_OP_CONT: {
1724+
switch (op->type) {
1725+
case GGML_TYPE_F32:
1726+
case GGML_TYPE_F16:
1727+
return true;
1728+
default:
1729+
return false;
1730+
}
1731+
}
1732+
case GGML_OP_ROPE: {
1733+
float freq_scale;
1734+
memcpy(&freq_scale, (int32_t*)op->op_params + 6, sizeof(float));
1735+
if (op->src[2] != NULL) {
1736+
return false;
1737+
}
1738+
if (op->src[0]->ne[0] != op->op_params[1]) {
1739+
return false;
1740+
}
1741+
1742+
if (op->op_params[7] != 0) {
1743+
return false;
1744+
}
1745+
if (freq_scale != 1) {
1746+
return false;
1747+
}
1748+
switch (op->src[0]->type) {
1749+
case GGML_TYPE_F32:
1750+
return true;
1751+
default:
1752+
return false;
1753+
}
1754+
}
1755+
case GGML_OP_CONCAT:
1756+
case GGML_OP_UPSCALE:
17091757
case GGML_OP_DUP:
17101758
case GGML_OP_REPEAT:
1711-
case GGML_OP_CONCAT:
17121759
case GGML_OP_NONE:
17131760
case GGML_OP_RESHAPE:
17141761
case GGML_OP_VIEW:
@@ -1722,17 +1769,13 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
17221769
case GGML_OP_SCALE:
17231770
case GGML_OP_SQR:
17241771
case GGML_OP_CLAMP:
1725-
case GGML_OP_CONT:
17261772
case GGML_OP_DIAG_MASK_INF:
17271773
case GGML_OP_SOFT_MAX:
1728-
case GGML_OP_ROPE:
1729-
case GGML_OP_IM2COL:
17301774
case GGML_OP_POOL_2D:
17311775
case GGML_OP_SUM_ROWS:
17321776
case GGML_OP_ARGSORT:
17331777
case GGML_OP_ACC:
17341778
case GGML_OP_GROUP_NORM:
1735-
case GGML_OP_UPSCALE:
17361779
case GGML_OP_PAD:
17371780
case GGML_OP_ARANGE:
17381781
case GGML_OP_TIMESTEP_EMBEDDING:

0 commit comments

Comments
 (0)