@@ -55,6 +55,19 @@ static int printe(const char * fmt, ...) {
55
55
class Opt {
56
56
public:
57
57
int init (int argc, const char ** argv) {
58
+ ctx_params_ = llama_context_default_params ();
59
+ model_params_ = llama_model_default_params ();
60
+ context_size_default = ctx_params_.n_batch ;
61
+ ngl_default = model_params_.n_gpu_layers ;
62
+ common_params_sampling sampling;
63
+ temperature_default = sampling.temp ;
64
+
65
+ if (argc < 2 ) {
66
+ printe (" Error: No arguments provided.\n " );
67
+ help ();
68
+ return 1 ;
69
+ }
70
+
58
71
// Parse arguments
59
72
if (parse (argc, argv)) {
60
73
printe (" Error: Failed to parse arguments.\n " );
@@ -68,12 +81,21 @@ class Opt {
68
81
return 2 ;
69
82
}
70
83
84
+ ctx_params_.n_batch = context_size_ >= 0 ? context_size_ : context_size_default;
85
+ model_params_.n_gpu_layers = ngl_ >= 0 ? ngl_ : ngl_default;
86
+ temperature_ = temperature_ >= 0 ? temperature_ : temperature_default;
87
+
71
88
return 0 ; // Success
72
89
}
73
90
91
+ llama_context_params ctx_params_;
92
+ llama_model_params model_params_;
74
93
std::string model_;
75
94
std::string user_;
76
- int context_size_ = -1 , ngl_ = -1 ;
95
+ int context_size_default = -1 , ngl_default = -1 ;
96
+ float temperature_default = -1 ;
97
+ int context_size_ = -1 , ngl_ = -1 ;
98
+ float temperature_ = -1 ;
77
99
bool verbose_ = false ;
78
100
79
101
private:
@@ -89,6 +111,17 @@ class Opt {
89
111
}
90
112
91
113
option_value = std::atoi (argv[++i]);
114
+
115
+ return 0 ;
116
+ }
117
+
118
+ int handle_option_with_value (int argc, const char ** argv, int & i, float & option_value) {
119
+ if (i + 1 >= argc) {
120
+ return 1 ;
121
+ }
122
+
123
+ option_value = std::atof (argv[++i]);
124
+
92
125
return 0 ;
93
126
}
94
127
@@ -103,6 +136,10 @@ class Opt {
103
136
if (handle_option_with_value (argc, argv, i, ngl_) == 1 ) {
104
137
return 1 ;
105
138
}
139
+ } else if (options_parsing && strcmp (argv[i], " --temperature" ) == 0 ) {
140
+ if (handle_option_with_value (argc, argv, i, temperature_) == 1 ) {
141
+ return 1 ;
142
+ }
106
143
} else if (options_parsing &&
107
144
(parse_flag (argv, i, " -v" , " --verbose" ) || parse_flag (argv, i, " -v" , " --log-verbose" ))) {
108
145
verbose_ = true ;
@@ -142,6 +179,8 @@ class Opt {
142
179
" Context size (default: %d)\n "
143
180
" -n, --ngl <value>\n "
144
181
" Number of GPU layers (default: %d)\n "
182
+ " --temp <value>\n "
183
+ " Temperature (default: %.1f)\n "
145
184
" -v, --verbose, --log-verbose\n "
146
185
" Set verbosity level to infinity (i.e. log all messages, useful for debugging)\n "
147
186
" -h, --help\n "
@@ -170,7 +209,7 @@ class Opt {
170
209
" llama-run file://some-file3.gguf\n "
171
210
" llama-run --ngl 999 some-file4.gguf\n "
172
211
" llama-run --ngl 999 some-file5.gguf Hello World\n " ,
173
- llama_context_default_params (). n_batch , llama_model_default_params (). n_gpu_layers );
212
+ context_size_default, ngl_default, temperature_default );
174
213
}
175
214
};
176
215
@@ -495,12 +534,12 @@ class LlamaData {
495
534
return 1 ;
496
535
}
497
536
498
- context = initialize_context (model, opt. context_size_ );
537
+ context = initialize_context (model, opt);
499
538
if (!context) {
500
539
return 1 ;
501
540
}
502
541
503
- sampler = initialize_sampler ();
542
+ sampler = initialize_sampler (opt );
504
543
return 0 ;
505
544
}
506
545
@@ -619,14 +658,12 @@ class LlamaData {
619
658
// Initializes the model and returns a unique pointer to it
620
659
llama_model_ptr initialize_model (Opt & opt) {
621
660
ggml_backend_load_all ();
622
- llama_model_params model_params = llama_model_default_params ();
623
- model_params.n_gpu_layers = opt.ngl_ >= 0 ? opt.ngl_ : model_params.n_gpu_layers ;
624
661
resolve_model (opt.model_ );
625
662
printe (
626
663
" \r %*s"
627
664
" \r Loading model" ,
628
665
get_terminal_width (), " " );
629
- llama_model_ptr model (llama_load_model_from_file (opt.model_ .c_str (), model_params ));
666
+ llama_model_ptr model (llama_load_model_from_file (opt.model_ .c_str (), opt. model_params_ ));
630
667
if (!model) {
631
668
printe (" %s: error: unable to load model from file: %s\n " , __func__, opt.model_ .c_str ());
632
669
}
@@ -636,10 +673,8 @@ class LlamaData {
636
673
}
637
674
638
675
// Initializes the context with the specified parameters
639
- llama_context_ptr initialize_context (const llama_model_ptr & model, const int n_ctx) {
640
- llama_context_params ctx_params = llama_context_default_params ();
641
- ctx_params.n_ctx = ctx_params.n_batch = n_ctx >= 0 ? n_ctx : ctx_params.n_batch ;
642
- llama_context_ptr context (llama_new_context_with_model (model.get (), ctx_params));
676
+ llama_context_ptr initialize_context (const llama_model_ptr & model, const Opt & opt) {
677
+ llama_context_ptr context (llama_new_context_with_model (model.get (), opt.ctx_params_ ));
643
678
if (!context) {
644
679
printe (" %s: error: failed to create the llama_context\n " , __func__);
645
680
}
@@ -648,10 +683,10 @@ class LlamaData {
648
683
}
649
684
650
685
// Initializes and configures the sampler
651
- llama_sampler_ptr initialize_sampler () {
686
+ llama_sampler_ptr initialize_sampler (const Opt & opt ) {
652
687
llama_sampler_ptr sampler (llama_sampler_chain_init (llama_sampler_chain_default_params ()));
653
688
llama_sampler_chain_add (sampler.get (), llama_sampler_init_min_p (0 .05f , 1 ));
654
- llama_sampler_chain_add (sampler.get (), llama_sampler_init_temp (0 . 8f ));
689
+ llama_sampler_chain_add (sampler.get (), llama_sampler_init_temp (opt. temperature_ ));
655
690
llama_sampler_chain_add (sampler.get (), llama_sampler_init_dist (LLAMA_DEFAULT_SEED));
656
691
657
692
return sampler;
0 commit comments