@@ -2859,15 +2859,27 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
2859
2859
ACL_CHECK (aclDestroyTensor (acl_cos_tensor));
2860
2860
}
2861
2861
2862
+ #ifdef __cplusplus
2863
+ extern " C" {
2864
+ #endif
2865
+ aclnnStatus aclnnRotaryPositionEmbeddingGetWorkspaceSize (
2866
+ const aclTensor* x, const aclTensor* cos, const aclTensor* sin,
2867
+ int64_t mode, const aclTensor* yOut, uint64_t * workspaceSize,
2868
+ aclOpExecutor** executor);
2869
+ aclnnStatus aclnnRotaryPositionEmbedding (void * workspace,
2870
+ uint64_t workspaceSize,
2871
+ aclOpExecutor* executor,
2872
+ aclrtStream stream);
2873
+ #ifdef __cplusplus
2874
+ }
2875
+ #endif
2876
+
2862
2877
void ggml_cann_rope (ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2863
2878
// TODO: use ascendc
2864
2879
// Only test with LLAMA model.
2865
2880
ggml_tensor* src0 = dst->src [0 ]; // input
2866
2881
ggml_tensor* src2 = dst->src [2 ]; // freq_factors
2867
2882
2868
- // TODO: with freq_factors
2869
- GGML_ASSERT (src2 == NULL );
2870
-
2871
2883
// param
2872
2884
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
2873
2885
// const int n_past = ((int32_t *) dst->op_params)[0];
@@ -2885,14 +2897,19 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2885
2897
memcpy (&beta_fast, (int32_t *)dst->op_params + 9 , sizeof (float ));
2886
2898
memcpy (&beta_slow, (int32_t *)dst->op_params + 10 , sizeof (float ));
2887
2899
2888
- GGML_ASSERT (n_dims <= ne0);
2900
+ // TODO: with freq_factors
2901
+ GGML_ASSERT (src2 == NULL );
2902
+
2903
+ GGML_ASSERT (n_dims == ne0);
2889
2904
GGML_ASSERT (n_dims % 2 == 0 );
2890
2905
2891
2906
// TODO: ext_factor != 0
2892
2907
GGML_ASSERT (ext_factor == 0 );
2893
2908
// TODO: freq_scale != 1
2894
2909
GGML_ASSERT (freq_scale == 1 );
2895
2910
2911
+ GGML_ASSERT (src0->type == GGML_TYPE_F32);
2912
+
2896
2913
const float theta_scale = powf (freq_base, -2 .0f / n_dims);
2897
2914
2898
2915
float corr_dims[2 ];
@@ -2924,177 +2941,30 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2924
2941
aclnn_cache_init (ctx, dst, acl_cos_reshape_tensor, acl_sin_reshape_tensor,
2925
2942
theta_scale, is_neox);
2926
2943
2927
- // roll input
2928
- void * input_roll_buffer;
2929
- aclTensor* acl_minus_one_tensor;
2930
- void * minus_one_scale_buffer = nullptr ;
2931
- ggml_cann_pool_alloc roll_allocator (ctx.pool (), ggml_nbytes (src0));
2932
- ggml_cann_pool_alloc minus_one_scale_allocator (
2933
- ctx.pool (), sizeof (float_t ) * src0->ne [0 ]);
2934
- if (!is_neox) {
2935
- // roll input: [q0,q1,q2,q3,...] -> [q1,q0,q3,q2,...]
2936
- input_roll_buffer = roll_allocator.get ();
2937
- int64_t input_roll_ne[4 ] = {2 , src0->ne [1 ] * (src0->ne [0 ] / 2 ),
2938
- src0->ne [2 ], src0->ne [3 ]};
2939
- size_t input_roll_nb[GGML_MAX_DIMS];
2940
- input_roll_nb[0 ] = ggml_type_size (src0->type );
2941
- for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
2942
- input_roll_nb[i] = input_roll_nb[i - 1 ] * input_roll_ne[i - 1 ];
2943
- }
2944
- aclTensor* acl_input_roll_tensor = ggml_cann_create_tensor (
2945
- input_roll_buffer, ggml_cann_type_mapping (src0->type ),
2946
- ggml_type_size (src0->type ), input_roll_ne, input_roll_nb,
2947
- GGML_MAX_DIMS);
2948
- aclTensor* acl_input_tensor = ggml_cann_create_tensor (
2949
- src0->data , ggml_cann_type_mapping (src0->type ),
2950
- ggml_type_size (src0->type ), input_roll_ne, input_roll_nb,
2951
- GGML_MAX_DIMS);
2952
-
2953
- int64_t shifts[] = {1 };
2954
- int64_t dims[] = {3 };
2955
- aclnn_roll (ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims);
2956
- ACL_CHECK (aclDestroyTensor (acl_input_roll_tensor));
2957
- ACL_CHECK (aclDestroyTensor (acl_input_tensor));
2958
-
2959
- // init [-1, 1, -1, 1, ...]
2960
- minus_one_scale_buffer = minus_one_scale_allocator.get ();
2961
-
2962
- int64_t minus_one_ne[4 ] = {src0->ne [0 ], 1 , 1 , 1 };
2963
- size_t minus_one_nb[GGML_MAX_DIMS];
2964
- minus_one_nb[0 ] = sizeof (float_t );
2965
- for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
2966
- minus_one_nb[i] = minus_one_nb[i - 1 ] * minus_one_ne[i - 1 ];
2967
- }
2968
- acl_minus_one_tensor = aclnn_ones (
2969
- ctx, minus_one_scale_buffer, sizeof (float_t ) * src0->ne [0 ],
2970
- minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof (float_t ), 1 );
2971
- int64_t dim = 3 ;
2972
- int64_t * index = new int64_t [src0->ne [0 ]];
2973
- for (int i = 0 ; i < src0->ne [0 ]; i++) {
2974
- index[i] = i / 2 * 2 ;
2975
- }
2976
- int64_t index_num = src0->ne [0 ];
2977
- float value = -1 ;
2978
- aclnn_index_fill_tensor (ctx, acl_minus_one_tensor, dim, index,
2979
- index_num, value);
2980
- } else {
2981
- // roll input: [q0,q1,q2,...] ->
2982
- // [q_half,q_half+1,...,q_end,q0,q1,...q_half-1]
2983
- input_roll_buffer = roll_allocator.get ();
2984
- aclTensor* acl_input_roll_tensor = ggml_cann_create_tensor (
2985
- input_roll_buffer, ggml_cann_type_mapping (src0->type ),
2986
- ggml_type_size (src0->type ), src0->ne , src0->nb , GGML_MAX_DIMS);
2987
- aclTensor* acl_input_tensor = ggml_cann_create_tensor (src0);
2988
-
2989
- int64_t shifts[] = {src0->ne [0 ] / 2 };
2990
- int64_t dims[] = {3 };
2991
- aclnn_roll (ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims);
2992
-
2993
- ACL_CHECK (aclDestroyTensor (acl_input_roll_tensor));
2994
- ACL_CHECK (aclDestroyTensor (acl_input_tensor));
2995
-
2996
- // init [-1, -1, -1, 1, 1,1,...]
2997
- minus_one_scale_buffer = minus_one_scale_allocator.get ();
2944
+ uint64_t workspaceSize = 0 ;
2945
+ aclOpExecutor* executor;
2998
2946
2999
- int64_t minus_one_ne[4 ] = {src0->ne [0 ], 1 , 1 , 1 };
3000
- size_t minus_one_nb[GGML_MAX_DIMS];
3001
- minus_one_nb[0 ] = sizeof (float_t );
3002
- for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
3003
- minus_one_nb[i] = minus_one_nb[i - 1 ] * minus_one_ne[i - 1 ];
3004
- }
3005
- acl_minus_one_tensor = aclnn_ones (
3006
- ctx, minus_one_scale_buffer, sizeof (float_t ) * src0->ne [0 ],
3007
- minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof (float_t ), 1 );
3008
- // -1 * first half
3009
- int64_t first_half_ne[4 ] = {src0->ne [0 ] / 2 , 1 , 1 , 1 };
3010
- size_t first_half_nb[GGML_MAX_DIMS];
3011
- first_half_nb[0 ] = sizeof (float_t );
3012
- for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
3013
- first_half_nb[i] = first_half_nb[i - 1 ] * first_half_ne[i - 1 ];
3014
- }
3015
- aclTensor* acl_first_half_tensor = ggml_cann_create_tensor (
3016
- minus_one_scale_buffer, ACL_FLOAT, sizeof (float_t ), first_half_ne,
3017
- first_half_nb, GGML_MAX_DIMS);
3018
- bool inplace = true ;
3019
- float scale = -1 ;
3020
- aclnn_muls (ctx, acl_first_half_tensor, scale, nullptr , inplace);
3021
- ACL_CHECK (aclDestroyTensor (acl_first_half_tensor));
3022
- }
3023
-
3024
- // TODO: n_dims < ne0
3025
- GGML_ASSERT (n_dims == src0->ne [0 ]);
3026
-
3027
- // input * scale
3028
- ggml_cann_pool_alloc roll_mul_scale_allocator (ctx.pool (),
3029
- ggml_nbytes (src0));
3030
- void * input_roll_mul_scale_buffer = roll_mul_scale_allocator.get ();
3031
- size_t input_nb[GGML_MAX_DIMS];
3032
- input_nb[0 ] = ggml_type_size (src0->type );
3033
- for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
3034
- input_nb[i] = input_nb[i - 1 ] * src0->ne [i - 1 ];
3035
- }
3036
- aclTensor* acl_input_roll_mul_scale_tensor = ggml_cann_create_tensor (
3037
- input_roll_mul_scale_buffer, ggml_cann_type_mapping (src0->type ),
3038
- ggml_type_size (src0->type ), src0->ne , input_nb, GGML_MAX_DIMS);
3039
- aclTensor* acl_input_roll_reshape_tensor = ggml_cann_create_tensor (
3040
- input_roll_buffer, ggml_cann_type_mapping (src0->type ),
3041
- ggml_type_size (src0->type ), src0->ne , input_nb, GGML_MAX_DIMS);
2947
+ void * workspaceAddr = nullptr ;
3042
2948
3043
- aclnn_mul (ctx, acl_input_roll_reshape_tensor, acl_minus_one_tensor,
3044
- acl_input_roll_mul_scale_tensor);
2949
+ int acl_mode = mode;
2950
+ if (mode == 0 ) {
2951
+ acl_mode = 1 ;
2952
+ }
3045
2953
3046
- // output
3047
- aclTensor* acl_src0 = ggml_cann_create_tensor (src0);
2954
+ aclTensor* acl_x = ggml_cann_create_tensor (src0);
3048
2955
aclTensor* acl_dst = ggml_cann_create_tensor (dst);
3049
- void * output_fp32_buffer;
3050
- if (src0->type == GGML_TYPE_F32) {
3051
- aclnn_inplace_mul (ctx, acl_src0, acl_cos_reshape_tensor);
3052
- aclnn_inplace_mul (ctx, acl_input_roll_mul_scale_tensor,
3053
- acl_sin_reshape_tensor);
3054
- aclnn_add (ctx, acl_src0, acl_input_roll_mul_scale_tensor, acl_dst);
3055
- // TODO: ne0 != n_dims in mode2
3056
- } else if (src0->type == GGML_TYPE_F16) {
3057
- size_t input_fp32_nb[GGML_MAX_DIMS];
3058
- input_fp32_nb[0 ] = sizeof (float_t );
3059
- for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
3060
- input_fp32_nb[i] = input_fp32_nb[i - 1 ] * dst->ne [i - 1 ];
3061
- }
3062
- ggml_cann_pool_alloc fp32_allocator1 (
3063
- ctx.pool (), ggml_nelements (dst) * sizeof (float_t ));
3064
- void * input_fp32_buffer1 = fp32_allocator1.get ();
3065
- aclTensor* input_fp32_tensor1 = ggml_cann_create_tensor (
3066
- input_fp32_buffer1, ACL_FLOAT, sizeof (float_t ), dst->ne ,
3067
- input_fp32_nb, GGML_MAX_DIMS);
3068
- ggml_cann_pool_alloc fp32_allocator2 (
3069
- ctx.pool (), ggml_nelements (dst) * sizeof (float_t ));
3070
- void * input_fp32_buffer2 = fp32_allocator2.get ();
3071
- aclTensor* input_fp32_tensor2 = ggml_cann_create_tensor (
3072
- input_fp32_buffer2, ACL_FLOAT, sizeof (float_t ), dst->ne ,
3073
- input_fp32_nb, GGML_MAX_DIMS);
3074
-
3075
- ggml_cann_pool_alloc fp32_allocator (
3076
- ctx.pool (), ggml_nelements (dst) * sizeof (float_t ));
3077
- output_fp32_buffer = fp32_allocator.get ();
3078
- aclTensor* output_fp32_tensor = ggml_cann_create_tensor (
3079
- output_fp32_buffer, ACL_FLOAT, sizeof (float_t ), dst->ne ,
3080
- input_fp32_nb, GGML_MAX_DIMS);
3081
- aclnn_mul (ctx, acl_src0, acl_cos_reshape_tensor, input_fp32_tensor1);
3082
- aclnn_mul (ctx, acl_input_roll_mul_scale_tensor, acl_sin_reshape_tensor,
3083
- input_fp32_tensor2);
3084
- aclnn_add (ctx, input_fp32_tensor1, input_fp32_tensor2,
3085
- output_fp32_tensor);
3086
- aclnn_cast (ctx, output_fp32_tensor, acl_dst, ACL_FLOAT16);
3087
-
3088
- ACL_CHECK (aclDestroyTensor (input_fp32_tensor1));
3089
- ACL_CHECK (aclDestroyTensor (input_fp32_tensor2));
3090
- ACL_CHECK (aclDestroyTensor (output_fp32_tensor));
2956
+ ACL_CHECK (aclnnRotaryPositionEmbeddingGetWorkspaceSize (
2957
+ acl_x, acl_cos_reshape_tensor, acl_sin_reshape_tensor, acl_mode, acl_dst, &workspaceSize, &executor));
2958
+ if (workspaceSize > 0 ) {
2959
+ ggml_cann_pool_alloc workspace_allocator (ctx.pool (), workspaceSize);
2960
+ workspaceAddr = workspace_allocator.get ();
3091
2961
}
3092
2962
3093
- ACL_CHECK (aclDestroyTensor (acl_sin_reshape_tensor));
2963
+ ACL_CHECK (aclnnRotaryPositionEmbedding (workspaceAddr, workspaceSize,
2964
+ executor, ctx.stream ()));
2965
+
2966
+ ACL_CHECK (aclDestroyTensor (acl_x));
3094
2967
ACL_CHECK (aclDestroyTensor (acl_cos_reshape_tensor));
3095
- ACL_CHECK (aclDestroyTensor (acl_minus_one_tensor));
3096
- ACL_CHECK (aclDestroyTensor (acl_input_roll_mul_scale_tensor));
3097
- ACL_CHECK (aclDestroyTensor (acl_input_roll_reshape_tensor));
3098
- ACL_CHECK (aclDestroyTensor (acl_src0));
2968
+ ACL_CHECK (aclDestroyTensor (acl_sin_reshape_tensor));
3099
2969
ACL_CHECK (aclDestroyTensor (acl_dst));
3100
2970
}
0 commit comments