@@ -146,138 +146,61 @@ def _create_llm(self, llm_config: dict) -> object:
146
146
raise KeyError ("model_tokens not specified" ) from exc
147
147
return llm_params ["model_instance" ]
148
148
149
- # Instantiate the language model based on the model name
150
- if "gpt-" in llm_params [ "model" ] :
149
+ # Instantiate the language model based on the model name (models that use the common interface)
150
+ def handle_model ( model_name , provider , token_key , default_token = 8192 ) :
151
151
try :
152
- self .model_token = models_tokens ["openai" ][llm_params ["model" ]]
153
- llm_params ["model_provider" ] = "openai"
154
- except KeyError as exc :
155
- raise KeyError ("Model not supported" ) from exc
152
+ self .model_token = models_tokens [provider ][token_key ]
153
+ except KeyError :
154
+ print (f"Model not found, using default token size ({ default_token } )" )
155
+ self .model_token = default_token
156
+ llm_params ["model_provider" ] = provider
157
+ llm_params ["model" ] = model_name
156
158
return init_chat_model (** llm_params )
157
159
158
- if "oneapi" in llm_params ["model" ]:
159
- # take the model after the last dash
160
- llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
161
- try :
162
- self .model_token = models_tokens ["oneapi" ][llm_params ["model" ]]
163
- except KeyError as exc :
164
- raise KeyError ("Model not supported" ) from exc
165
- return OneApi (llm_params )
160
+ if "gpt-" in llm_params ["model" ]:
161
+ return handle_model (llm_params ["model" ], "openai" , llm_params ["model" ])
166
162
167
163
if "fireworks" in llm_params ["model" ]:
168
- try :
169
- self .model_token = models_tokens ["fireworks" ][llm_params ["model" ].split ("/" )[- 1 ]]
170
- llm_params ["model" ] = "/" .join (llm_params ["model" ].split ("/" )[1 :])
171
- except KeyError as exc :
172
- raise KeyError ("Model not supported" ) from exc
173
- llm_params ["model_provider" ] = "fireworks"
174
- return init_chat_model (** llm_params )
164
+ model_name = "/" .join (llm_params ["model" ].split ("/" )[1 :])
165
+ token_key = llm_params ["model" ].split ("/" )[- 1 ]
166
+ return handle_model (model_name , "fireworks" , token_key )
175
167
176
168
if "azure" in llm_params ["model" ]:
177
- # take the model after the last dash
178
- llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
179
- try :
180
- self .model_token = models_tokens ["azure" ][llm_params ["model" ]]
181
- except KeyError as exc :
182
- raise KeyError ("Model not supported" ) from exc
183
- llm_params ["model_provider" ] = "azure_openai"
184
- return init_chat_model (** llm_params )
185
-
186
- if "nvidia" in llm_params ["model" ]:
187
- try :
188
- self .model_token = models_tokens ["nvidia" ][llm_params ["model" ].split ("/" )[- 1 ]]
189
- llm_params ["model" ] = "/" .join (llm_params ["model" ].split ("/" )[1 :])
190
- except KeyError as exc :
191
- raise KeyError ("Model not supported" ) from exc
192
- return ChatNVIDIA (llm_params )
169
+ model_name = llm_params ["model" ].split ("/" )[- 1 ]
170
+ return handle_model (model_name , "azure_openai" , model_name )
193
171
194
172
if "gemini" in llm_params ["model" ]:
195
- llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
196
- try :
197
- self .model_token = models_tokens ["gemini" ][llm_params ["model" ]]
198
- except KeyError as exc :
199
- raise KeyError ("Model not supported" ) from exc
200
- llm_params ["model_provider" ] = "google_genai "
201
- return init_chat_model (** llm_params )
173
+ model_name = llm_params ["model" ].split ("/" )[- 1 ]
174
+ return handle_model (model_name , "google_genai" , model_name )
202
175
203
176
if llm_params ["model" ].startswith ("claude" ):
204
- llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
205
- try :
206
- self .model_token = models_tokens ["claude" ][llm_params ["model" ]]
207
- except KeyError as exc :
208
- raise KeyError ("Model not supported" ) from exc
209
- llm_params ["model_provider" ] = "anthropic"
210
- return init_chat_model (** llm_params )
177
+ model_name = llm_params ["model" ].split ("/" )[- 1 ]
178
+ return handle_model (model_name , "anthropic" , model_name )
211
179
212
180
if llm_params ["model" ].startswith ("vertexai" ):
213
- try :
214
- self .model_token = models_tokens ["vertexai" ][llm_params ["model" ]]
215
- except KeyError as exc :
216
- raise KeyError ("Model not supported" ) from exc
217
- llm_params ["model_provider" ] = "google_vertexai"
218
- return init_chat_model (** llm_params )
181
+ return handle_model (llm_params ["model" ], "google_vertexai" , llm_params ["model" ])
219
182
220
183
if "ollama" in llm_params ["model" ]:
221
- llm_params ["model" ] = llm_params ["model" ].split ("ollama/" )[- 1 ]
222
- llm_params ["model_provider" ] = "ollama"
223
-
224
- # allow user to set model_tokens in config
225
- try :
226
- if "model_tokens" in llm_params :
227
- self .model_token = llm_params ["model_tokens" ]
228
- elif llm_params ["model" ] in models_tokens ["ollama" ]:
229
- try :
230
- self .model_token = models_tokens ["ollama" ][llm_params ["model" ]]
231
- except KeyError as exc :
232
- print ("model not found, using default token size (8192)" )
233
- self .model_token = 8192
234
- else :
235
- self .model_token = 8192
236
- except AttributeError :
237
- self .model_token = 8192
238
-
239
- return init_chat_model (** llm_params )
184
+ model_name = llm_params ["model" ].split ("ollama/" )[- 1 ]
185
+ token_key = model_name if "model_tokens" not in llm_params else llm_params ["model_tokens" ]
186
+ return handle_model (model_name , "ollama" , token_key )
240
187
241
188
if "hugging_face" in llm_params ["model" ]:
242
- llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
243
- try :
244
- self .model_token = models_tokens ["hugging_face" ][llm_params ["model" ]]
245
- except KeyError :
246
- print ("model not found, using default token size (8192)" )
247
- self .model_token = 8192
248
- llm_params ["model_provider" ] = "hugging_face"
249
- return init_chat_model (** llm_params )
189
+ model_name = llm_params ["model" ].split ("/" )[- 1 ]
190
+ return handle_model (model_name , "hugging_face" , model_name )
250
191
251
192
if "groq" in llm_params ["model" ]:
252
- llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
253
-
254
- try :
255
- self .model_token = models_tokens ["groq" ][llm_params ["model" ]]
256
- except KeyError :
257
- print ("model not found, using default token size (8192)" )
258
- self .model_token = 8192
259
- llm_params ["model_provider" ] = "groq"
260
- return init_chat_model (** llm_params )
193
+ model_name = llm_params ["model" ].split ("/" )[- 1 ]
194
+ return handle_model (model_name , "groq" , model_name )
261
195
262
196
if "bedrock" in llm_params ["model" ]:
263
- llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
264
- try :
265
- self .model_token = models_tokens ["bedrock" ][llm_params ["model" ]]
266
- except KeyError :
267
- print ("model not found, using default token size (8192)" )
268
- self .model_token = 8192
269
- llm_params ["model_provider" ] = "bedrock"
270
- return init_chat_model (** llm_params )
197
+ model_name = llm_params ["model" ].split ("/" )[- 1 ]
198
+ return handle_model (model_name , "bedrock" , model_name )
271
199
272
200
if "claude-3-" in llm_params ["model" ]:
273
- try :
274
- self .model_token = models_tokens ["claude" ]["claude3" ]
275
- except KeyError :
276
- print ("model not found, using default token size (8192)" )
277
- self .model_token = 8192
278
- llm_params ["model_provider" ] = "anthropic"
279
- return init_chat_model (** llm_params )
201
+ return handle_model (llm_params ["model" ], "anthropic" , "claude3" )
280
202
203
+ # Instantiate the language model based on the model name (models that do not use the common interface)
281
204
if "deepseek" in llm_params ["model" ]:
282
205
try :
283
206
self .model_token = models_tokens ["deepseek" ][llm_params ["model" ]]
@@ -293,7 +216,25 @@ def _create_llm(self, llm_config: dict) -> object:
293
216
print ("model not found, using default token size (8192)" )
294
217
self .model_token = 8192
295
218
return ErnieBotChat (llm_params )
219
+
220
+ if "oneapi" in llm_params ["model" ]:
221
+ # take the model after the last dash
222
+ llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
223
+ try :
224
+ self .model_token = models_tokens ["oneapi" ][llm_params ["model" ]]
225
+ except KeyError as exc :
226
+ raise KeyError ("Model not supported" ) from exc
227
+ return OneApi (llm_params )
228
+
229
+ if "nvidia" in llm_params ["model" ]:
230
+ try :
231
+ self .model_token = models_tokens ["nvidia" ][llm_params ["model" ].split ("/" )[- 1 ]]
232
+ llm_params ["model" ] = "/" .join (llm_params ["model" ].split ("/" )[1 :])
233
+ except KeyError as exc :
234
+ raise KeyError ("Model not supported" ) from exc
235
+ return ChatNVIDIA (llm_params )
296
236
237
+ # Raise an error if the model did not match any of the previous cases
297
238
raise ValueError ("Model provided by the configuration not supported" )
298
239
299
240
def _create_default_embedder (self , llm_config = None ) -> object :
0 commit comments