Skip to content

Commit ae806f4

Browse files
llama_opt_param_filter
1 parent 899f7a2 commit ae806f4

File tree

6 files changed

+89
-46
lines changed

6 files changed

+89
-46
lines changed

examples/training/finetune.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ int main(int argc, char ** argv) {
8686

8787
struct llama_opt_params lopt_params {
8888
/*n_ctx_train =*/ 0,
89+
/*param_filter =*/ llama_opt_param_filter_all,
90+
/*param_filter_ud =*/ nullptr,
8991
/*get_opt_pars =*/ ggml_opt_get_constant_optimizer_params,
9092
/*get_opt_pars_ud =*/ &optimizer_params,
9193
};

ggml/include/ggml.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -773,7 +773,7 @@ extern "C" {
773773
// Tensor flags
774774
GGML_API void ggml_set_input(struct ggml_tensor * tensor);
775775
GGML_API void ggml_set_output(struct ggml_tensor * tensor);
776-
GGML_API void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor);
776+
GGML_API void ggml_set_param(struct ggml_tensor * tensor);
777777
GGML_API void ggml_set_loss(struct ggml_tensor * tensor);
778778

779779
//

ggml/src/ggml.c

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6163,8 +6163,7 @@ void ggml_set_output(struct ggml_tensor * tensor) {
61636163
tensor->flags |= GGML_TENSOR_FLAG_OUTPUT;
61646164
}
61656165

6166-
void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor) {
6167-
GGML_UNUSED(ctx); // TODO: remove this parameter
6166+
void ggml_set_param(struct ggml_tensor * tensor) {
61686167
tensor->flags |= GGML_TENSOR_FLAG_PARAM;
61696168
}
61706169

include/llama.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,9 +1259,18 @@ extern "C" {
12591259
// training
12601260
//
12611261

1262+
// function that returns whether or not a given tensor is a trainable parameter
1263+
typedef bool (*llama_opt_param_filter)(const struct ggml_tensor * tensor, void * userdata);
1264+
1265+
// always returns true
1266+
bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata);
1267+
12621268
struct llama_opt_params {
12631269
uint32_t n_ctx_train; // assumed context size post training, use context size specified in llama_context if 0
12641270

1271+
llama_opt_param_filter param_filter; // callback for determining which tensors are trainable parameters
1272+
void * param_filter_ud; // userdata for determining which tensors are trainable parameters
1273+
12651274
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
12661275
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
12671276
};

src/llama.cpp

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4830,7 +4830,6 @@ struct llama_model_loader {
48304830
n_created++;
48314831
}
48324832

4833-
ggml_set_param(nullptr, tensor);
48344833
return tensor;
48354834

48364835
}
@@ -22636,10 +22635,20 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void *
2263622635
// training
2263722636
//
2263822637

22639-
static struct ggml_opt_optimizer_params llama_get_default_optimizer_params(void * userdata) {
22640-
struct ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(userdata);
22641-
result.adamw.alpha = 1e-6f;
22642-
return result;
22638+
bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata) {
22639+
GGML_UNUSED(tensor);
22640+
GGML_UNUSED(userdata);
22641+
return true;
22642+
}
22643+
22644+
static void llama_set_param(struct ggml_tensor * tensor, llama_opt_param_filter param_filter, void * userdata) {
22645+
if (!tensor || tensor->type != GGML_TYPE_F32) {
22646+
return;
22647+
}
22648+
if (!param_filter(tensor, userdata)) {
22649+
return;
22650+
}
22651+
ggml_set_param(tensor);
2264322652
}
2264422653

2264522654
void llama_opt_init(struct llama_context * lctx, struct llama_model * model, struct llama_opt_params lopt_params) {
@@ -22656,6 +22665,30 @@ void llama_opt_init(struct llama_context * lctx, struct llama_model * model, str
2265622665
opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
2265722666

2265822667
lctx->opt_ctx = ggml_opt_init(opt_params);
22668+
22669+
llama_opt_param_filter param_filter = lopt_params.param_filter;
22670+
void * param_filter_ud = lopt_params.param_filter_ud;
22671+
22672+
llama_set_param(model->tok_embd, param_filter, param_filter_ud);
22673+
llama_set_param(model->type_embd, param_filter, param_filter_ud);
22674+
llama_set_param(model->pos_embd, param_filter, param_filter_ud);
22675+
llama_set_param(model->tok_norm, param_filter, param_filter_ud);
22676+
llama_set_param(model->tok_norm_b, param_filter, param_filter_ud);
22677+
llama_set_param(model->output_norm, param_filter, param_filter_ud);
22678+
llama_set_param(model->output_norm_b, param_filter, param_filter_ud);
22679+
llama_set_param(model->output, param_filter, param_filter_ud);
22680+
llama_set_param(model->output_b, param_filter, param_filter_ud);
22681+
llama_set_param(model->output_norm_enc, param_filter, param_filter_ud);
22682+
llama_set_param(model->cls, param_filter, param_filter_ud);
22683+
llama_set_param(model->cls_b, param_filter, param_filter_ud);
22684+
llama_set_param(model->cls_out, param_filter, param_filter_ud);
22685+
llama_set_param(model->cls_out_b, param_filter, param_filter_ud);
22686+
22687+
for (struct llama_layer & layer : model->layers) {
22688+
for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) {
22689+
llama_set_param(reinterpret_cast<struct ggml_tensor **>(&layer)[i], param_filter, param_filter_ud);
22690+
}
22691+
}
2265922692
}
2266022693

2266122694
static void llama_opt_epoch_iter(

0 commit comments

Comments
 (0)