Skip to content

Commit 90172d5

Browse files
committed
CANN: fix backend ops fail
1 parent 7932b93 commit 90172d5

File tree

3 files changed

+19
-21
lines changed

3 files changed

+19
-21
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -358,8 +358,6 @@ void ggml_cann_sqr(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
358358

359359
void ggml_cann_clamp(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
360360
ggml_tensor* src = dst->src[0];
361-
GGML_ASSERT(src->type == GGML_TYPE_F32);
362-
GGML_ASSERT(dst->type == GGML_TYPE_F32);
363361

364362
float min;
365363
float max;
@@ -1090,8 +1088,6 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
10901088
float eps;
10911089
memcpy(&eps, dst->op_params, sizeof(float));
10921090

1093-
GGML_ASSERT(eps > 0.0f);
1094-
10951091
uint64_t workspaceSize = 0;
10961092
aclOpExecutor* executor;
10971093
void* workspaceAddr = nullptr;
@@ -3152,7 +3148,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
31523148
// TODO: use ascendc
31533149
// Only test with LLAMA model.
31543150
ggml_tensor* src0 = dst->src[0]; // input
3155-
ggml_tensor* src2 = dst->src[2]; // freq_factors
3151+
// ggml_tensor* src2 = dst->src[2]; // freq_factors, not used now.
31563152

31573153
// param
31583154
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;

ggml/src/ggml-cann/aclnn_ops.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -535,9 +535,6 @@ template <aclnnStatus getWorkspaceSize(const aclTensor*, aclTensor*, uint64_t*,
535535
void ggml_cann_activation(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
536536
ggml_tensor* src = dst->src[0];
537537

538-
GGML_ASSERT(src->type == GGML_TYPE_F32);
539-
GGML_ASSERT(dst->type == GGML_TYPE_F32);
540-
541538
aclTensor* acl_src = ggml_cann_create_tensor(src);
542539
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
543540

@@ -566,9 +563,6 @@ template <aclnnStatus getWorkspaceSize(const aclTensor*, const aclTensor*,
566563
void ggml_cann_activation(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
567564
ggml_tensor* src = dst->src[0];
568565

569-
GGML_ASSERT(src->type == GGML_TYPE_F32);
570-
GGML_ASSERT(dst->type == GGML_TYPE_F32);
571-
572566
aclTensor* acl_src = ggml_cann_create_tensor(src);
573567
aclTensor* acl_dst = ggml_cann_create_tensor(dst);
574568

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,11 +1458,6 @@ static void ggml_backend_cann_free(ggml_backend_t backend) {
14581458
ACL_CHECK(aclrtSynchronizeDevice());
14591459
ACL_CHECK(aclrtResetDevice(cann_ctx->device));
14601460

1461-
// finalize when last backend freed.
1462-
if (cann_ctx->device == ggml_backend_cann_get_device_count() - 1) {
1463-
ACL_CHECK(aclFinalize());
1464-
}
1465-
14661461
delete cann_ctx;
14671462
delete backend;
14681463
}
@@ -1688,11 +1683,14 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
16881683
}
16891684
case GGML_OP_MUL_MAT: {
16901685
switch (op->src[0]->type) {
1691-
case GGML_TYPE_Q8_0:
16921686
case GGML_TYPE_F16:
16931687
case GGML_TYPE_F32:
1694-
case GGML_TYPE_Q4_0:
16951688
return true;
1689+
case GGML_TYPE_Q8_0:
1690+
case GGML_TYPE_Q4_0:
1691+
// only support contiguous for quantized types.
1692+
return ggml_is_contiguous(op->src[0]) &&
1693+
ggml_is_contiguous(op->src[1]);
16961694
default:
16971695
return false;
16981696
}
@@ -1738,13 +1736,14 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
17381736
}
17391737
case GGML_OP_ROPE: {
17401738
// TODO: with ops-test v == 1
1741-
float * ext_factor = (float*)((int32_t*)op->op_params + 7);
1739+
float ext_factor = 0.0f;
1740+
memcpy(&ext_factor, (const float *) op->op_params + 7, sizeof(float));
17421741
// TODO: n_dims <= ne0
17431742
if (op->src[0]->ne[0] != op->op_params[1]) {
17441743
return false;
17451744
}
17461745
// TODO: ext_factor != 0
1747-
if (*ext_factor != 0) {
1746+
if (ext_factor != 0) {
17481747
return false;
17491748
}
17501749

@@ -1766,6 +1765,16 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
17661765
}
17671766
return true;
17681767
}
1768+
case GGML_OP_POOL_2D: {
1769+
const int32_t * opts = (const int32_t *) op->op_params;
1770+
const int k0 = opts[1];
1771+
const int k1 = opts[2];
1772+
const int p0 = opts[5];
1773+
const int p1 = opts[6];
1774+
// value of paddingH should be at most half of kernelH
1775+
// value of paddingW should be at most half of kernelW
1776+
return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2));
1777+
}
17691778
case GGML_OP_DUP:
17701779
case GGML_OP_IM2COL:
17711780
case GGML_OP_CONCAT:
@@ -1785,7 +1794,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
17851794
case GGML_OP_CLAMP:
17861795
case GGML_OP_DIAG_MASK_INF:
17871796
case GGML_OP_SOFT_MAX:
1788-
case GGML_OP_POOL_2D:
17891797
case GGML_OP_SUM_ROWS:
17901798
case GGML_OP_ARGSORT:
17911799
case GGML_OP_ACC:

0 commit comments

Comments
 (0)