Skip to content

Commit bfaa676

Browse files
authored
ggml : improve ggml_is_contiguous logic (ggml-org#7856)
* ggml : improve ggml_is_contiguous logic ggml-ci * ggml : support more contiguous cases ggml-ci
1 parent 704a35b commit bfaa676

File tree

1 file changed

+35
-40
lines changed

1 file changed

+35
-40
lines changed

ggml.c

Lines changed: 35 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3212,35 +3212,42 @@ GGML_CALL bool ggml_is_transposed(const struct ggml_tensor * tensor) {
32123212
return tensor->nb[0] > tensor->nb[1];
32133213
}
32143214

3215-
GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
3216-
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3215+
static bool ggml_is_contiguous_n(const struct ggml_tensor * tensor, int n) {
3216+
size_t next_nb = ggml_type_size(tensor->type);
3217+
if (tensor->ne[0] != ggml_blck_size(tensor->type) && tensor->nb[0] != next_nb) {
3218+
return false;
3219+
}
3220+
next_nb *= tensor->ne[0]/ggml_blck_size(tensor->type);
3221+
for (int i = 1; i < GGML_MAX_DIMS; i++) {
3222+
if (tensor->ne[i] != 1) {
3223+
if (i > n) {
3224+
if (tensor->nb[i] != next_nb) {
3225+
return false;
3226+
}
3227+
next_nb *= tensor->ne[i];
3228+
} else {
3229+
// this dimension does not need to be contiguous
3230+
next_nb = tensor->ne[i]*tensor->nb[i];
3231+
}
3232+
}
3233+
}
3234+
return true;
3235+
}
32173236

3218-
return
3219-
tensor->nb[0] == ggml_type_size(tensor->type) &&
3220-
tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) &&
3221-
tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
3222-
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
3237+
GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
3238+
return ggml_is_contiguous_0(tensor);
32233239
}
32243240

32253241
GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor) {
3226-
return ggml_is_contiguous(tensor);
3242+
return ggml_is_contiguous_n(tensor, 0);
32273243
}
32283244

32293245
GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor) {
3230-
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3231-
3232-
return
3233-
tensor->nb[0] == ggml_type_size(tensor->type) &&
3234-
tensor->nb[2] == tensor->nb[1]*tensor->ne[1] &&
3235-
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
3246+
return ggml_is_contiguous_n(tensor, 1);
32363247
}
32373248

32383249
GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) {
3239-
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3240-
3241-
return
3242-
tensor->nb[0] == ggml_type_size(tensor->type) &&
3243-
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
3250+
return ggml_is_contiguous_n(tensor, 2);
32443251
}
32453252

32463253
GGML_CALL bool ggml_is_permuted(const struct ggml_tensor * tensor) {
@@ -3272,20 +3279,20 @@ bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor
32723279
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
32733280

32743281
return
3275-
(t0->ne[0] == t1->ne[0] ) &&
3276-
(t0->ne[1] == t1->ne[1] ) &&
3277-
(t0->ne[2] == t1->ne[2] ) &&
3278-
(t0->ne[3] == t1->ne[3] );
3282+
(t0->ne[0] == t1->ne[0]) &&
3283+
(t0->ne[1] == t1->ne[1]) &&
3284+
(t0->ne[2] == t1->ne[2]) &&
3285+
(t0->ne[3] == t1->ne[3]);
32793286
}
32803287

32813288
bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
32823289
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
32833290

32843291
return
3285-
(t0->nb[0] == t1->nb[0] ) &&
3286-
(t0->nb[1] == t1->nb[1] ) &&
3287-
(t0->nb[2] == t1->nb[2] ) &&
3288-
(t0->nb[3] == t1->nb[3] );
3292+
(t0->nb[0] == t1->nb[0]) &&
3293+
(t0->nb[1] == t1->nb[1]) &&
3294+
(t0->nb[2] == t1->nb[2]) &&
3295+
(t0->nb[3] == t1->nb[3]);
32893296
}
32903297

32913298
// check if t1 can be represented as a repeatition of t0
@@ -4078,32 +4085,26 @@ float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i) {
40784085
switch (tensor->type) {
40794086
case GGML_TYPE_I8:
40804087
{
4081-
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
40824088
return ((int8_t *)(tensor->data))[i];
40834089
}
40844090
case GGML_TYPE_I16:
40854091
{
4086-
GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
40874092
return ((int16_t *)(tensor->data))[i];
40884093
}
40894094
case GGML_TYPE_I32:
40904095
{
4091-
GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
40924096
return ((int32_t *)(tensor->data))[i];
40934097
}
40944098
case GGML_TYPE_F16:
40954099
{
4096-
GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
40974100
return GGML_FP16_TO_FP32(((ggml_fp16_t *)(tensor->data))[i]);
40984101
}
40994102
case GGML_TYPE_BF16:
41004103
{
4101-
GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
41024104
return GGML_BF16_TO_FP32(((ggml_bf16_t *)(tensor->data))[i]);
41034105
}
41044106
case GGML_TYPE_F32:
41054107
{
4106-
GGML_ASSERT(tensor->nb[0] == sizeof(float));
41074108
return ((float *)(tensor->data))[i];
41084109
}
41094110
default:
@@ -4125,32 +4126,26 @@ void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value) {
41254126
switch (tensor->type) {
41264127
case GGML_TYPE_I8:
41274128
{
4128-
GGML_ASSERT(tensor->nb[0] == sizeof(int8_t));
41294129
((int8_t *)(tensor->data))[i] = value;
41304130
} break;
41314131
case GGML_TYPE_I16:
41324132
{
4133-
GGML_ASSERT(tensor->nb[0] == sizeof(int16_t));
41344133
((int16_t *)(tensor->data))[i] = value;
41354134
} break;
41364135
case GGML_TYPE_I32:
41374136
{
4138-
GGML_ASSERT(tensor->nb[0] == sizeof(int32_t));
41394137
((int32_t *)(tensor->data))[i] = value;
41404138
} break;
41414139
case GGML_TYPE_F16:
41424140
{
4143-
GGML_ASSERT(tensor->nb[0] == sizeof(ggml_fp16_t));
41444141
((ggml_fp16_t *)(tensor->data))[i] = GGML_FP32_TO_FP16(value);
41454142
} break;
41464143
case GGML_TYPE_BF16:
41474144
{
4148-
GGML_ASSERT(tensor->nb[0] == sizeof(ggml_bf16_t));
41494145
((ggml_bf16_t *)(tensor->data))[i] = GGML_FP32_TO_BF16(value);
41504146
} break;
41514147
case GGML_TYPE_F32:
41524148
{
4153-
GGML_ASSERT(tensor->nb[0] == sizeof(float));
41544149
((float *)(tensor->data))[i] = value;
41554150
} break;
41564151
default:
@@ -7343,7 +7338,7 @@ struct ggml_tensor * ggml_add_rel_pos_inplace(
73437338
return ggml_add_rel_pos_impl(ctx, a, pw, ph, true);
73447339
}
73457340

7346-
// gmml_unary
7341+
// ggml_unary
73477342

73487343
static struct ggml_tensor * ggml_unary_impl(
73497344
struct ggml_context * ctx,

0 commit comments

Comments
 (0)