Skip to content

Commit 7ca81e9

Browse files
committed
mtl : add reshape and transpose handling
1 parent 1213af7 commit 7ca81e9

File tree

3 files changed

+11
-7
lines changed

3 files changed

+11
-7
lines changed

examples/mtl/mtl.m

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ int llama_mtl_eval(
258258

259259
switch (gf->nodes[i]->op) {
260260
case GGML_OP_RESHAPE:
261+
case GGML_OP_TRANSPOSE:
261262
{
262263
// noop
263264
} break;

ggml.c

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15011,15 +15011,19 @@ struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context **
1501115011

1501215012
// create the tensor
1501315013
// "view" operations are handled differently
15014+
// TODO: handle inplac ops - currentl a copy is always made
1501415015

1501515016
struct ggml_tensor * tensor = NULL;
1501615017

1501715018
switch (eop) {
1501815019
// TODO: implement other view ops
1501915020
case GGML_OP_RESHAPE:
1502015021
{
15021-
// TODO: implement other dims
15022-
tensor = ggml_reshape_3d(*ctx_eval, args[0], ne[0], ne[1], ne[2]);
15022+
tensor = ggml_reshape_4d(*ctx_eval, args[0], ne[0], ne[1], ne[2], ne[3]);
15023+
} break;
15024+
case GGML_OP_TRANSPOSE:
15025+
{
15026+
tensor = ggml_transpose(*ctx_eval, args[0]);
1502315027
} break;
1502415028
default:
1502515029
{

llama.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,15 +1279,14 @@ static bool llama_eval_internal(
12791279
ggml_set_name(Qcur, "Qcur");
12801280
ggml_set_name(Kcur, "Kcur");
12811281

1282-
// TODO: TMP !!!!
1283-
if (il == 0) {
1284-
ggml_set_name(Qcur, "mtl-check");
1285-
}
1286-
12871282
// store key and value to memory
12881283
{
12891284
// compute the transposed [N, n_embd] V matrix
12901285
struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), n_embd, N));
1286+
// TODO: TMP !!!!
1287+
if (il == 0) {
1288+
ggml_set_name(Vcur, "mtl-check");
1289+
}
12911290

12921291
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
12931292
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,

0 commit comments

Comments
 (0)