@@ -2282,6 +2282,52 @@ static void ggml_compute_forward_repeat_f16(
2282
2282
}
2283
2283
}
2284
2284
2285
+ static void ggml_compute_forward_repeat_i64 (
2286
+ const ggml_compute_params * params,
2287
+ ggml_tensor * dst) {
2288
+
2289
+ const ggml_tensor * src0 = dst->src [0 ];
2290
+
2291
+ if (params->ith != 0 ) {
2292
+ return ;
2293
+ }
2294
+
2295
+ GGML_ASSERT (ggml_can_repeat (src0, dst));
2296
+
2297
+ GGML_TENSOR_UNARY_OP_LOCALS
2298
+
2299
+ // guaranteed to be an integer due to the check in ggml_can_repeat
2300
+ const int nr0 = (int )(ne0/ne00);
2301
+ const int nr1 = (int )(ne1/ne01);
2302
+ const int nr2 = (int )(ne2/ne02);
2303
+ const int nr3 = (int )(ne3/ne03);
2304
+
2305
+ // TODO: support for transposed / permuted tensors
2306
+ GGML_ASSERT (nb0 == sizeof (int64_t ));
2307
+ GGML_ASSERT (nb00 == sizeof (int64_t ));
2308
+
2309
+ // TODO: maybe this is not optimal?
2310
+ for (int i3 = 0 ; i3 < nr3; i3++) {
2311
+ for (int k3 = 0 ; k3 < ne03; k3++) {
2312
+ for (int i2 = 0 ; i2 < nr2; i2++) {
2313
+ for (int k2 = 0 ; k2 < ne02; k2++) {
2314
+ for (int i1 = 0 ; i1 < nr1; i1++) {
2315
+ for (int k1 = 0 ; k1 < ne01; k1++) {
2316
+ for (int i0 = 0 ; i0 < nr0; i0++) {
2317
+ int64_t * y = (int64_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0);
2318
+ int64_t * x = (int64_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01);
2319
+ for (int i = 0 ; i < ne00; ++i) {
2320
+ y[i] = x[i];
2321
+ }
2322
+ }
2323
+ }
2324
+ }
2325
+ }
2326
+ }
2327
+ }
2328
+ }
2329
+ }
2330
+
2285
2331
void ggml_compute_forward_repeat (
2286
2332
const ggml_compute_params * params,
2287
2333
ggml_tensor * dst) {
@@ -2300,6 +2346,10 @@ void ggml_compute_forward_repeat(
2300
2346
{
2301
2347
ggml_compute_forward_repeat_f32 (params, dst);
2302
2348
} break ;
2349
+ case GGML_TYPE_I64:
2350
+ {
2351
+ ggml_compute_forward_repeat_i64 (params, dst);
2352
+ } break ;
2303
2353
default :
2304
2354
{
2305
2355
GGML_ABORT (" fatal error" );
0 commit comments