@@ -136,7 +136,6 @@ def _create_llm(self, llm_config: dict) -> object:
136
136
raise KeyError ("model_tokens not specified" ) from exc
137
137
return llm_params ["model_instance" ]
138
138
139
- # Instantiate the language model based on the model name (models that use the common interface)
140
139
def handle_model (model_name , provider , token_key , default_token = 8192 ):
141
140
try :
142
141
self .model_token = models_tokens [provider ][token_key ]
@@ -153,84 +152,74 @@ def handle_model(model_name, provider, token_key, default_token=8192):
153
152
model_name = llm_params ["model" ].split ("/" )[- 1 ]
154
153
return handle_model (model_name , "azure_openai" , model_name )
155
154
156
- if "gpt-" in llm_params ["model" ]:
157
- return handle_model (llm_params ["model" ], "openai" , llm_params ["model" ])
158
-
159
- if "fireworks" in llm_params ["model" ]:
155
+ elif "fireworks" in llm_params ["model" ]:
160
156
model_name = "/" .join (llm_params ["model" ].split ("/" )[1 :])
161
157
token_key = llm_params ["model" ].split ("/" )[- 1 ]
162
158
return handle_model (model_name , "fireworks" , token_key )
163
159
164
- if "gemini" in llm_params ["model" ]:
160
+ elif "gemini" in llm_params ["model" ]:
165
161
model_name = llm_params ["model" ].split ("/" )[- 1 ]
166
162
return handle_model (model_name , "google_genai" , model_name )
167
163
168
- if llm_params ["model" ].startswith ("claude" ):
164
+ elif llm_params ["model" ].startswith ("claude" ):
169
165
model_name = llm_params ["model" ].split ("/" )[- 1 ]
170
166
return handle_model (model_name , "anthropic" , model_name )
171
167
172
- if llm_params ["model" ].startswith ("vertexai" ):
168
+ elif llm_params ["model" ].startswith ("vertexai" ):
173
169
return handle_model (llm_params ["model" ], "google_vertexai" , llm_params ["model" ])
174
170
175
- if "ollama" in llm_params ["model" ]:
171
+ elif "gpt-" in llm_params ["model" ]:
172
+ return handle_model (llm_params ["model" ], "openai" , llm_params ["model" ])
173
+
174
+ elif "ollama" in llm_params ["model" ]:
176
175
model_name = llm_params ["model" ].split ("ollama/" )[- 1 ]
177
176
token_key = model_name if "model_tokens" not in llm_params else llm_params ["model_tokens" ]
178
177
return handle_model (model_name , "ollama" , token_key )
179
178
180
- if "hugging_face" in llm_params ["model" ]:
181
- model_name = llm_params ["model" ].split ("/" )[- 1 ]
182
- return handle_model (model_name , "hugging_face" , model_name )
183
-
184
- if "groq" in llm_params ["model" ]:
185
- model_name = llm_params ["model" ].split ("/" )[- 1 ]
186
- return handle_model (model_name , "groq" , model_name )
187
-
188
- if "bedrock" in llm_params ["model" ]:
189
- model_name = llm_params ["model" ].split ("/" )[- 1 ]
190
- return handle_model (model_name , "bedrock" , model_name )
191
-
192
- if "claude-3-" in llm_params ["model" ]:
179
+ elif "claude-3-" in llm_params ["model" ]:
193
180
return handle_model (llm_params ["model" ], "anthropic" , "claude3" )
194
-
195
- if llm_params ["model" ].startswith ("mistral" ):
181
+
182
+ elif llm_params ["model" ].startswith ("mistral" ):
196
183
model_name = llm_params ["model" ].split ("/" )[- 1 ]
197
184
return handle_model (model_name , "mistralai" , model_name )
198
185
199
186
# Instantiate the language model based on the model name (models that do not use the common interface)
200
- if "deepseek" in llm_params ["model" ]:
187
+ elif "deepseek" in llm_params ["model" ]:
201
188
try :
202
189
self .model_token = models_tokens ["deepseek" ][llm_params ["model" ]]
203
190
except KeyError :
204
191
print ("model not found, using default token size (8192)" )
205
192
self .model_token = 8192
206
193
return DeepSeek (llm_params )
207
194
208
- if "ernie" in llm_params ["model" ]:
195
+ elif "ernie" in llm_params ["model" ]:
209
196
try :
210
197
self .model_token = models_tokens ["ernie" ][llm_params ["model" ]]
211
198
except KeyError :
212
199
print ("model not found, using default token size (8192)" )
213
200
self .model_token = 8192
214
201
return ErnieBotChat (llm_params )
215
-
216
- if "oneapi" in llm_params ["model" ]:
202
+
203
+ elif "oneapi" in llm_params ["model" ]:
217
204
# take the model after the last dash
218
205
llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
219
206
try :
220
207
self .model_token = models_tokens ["oneapi" ][llm_params ["model" ]]
221
208
except KeyError as exc :
222
209
raise KeyError ("Model not supported" ) from exc
223
210
return OneApi (llm_params )
224
-
225
- if "nvidia" in llm_params ["model" ]:
211
+
212
+ elif "nvidia" in llm_params ["model" ]:
226
213
try :
227
214
self .model_token = models_tokens ["nvidia" ][llm_params ["model" ].split ("/" )[- 1 ]]
228
215
llm_params ["model" ] = "/" .join (llm_params ["model" ].split ("/" )[1 :])
229
216
except KeyError as exc :
230
217
raise KeyError ("Model not supported" ) from exc
231
218
return ChatNVIDIA (llm_params )
219
+ else :
220
+ model_name = llm_params ["model" ].split ("/" )[- 1 ]
221
+ return handle_model (model_name , llm_params ["model" ], model_name )
232
222
233
- # Raise an error if the model did not match any of the previous cases
234
223
raise ValueError ("Model provided by the configuration not supported" )
235
224
236
225
0 commit comments