Skip to content

Commit bec0ceb

Browse files
committed
ggml : add repeat impl for i64
1 parent 70e3d27 commit bec0ceb

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

examples/eval-callback/eval-callback.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne
5555
v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
5656
} else if (type == GGML_TYPE_F32) {
5757
v = *(float *) &data[i];
58+
} else if (type == GGML_TYPE_I64) {
59+
v = (float) *(int64_t *) &data[i];
5860
} else if (type == GGML_TYPE_I32) {
5961
v = (float) *(int32_t *) &data[i];
6062
} else if (type == GGML_TYPE_I16) {

ggml/src/ggml-cpu/ops.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2282,6 +2282,52 @@ static void ggml_compute_forward_repeat_f16(
22822282
}
22832283
}
22842284

2285+
static void ggml_compute_forward_repeat_i64(
2286+
const ggml_compute_params * params,
2287+
ggml_tensor * dst) {
2288+
2289+
const ggml_tensor * src0 = dst->src[0];
2290+
2291+
if (params->ith != 0) {
2292+
return;
2293+
}
2294+
2295+
GGML_ASSERT(ggml_can_repeat(src0, dst));
2296+
2297+
GGML_TENSOR_UNARY_OP_LOCALS
2298+
2299+
// guaranteed to be an integer due to the check in ggml_can_repeat
2300+
const int nr0 = (int)(ne0/ne00);
2301+
const int nr1 = (int)(ne1/ne01);
2302+
const int nr2 = (int)(ne2/ne02);
2303+
const int nr3 = (int)(ne3/ne03);
2304+
2305+
// TODO: support for transposed / permuted tensors
2306+
GGML_ASSERT(nb0 == sizeof(int64_t));
2307+
GGML_ASSERT(nb00 == sizeof(int64_t));
2308+
2309+
// TODO: maybe this is not optimal?
2310+
for (int i3 = 0; i3 < nr3; i3++) {
2311+
for (int k3 = 0; k3 < ne03; k3++) {
2312+
for (int i2 = 0; i2 < nr2; i2++) {
2313+
for (int k2 = 0; k2 < ne02; k2++) {
2314+
for (int i1 = 0; i1 < nr1; i1++) {
2315+
for (int k1 = 0; k1 < ne01; k1++) {
2316+
for (int i0 = 0; i0 < nr0; i0++) {
2317+
int64_t * y = (int64_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0);
2318+
int64_t * x = (int64_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01);
2319+
for (int i = 0; i < ne00; ++i) {
2320+
y[i] = x[i];
2321+
}
2322+
}
2323+
}
2324+
}
2325+
}
2326+
}
2327+
}
2328+
}
2329+
}
2330+
22852331
void ggml_compute_forward_repeat(
22862332
const ggml_compute_params * params,
22872333
ggml_tensor * dst) {
@@ -2300,6 +2346,10 @@ void ggml_compute_forward_repeat(
23002346
{
23012347
ggml_compute_forward_repeat_f32(params, dst);
23022348
} break;
2349+
case GGML_TYPE_I64:
2350+
{
2351+
ggml_compute_forward_repeat_i64(params, dst);
2352+
} break;
23032353
default:
23042354
{
23052355
GGML_ABORT("fatal error");

0 commit comments

Comments
 (0)