Skip to content

Commit 98d224e

Browse files
lucylqfacebook-github-bot
authored andcommitted
Add temperature to llama runner (#1969)
Summary: Pull Request resolved: #1969 So we have deterministic output for CI testing. Reviewed By: larryliu0820 Differential Revision: D53735445 fbshipit-source-id: 1a30f2edf366f2f3ca08a14cda20f75184b68801
1 parent da5ab27 commit 98d224e

File tree

3 files changed

+16
-4
lines changed

3 files changed

+16
-4
lines changed

examples/models/llama2/main.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ DEFINE_string(tokenizer_path, "tokenizer.bin", "Tokenizer stuff.");
1919

2020
DEFINE_string(prompt, "The answer to the ultimate question is", "Prompt.");
2121

22+
DEFINE_double(
23+
temperature,
24+
0.8f,
25+
"Temperature; Default is 0.8f. 0 = greedy argmax sampling (deterministic). Lower temperature = more deterministic");
26+
2227
int32_t main(int32_t argc, char** argv) {
2328
gflags::ParseCommandLineFlags(&argc, &argv, true);
2429

@@ -31,8 +36,10 @@ int32_t main(int32_t argc, char** argv) {
3136

3237
const char* prompt = FLAGS_prompt.c_str();
3338

39+
double temperature = FLAGS_temperature;
40+
3441
// create llama runner
35-
::torch::executor::Runner runner(model_path, tokenizer_path);
42+
::torch::executor::Runner runner(model_path, tokenizer_path, temperature);
3643

3744
// generate
3845
runner.generate(prompt);

examples/models/llama2/runner/runner.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@
2727
namespace torch {
2828
namespace executor {
2929

30-
Runner::Runner(const char* model_path, const char* tokenizer_path) {
30+
Runner::Runner(
31+
const char* model_path,
32+
const char* tokenizer_path,
33+
float temperature) {
3134
// Constants definition
32-
float temperature = 0.8f;
3335
float topp = 0.9f;
3436
unsigned long long rng_seed =
3537
(unsigned int)time(nullptr); // seed rng with time by default

examples/models/llama2/runner/runner.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@ namespace executor {
2626

2727
class Runner {
2828
public:
29-
explicit Runner(const char* model_path, const char* tokenizer_path);
29+
explicit Runner(
30+
const char* model_path,
31+
const char* tokenizer_path,
32+
float temperature = 0.8f);
3033

3134
Error generate(
3235
const std::string& prompt,

0 commit comments

Comments
 (0)