Skip to content

Commit b9110bb

Browse files
slarenarthw
authored andcommitted
llama : fix buffer checks for mamba and rwk (ggml-org#10111)
* llama : fix buffer checks for mamba and rwk * llama : fix missing worst case flag during reserve * cuda : fix supports_op for norm * disable sched SET_CAUSE
1 parent 83ad0ab commit b9110bb

File tree

4 files changed

+35
-12
lines changed

4 files changed

+35
-12
lines changed

ggml/src/ggml-backend.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1508,7 +1508,7 @@ static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, co
15081508
return -1;
15091509
}
15101510

1511-
#if 1
1511+
#if 0
15121512
#define GGML_SCHED_MAX_SPLITS_DEBUG 4096
15131513
static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_SCHED_MAX_SPLITS_DEBUG*GGML_SCHED_MAX_SPLIT_INPUTS][128]; // debug only
15141514
#define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__)

ggml/src/ggml-cuda.cu

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3107,18 +3107,20 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
31073107
}
31083108
return false;
31093109
} break;
3110+
case GGML_OP_NORM:
3111+
case GGML_OP_RMS_NORM:
3112+
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
3113+
break;
31103114
case GGML_OP_NONE:
31113115
case GGML_OP_RESHAPE:
31123116
case GGML_OP_VIEW:
31133117
case GGML_OP_PERMUTE:
31143118
case GGML_OP_TRANSPOSE:
3115-
case GGML_OP_NORM:
31163119
case GGML_OP_ADD:
31173120
case GGML_OP_ADD1:
31183121
case GGML_OP_SUB:
31193122
case GGML_OP_MUL:
31203123
case GGML_OP_DIV:
3121-
case GGML_OP_RMS_NORM:
31223124
case GGML_OP_SCALE:
31233125
case GGML_OP_SQR:
31243126
case GGML_OP_SQRT:

ggml/src/ggml.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7272,6 +7272,7 @@ struct ggml_tensor * ggml_ssm_conv(
72727272
const int64_t n_s = sx->ne[2];
72737273

72747274
// TODO: maybe support other strides than 1?
7275+
// FIXME: this is always true?
72757276
GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t);
72767277
GGML_ASSERT(sx->ne[1] == d_inner);
72777278
GGML_ASSERT(n_t >= 0);

src/llama.cpp

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7131,7 +7131,7 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
71317131
} break;
71327132
case GGML_OP_MUL_MAT:
71337133
{
7134-
ggml_tensor * b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, w->ne[0], 512);
7134+
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], 512, w->ne[2], w->ne[3]);
71357135
op_tensor = ggml_mul_mat(ctx, w, b);
71367136
} break;
71377137
case GGML_OP_MUL_MAT_ID:
@@ -7171,18 +7171,38 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
71717171
} break;
71727172
case GGML_OP_SSM_CONV:
71737173
{
7174-
// TODO: ggml_ssm_conv(ctx, conv_x, model.layers[il].ssm_conv1d);
7175-
op_tensor = ggml_ssm_conv(ctx, nullptr, w);
7174+
// FIXME
7175+
ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 12345, w->ne[1], 6789);
7176+
op_tensor = ggml_ssm_conv(ctx, conv_x, w);
71767177
} break;
71777178
case GGML_OP_SSM_SCAN:
71787179
{
7179-
// TODO: ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C);
7180-
op_tensor = ggml_ssm_scan(ctx, nullptr, nullptr, nullptr, w, nullptr, nullptr);
7180+
// FIXME
7181+
const int64_t d_state = w->ne[0];
7182+
const int64_t d_inner = w->ne[1];
7183+
const int64_t n_seq_tokens = 512;
7184+
const int64_t n_seqs = 1;
7185+
ggml_tensor * s = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, d_inner, n_seqs);
7186+
ggml_tensor * x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs);
7187+
ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs);
7188+
ggml_tensor * B = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs);
7189+
ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs);
7190+
op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C);
71817191
} break;
71827192
case GGML_OP_RWKV_WKV:
71837193
{
7184-
// TODO: ggml_rwkv_wkv(ctx, k, v, r, layer->time_mix_first, w, *wkv_state);
7185-
op_tensor = ggml_rwkv_wkv(ctx, nullptr, nullptr, nullptr, w, nullptr, nullptr);
7194+
// FIXME
7195+
const int64_t S = 123;
7196+
const int64_t H = 123;
7197+
const int64_t n_tokens = 123;
7198+
const int64_t n_seqs = 123;
7199+
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, 1, H, n_tokens);
7200+
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
7201+
ggml_tensor * r = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
7202+
ggml_tensor * tf = w;
7203+
ggml_tensor * td = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, S, H, n_tokens);
7204+
ggml_tensor * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H);
7205+
op_tensor = ggml_rwkv_wkv(ctx, k, v, r, tf, td, state);
71867206
} break;
71877207
default:
71887208
GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name);
@@ -7462,7 +7482,7 @@ static bool llm_load_tensors(
74627482

74637483
// tensors with "bias" suffix are always used with GGML_OP_ADD
74647484
ggml_op op;
7465-
bool bias = strcmp(tn.suffix, "bias") == 0;
7485+
bool bias = tn.suffix != nullptr && strcmp(tn.suffix, "bias") == 0;
74667486
if (bias) {
74677487
op = GGML_OP_ADD;
74687488
} else {
@@ -19690,7 +19710,7 @@ struct llama_context * llama_new_context_with_model(
1969019710
int n_nodes_tg = ggml_graph_n_nodes(gf_tg);
1969119711

1969219712
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
19693-
gf_pp = llama_build_graph(*ctx, ubatch_pp, false);
19713+
gf_pp = llama_build_graph(*ctx, ubatch_pp, true);
1969419714
if (!ggml_backend_sched_reserve(ctx->sched, gf_pp)) {
1969519715
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
1969619716
llama_free(ctx);

0 commit comments

Comments
 (0)