Skip to content

Commit 021e6d9

Browse files
committed
Steering
1 parent 63d2046 commit 021e6d9

File tree

5 files changed

+100
-0
lines changed

5 files changed

+100
-0
lines changed

examples/common.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,30 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
344344
break;
345345
}
346346
params.input_suffix = argv[i];
347+
} else if (arg == "--steering-add") {
348+
if (++i >= argc) {
349+
invalid_param = true;
350+
break;
351+
}
352+
params.steering_add = argv[i];
353+
} else if (arg == "--steering-sub") {
354+
if (++i >= argc) {
355+
invalid_param = true;
356+
break;
357+
}
358+
params.steering_sub = argv[i];
359+
} else if (arg == "--steering-mul") {
360+
if (++i >= argc) {
361+
invalid_param = true;
362+
break;
363+
}
364+
params.steering_mul = std::stof(argv[i]);
365+
} else if (arg == "--steering-lyr") {
366+
if (++i >= argc) {
367+
invalid_param = true;
368+
break;
369+
}
370+
params.steering_lyr = std::stoi(argv[i]);
347371
} else {
348372
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
349373
gpt_print_usage(argc, argv, default_params);

examples/common.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ struct gpt_params {
7272
bool use_mlock = false; // use mlock to keep model in memory
7373
bool mem_test = false; // compute maximum memory usage
7474
bool verbose_prompt = false; // print prompt tokens before generation
75+
76+
std::string steering_add = "";
77+
std::string steering_sub = "";
78+
float steering_mul = 1.0f;
79+
int steering_lyr = 20;
7580
};
7681

7782
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);

examples/main/main.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,28 @@ int main(int argc, char ** argv) {
136136
return 0;
137137
}
138138

139+
if (params.steering_add.size() || params.steering_sub.size())
140+
{
141+
auto steering_add_tokens = ::llama_tokenize(ctx, params.steering_add, true);
142+
auto steering_sub_tokens = ::llama_tokenize(ctx, params.steering_sub, true);
143+
144+
if (steering_add_tokens.size() != steering_sub_tokens.size()) {
145+
llama_token space;
146+
llama_tokenize(ctx, " ", &space, 1, 0);
147+
148+
while (steering_add_tokens.size() < steering_sub_tokens.size()) steering_add_tokens.push_back(space);
149+
while (steering_sub_tokens.size() < steering_add_tokens.size()) steering_sub_tokens.push_back(space);
150+
}
151+
152+
llama_set_steering_write(ctx, params.steering_lyr, params.steering_mul/2);
153+
llama_eval(ctx, steering_add_tokens.data(), std::min((int)steering_add_tokens.size(), params.n_ctx), 0, params.n_threads);
154+
155+
llama_set_steering_write(ctx, params.steering_lyr, -params.steering_mul/2);
156+
llama_eval(ctx, steering_sub_tokens.data(), std::min((int)steering_sub_tokens.size(), params.n_ctx), 0, params.n_threads);
157+
158+
llama_set_steering_read(ctx, params.steering_lyr, 1);
159+
}
160+
139161
// Add a space in front of the first character to match OG llama tokenizer behavior
140162
params.prompt.insert(0, 1, ' ');
141163

llama.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,15 @@ struct llama_context {
229229
// input embedding (1-dimensional array: [n_embd])
230230
std::vector<float> embedding;
231231

232+
std::vector<float> steering_vector; // [n_ctx, n_embd]
233+
int steering_layer = 0;
234+
int steering_mode = 0;
235+
float steering_mul = 0.0f;
236+
237+
#define STEERING_OFF 0
238+
#define STEERING_WRITE 2
239+
#define STEERING_READ 3
240+
232241
// memory buffers used to evaluate the model
233242
// TODO: move in llama_state
234243
llama_ctx_buffer buf_compute;
@@ -269,6 +278,17 @@ struct llama_context {
269278
}
270279
};
271280

281+
void llama_set_steering_write(struct llama_context * ctx, int layer, float mul) {
282+
ctx->steering_mode = STEERING_WRITE;
283+
ctx->steering_mul = mul;
284+
ctx->steering_layer = layer;
285+
}
286+
void llama_set_steering_read(struct llama_context * ctx, int layer, float mul) {
287+
ctx->steering_mode = STEERING_READ;
288+
ctx->steering_mul = mul;
289+
ctx->steering_layer = layer;
290+
}
291+
272292
template <typename T>
273293
static T checked_mul(T a, T b) {
274294
T ret = a * b;
@@ -1141,6 +1161,12 @@ static bool llama_eval_internal(
11411161
ggml_set_name(embd, "embd");
11421162
memcpy(embd->data, tokens, N*ggml_element_size(embd));
11431163

1164+
struct ggml_tensor * steer;
1165+
if (lctx.steering_mode != STEERING_OFF) {
1166+
steer = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_ctx, n_embd);
1167+
memcpy(steer->data, lctx.steering_vector.data(), ggml_nbytes(steer));
1168+
}
1169+
11441170
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
11451171

11461172
for (int il = 0; il < n_layer; ++il) {
@@ -1150,6 +1176,18 @@ static bool llama_eval_internal(
11501176

11511177
lctx.use_buf(ctx0, 0);
11521178

1179+
if (lctx.steering_mode != STEERING_OFF && il == lctx.steering_layer) {
1180+
steer->data = lctx.steering_vector.data();
1181+
1182+
struct ggml_tensor * src = ggml_scale(ctx0, inpL, ggml_new_f32(ctx0, lctx.steering_mul));
1183+
struct ggml_tensor * dst = ggml_view_2d(ctx0, steer, n_embd, N, n_embd * sizeof(float), n_past * n_embd * sizeof(float));
1184+
if (lctx.steering_mode == STEERING_WRITE) {
1185+
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, ggml_add(ctx0, src, dst), dst));
1186+
} else {
1187+
inpL = src;
1188+
}
1189+
}
1190+
11531191
// norm
11541192
{
11551193
cur = ggml_rms_norm(ctx0, inpL);
@@ -1363,6 +1401,12 @@ static bool llama_eval_internal(
13631401
memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd);
13641402
}
13651403

1404+
1405+
if (lctx.steering_mode == STEERING_WRITE) {
1406+
memcpy(lctx.steering_vector.data(), steer->data, ggml_nbytes(steer));
1407+
}
1408+
1409+
13661410
if (mem_per_token == 0) {
13671411
mem_per_token = ggml_used_mem(ctx0)/N;
13681412
}
@@ -2184,6 +2228,8 @@ struct llama_context * llama_init_from_file(
21842228

21852229
ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0().at(ctx->model.type));
21862230
ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type));
2231+
2232+
ctx->steering_vector.resize(hparams.n_ctx * hparams.n_embd);
21872233
}
21882234

21892235
return ctx;

llama.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,9 @@ extern "C" {
191191
LLAMA_API llama_token llama_token_eos();
192192
LLAMA_API llama_token llama_token_nl();
193193

194+
LLAMA_API void llama_set_steering_write(struct llama_context * ctx, int layer, float mul);
195+
LLAMA_API void llama_set_steering_read(struct llama_context * ctx, int layer, float mul);
196+
194197
// Sampling functions
195198

196199
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.

0 commit comments

Comments
 (0)