Skip to content

Commit a8ea03d

Browse files
authored
ggml : add ggml_repeat_4d (#13824)
1 parent 05f6ac6 commit a8ea03d

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

ggml/include/ggml.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,15 @@ extern "C" {
935935
struct ggml_tensor * a,
936936
struct ggml_tensor * b);
937937

938+
// repeat a to the specified shape
939+
GGML_API struct ggml_tensor * ggml_repeat_4d(
940+
struct ggml_context * ctx,
941+
struct ggml_tensor * a,
942+
int64_t ne0,
943+
int64_t ne1,
944+
int64_t ne2,
945+
int64_t ne3);
946+
938947
// sums repetitions in a into shape of b
939948
GGML_API struct ggml_tensor * ggml_repeat_back(
940949
struct ggml_context * ctx,

ggml/src/ggml.c

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2312,6 +2312,26 @@ struct ggml_tensor * ggml_repeat(
23122312
return result;
23132313
}
23142314

2315+
struct ggml_tensor * ggml_repeat_4d(
2316+
struct ggml_context * ctx,
2317+
struct ggml_tensor * a,
2318+
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) {
2319+
const bool can_repeat = ggml_is_empty(a) || (
2320+
(ne0 % a->ne[0] == 0) &&
2321+
(ne1 % a->ne[1] == 0) &&
2322+
(ne2 % a->ne[2] == 0) &&
2323+
(ne3 % a->ne[3] == 0)
2324+
);
2325+
GGML_ASSERT(can_repeat);
2326+
2327+
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3);
2328+
2329+
result->op = GGML_OP_REPEAT;
2330+
result->src[0] = a;
2331+
2332+
return result;
2333+
}
2334+
23152335
// ggml_repeat_back
23162336

23172337
struct ggml_tensor * ggml_repeat_back(

0 commit comments

Comments
 (0)