@@ -341,7 +341,7 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
341
341
break ;
342
342
}
343
343
const auto sampler_names = string_split (argv[i], ' ;' );
344
- sparams.samplers_sequence = sampler_types_from_names (sampler_names);
344
+ sparams.samplers_sequence = sampler_types_from_names (sampler_names, true );
345
345
} else if (arg == " --sampling-seq" ) {
346
346
if (++i >= argc) {
347
347
invalid_param = true ;
@@ -964,7 +964,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
964
964
printf (" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n " , params.n_predict );
965
965
printf (" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n " , params.n_ctx );
966
966
printf (" -b N, --batch-size N batch size for prompt processing (default: %d)\n " , params.n_batch );
967
- printf (" --samplers samplers that will be used for generation in the order, separated by \' ;\' (default: %s)\n " , sampler_type_names.c_str ());
967
+ printf (" --samplers samplers that will be used for generation in the order, separated by \' ;\'\n " );
968
+ printf (" (default: %s)\n " , sampler_type_names.c_str ());
968
969
printf (" --sampling-seq simplified sequence for samplers that will be used (default: %s)\n " , sampler_type_chars.c_str ());
969
970
printf (" --top-k N top-k sampling (default: %d, 0 = disabled)\n " , sparams.top_k );
970
971
printf (" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n " , (double )sparams.top_p );
@@ -1133,34 +1134,50 @@ std::vector<std::string> string_split(std::string input, char separator) {
1133
1134
return parts;
1134
1135
}
1135
1136
1136
- std::vector<llama_sampler_type> sampler_types_from_names (const std::vector<std::string> & names) {
1137
+ std::vector<llama_sampler_type> sampler_types_from_names (const std::vector<std::string> & names, bool allow_alt_names) {
1138
+ std::unordered_map<std::string, llama_sampler_type> sampler_canonical_name_map {
1139
+ {" top_k" , llama_sampler_type::TOP_K},
1140
+ {" top_p" , llama_sampler_type::TOP_P},
1141
+ {" typical_p" , llama_sampler_type::TYPICAL_P},
1142
+ {" min_p" , llama_sampler_type::MIN_P},
1143
+ {" tfs_z" , llama_sampler_type::TFS_Z},
1144
+ {" temperature" , llama_sampler_type::TEMPERATURE}
1145
+ };
1146
+
1137
1147
// since samplers names are written multiple ways
1138
1148
// make it ready for both system names and input names
1139
- std::unordered_map<std::string, llama_sampler_type> sampler_name_map {
1140
- {" top_k" , llama_sampler_type::TOP_K},
1149
+ std::unordered_map<std::string, llama_sampler_type> sampler_alt_name_map {
1141
1150
{" top-k" , llama_sampler_type::TOP_K},
1142
- {" top_p" , llama_sampler_type::TOP_P},
1143
1151
{" top-p" , llama_sampler_type::TOP_P},
1144
1152
{" nucleus" , llama_sampler_type::TOP_P},
1145
- {" typical_p" , llama_sampler_type::TYPICAL_P},
1146
1153
{" typical-p" , llama_sampler_type::TYPICAL_P},
1147
1154
{" typical" , llama_sampler_type::TYPICAL_P},
1148
- {" min_p" , llama_sampler_type::MIN_P},
1149
1155
{" min-p" , llama_sampler_type::MIN_P},
1150
- {" tfs_z" , llama_sampler_type::TFS_Z},
1151
1156
{" tfs-z" , llama_sampler_type::TFS_Z},
1152
1157
{" tfs" , llama_sampler_type::TFS_Z},
1153
- {" temp" , llama_sampler_type::TEMP},
1154
- {" temperature" , llama_sampler_type::TEMP}
1158
+ {" temp" , llama_sampler_type::TEMPERATURE}
1155
1159
};
1156
1160
1157
1161
std::vector<llama_sampler_type> sampler_types;
1158
1162
sampler_types.reserve (names.size ());
1159
- for (const auto & name : names) {
1160
- const auto sampler_item = sampler_name_map.find (name);
1161
- if (sampler_item != sampler_name_map.end ()) {
1163
+ for (const auto & name : names)
1164
+ {
1165
+ auto sampler_item = sampler_canonical_name_map.find (name);
1166
+ if (sampler_item != sampler_canonical_name_map.end ())
1167
+ {
1162
1168
sampler_types.push_back (sampler_item->second );
1163
1169
}
1170
+ else
1171
+ {
1172
+ if (allow_alt_names)
1173
+ {
1174
+ sampler_item = sampler_alt_name_map.find (name);
1175
+ if (sampler_item != sampler_alt_name_map.end ())
1176
+ {
1177
+ sampler_types.push_back (sampler_item->second );
1178
+ }
1179
+ }
1180
+ }
1164
1181
}
1165
1182
return sampler_types;
1166
1183
}
@@ -1172,7 +1189,7 @@ std::vector<llama_sampler_type> sampler_types_from_chars(const std::string & nam
1172
1189
{' y' , llama_sampler_type::TYPICAL_P},
1173
1190
{' m' , llama_sampler_type::MIN_P},
1174
1191
{' f' , llama_sampler_type::TFS_Z},
1175
- {' t' , llama_sampler_type::TEMP }
1192
+ {' t' , llama_sampler_type::TEMPERATURE }
1176
1193
};
1177
1194
1178
1195
std::vector<llama_sampler_type> sampler_types;
@@ -1188,12 +1205,12 @@ std::vector<llama_sampler_type> sampler_types_from_chars(const std::string & nam
1188
1205
1189
1206
std::string sampler_type_to_name_string (llama_sampler_type sampler_type) {
1190
1207
switch (sampler_type) {
1191
- case llama_sampler_type::TOP_K: return " top_k" ;
1192
- case llama_sampler_type::TFS_Z: return " tfs_z" ;
1193
- case llama_sampler_type::TYPICAL_P: return " typical_p" ;
1194
- case llama_sampler_type::TOP_P: return " top_p" ;
1195
- case llama_sampler_type::MIN_P: return " min_p" ;
1196
- case llama_sampler_type::TEMP: return " temp " ;
1208
+ case llama_sampler_type::TOP_K: return " top_k" ;
1209
+ case llama_sampler_type::TFS_Z: return " tfs_z" ;
1210
+ case llama_sampler_type::TYPICAL_P: return " typical_p" ;
1211
+ case llama_sampler_type::TOP_P: return " top_p" ;
1212
+ case llama_sampler_type::MIN_P: return " min_p" ;
1213
+ case llama_sampler_type::TEMPERATURE: return " temperature " ;
1197
1214
default : return " " ;
1198
1215
}
1199
1216
}
0 commit comments