@@ -4884,6 +4884,44 @@ static void llm_load_vocab(
4884
4884
attrib |= LLAMA_TOKEN_ATTRIB_BYTE * (data.type == LLAMA_TOKEN_TYPE_BYTE);
4885
4885
data.attribs = (llama_token_attrib) attrib;
4886
4886
}
4887
+
4888
+ // set attributes by model name
4889
+ std::string model_name;
4890
+ if (ml.get_key(LLM_KV_GENERAL_NAME, model_name, false)) {
4891
+ std::transform(model_name.begin(), model_name.end(), model_name.begin(),
4892
+ [] (const std::string::value_type x) {
4893
+ return std::tolower(x);
4894
+ }
4895
+ );
4896
+
4897
+ auto _contains_any = [&model_name] (const std::vector<std::string> &substrs) -> bool {
4898
+ for (auto substr : substrs) {
4899
+ if (model_name.find(substr) < std::string::npos) {
4900
+ return true;
4901
+ }
4902
+ }
4903
+ return false;
4904
+ };
4905
+
4906
+ auto _set_token_attrib = [&vocab] (const std::string & token, llama_token_attrib attrib, bool value) {
4907
+ llama_vocab::id id = vocab.token_to_id.at(token);
4908
+ uint32_t attribs = vocab.id_to_token[id].attribs;
4909
+ attribs = value ? (attribs | attrib) : (attribs & ~attrib);
4910
+ vocab.id_to_token[id].attribs = (llama_token_attrib) attribs;
4911
+ };
4912
+
4913
+ if (_contains_any({"phi-3", "phi3"})) {
4914
+ for (auto token : vocab.cache_token_to_piece_special) {
4915
+ _set_token_attrib(token, LLAMA_TOKEN_ATTRIB_RSTRIP, true);
4916
+ }
4917
+ for (auto token : {"</s>"}) {
4918
+ _set_token_attrib(token, LLAMA_TOKEN_ATTRIB_RSTRIP, true);
4919
+ }
4920
+ for (auto token : {"<unk>", "<s>", "<|endoftext|>"}) {
4921
+ _set_token_attrib(token, LLAMA_TOKEN_ATTRIB_RSTRIP, false);
4922
+ }
4923
+ }
4924
+ }
4887
4925
}
4888
4926
}
4889
4927
0 commit comments