@@ -146,90 +146,84 @@ def handle_model(model_name, provider, token_key, default_token=8192):
146
146
with warnings .catch_warnings ():
147
147
warnings .simplefilter ("ignore" )
148
148
return init_chat_model (** llm_params )
149
-
150
- if "azure" in llm_params ["model" ]:
151
- model_name = llm_params ["model" ].split ("/" )[- 1 ]
152
- return handle_model (model_name , "azure_openai" , model_name )
153
-
154
- if "gpt-" in llm_params ["model" ]:
155
- return handle_model (llm_params ["model" ], "openai" , llm_params ["model" ])
156
-
157
- if "fireworks" in llm_params ["model" ]:
158
- model_name = "/" .join (llm_params ["model" ].split ("/" )[1 :])
159
- token_key = llm_params ["model" ].split ("/" )[- 1 ]
160
- return handle_model (model_name , "fireworks" , token_key )
161
-
162
- if "gemini" in llm_params ["model" ]:
163
- model_name = llm_params ["model" ].split ("/" )[- 1 ]
164
- return handle_model (model_name , "google_genai" , model_name )
165
-
166
- if llm_params ["model" ].startswith ("claude" ):
167
- model_name = llm_params ["model" ].split ("/" )[- 1 ]
168
- return handle_model (model_name , "anthropic" , model_name )
169
-
170
- if llm_params ["model" ].startswith ("vertexai" ):
171
- return handle_model (llm_params ["model" ], "google_vertexai" , llm_params ["model" ])
172
149
173
- if "ollama" in llm_params ["model" ]:
174
- model_name = llm_params ["model" ].split ("ollama/" )[- 1 ]
175
- token_key = model_name if "model_tokens" not in llm_params else llm_params ["model_tokens" ]
176
- return handle_model (model_name , "ollama" , token_key )
177
-
178
- if "hugging_face" in llm_params ["model" ]:
179
- model_name = llm_params ["model" ].split ("/" )[- 1 ]
180
- return handle_model (model_name , "hugging_face" , model_name )
181
-
182
- if "groq" in llm_params ["model" ]:
183
- model_name = llm_params ["model" ].split ("/" )[- 1 ]
184
- return handle_model (model_name , "groq" , model_name )
185
-
186
- if "bedrock" in llm_params ["model" ]:
187
- model_name = llm_params ["model" ].split ("/" )[- 1 ]
188
- return handle_model (model_name , "bedrock" , model_name )
189
-
190
- if "claude-3-" in llm_params ["model" ]:
191
- return handle_model (llm_params ["model" ], "anthropic" , "claude3" )
192
-
193
- if llm_params ["model" ].startswith ("mistral" ):
194
- model_name = llm_params ["model" ].split ("/" )[- 1 ]
195
- return handle_model (model_name , "mistralai" , model_name )
196
-
197
- # Instantiate the language model based on the model name (models that do not use the common interface)
198
- if "deepseek" in llm_params ["model" ]:
199
- try :
200
- self .model_token = models_tokens ["deepseek" ][llm_params ["model" ]]
201
- except KeyError :
202
- print ("model not found, using default token size (8192)" )
203
- self .model_token = 8192
204
- return DeepSeek (llm_params )
205
-
206
- if "ernie" in llm_params ["model" ]:
207
- try :
208
- self .model_token = models_tokens ["ernie" ][llm_params ["model" ]]
209
- except KeyError :
210
- print ("model not found, using default token size (8192)" )
211
- self .model_token = 8192
212
- return ErnieBotChat (** llm_params )
213
-
214
- if "oneapi" in llm_params ["model" ]:
215
- # take the model after the last dash
216
- llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
217
- try :
218
- self .model_token = models_tokens ["oneapi" ][llm_params ["model" ]]
219
- except KeyError as exc :
220
- raise KeyError ("Model not supported" ) from exc
221
- return OneApi (llm_params )
222
-
223
- if "nvidia" in llm_params ["model" ]:
224
- try :
225
- self .model_token = models_tokens ["nvidia" ][llm_params ["model" ].split ("/" )[- 1 ]]
226
- llm_params ["model" ] = "/" .join (llm_params ["model" ].split ("/" )[1 :])
227
- except KeyError as exc :
228
- raise KeyError ("Model not supported" ) from exc
229
- return ChatNVIDIA (** llm_params )
230
-
231
- # Raise an error if the model did not match any of the previous cases
232
- raise ValueError ("Model provided by the configuration not supported" )
150
+ known_models = ["openai" , "azure_openai" , "google_genai" , "ollama" , "oneapi" , "nvidia" , "groq" , "google_vertexai" , "bedrock" , "mistralai" , "hugging_face" , "deepseek" , "ernie" , "fireworks" ]
151
+
152
+ if llm_params ["model" ] not in known_models :
153
+ raise ValueError (f"Model '{ llm_params ['model' ]} ' is not supported" )
154
+
155
+ try :
156
+ if "fireworks" in llm_params ["model" ]:
157
+ model_name = "/" .join (llm_params ["model" ].split ("/" )[1 :])
158
+ token_key = llm_params ["model" ].split ("/" )[- 1 ]
159
+ return handle_model (model_name , "fireworks" , token_key )
160
+
161
+ elif "gemini" in llm_params ["model" ]:
162
+ model_name = llm_params ["model" ].split ("/" )[- 1 ]
163
+ return handle_model (model_name , "google_genai" , model_name )
164
+
165
+ elif llm_params ["model" ].startswith ("claude" ):
166
+ model_name = llm_params ["model" ].split ("/" )[- 1 ]
167
+ return handle_model (model_name , "anthropic" , model_name )
168
+
169
+ elif llm_params ["model" ].startswith ("vertexai" ):
170
+ return handle_model (llm_params ["model" ], "google_vertexai" , llm_params ["model" ])
171
+
172
+ elif "gpt-" in llm_params ["model" ]:
173
+ return handle_model (llm_params ["model" ], "openai" , llm_params ["model" ])
174
+
175
+ elif "ollama" in llm_params ["model" ]:
176
+ model_name = llm_params ["model" ].split ("ollama/" )[- 1 ]
177
+ token_key = model_name if "model_tokens" not in llm_params else llm_params ["model_tokens" ]
178
+ return handle_model (model_name , "ollama" , token_key )
179
+
180
+ elif "claude-3-" in llm_params ["model" ]:
181
+ return handle_model (llm_params ["model" ], "anthropic" , "claude3" )
182
+
183
+ elif llm_params ["model" ].startswith ("mistral" ):
184
+ model_name = llm_params ["model" ].split ("/" )[- 1 ]
185
+ return handle_model (model_name , "mistralai" , model_name )
186
+
187
+ # Instantiate the language model based on the model name (models that do not use the common interface)
188
+ elif "deepseek" in llm_params ["model" ]:
189
+ try :
190
+ self .model_token = models_tokens ["deepseek" ][llm_params ["model" ]]
191
+ except KeyError :
192
+ print ("model not found, using default token size (8192)" )
193
+ self .model_token = 8192
194
+ return DeepSeek (llm_params )
195
+
196
+ elif "ernie" in llm_params ["model" ]:
197
+ try :
198
+ self .model_token = models_tokens ["ernie" ][llm_params ["model" ]]
199
+ except KeyError :
200
+ print ("model not found, using default token size (8192)" )
201
+ self .model_token = 8192
202
+ return ErnieBotChat (llm_params )
203
+
204
+ elif "oneapi" in llm_params ["model" ]:
205
+ # take the model after the last dash
206
+ llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
207
+ try :
208
+ self .model_token = models_tokens ["oneapi" ][llm_params ["model" ]]
209
+ except KeyError :
210
+ raise KeyError ("Model not supported" )
211
+ return OneApi (llm_params )
212
+
213
+ elif "nvidia" in llm_params ["model" ]:
214
+ try :
215
+ self .model_token = models_tokens ["nvidia" ][llm_params ["model" ].split ("/" )[- 1 ]]
216
+ llm_params ["model" ] = "/" .join (llm_params ["model" ].split ("/" )[1 :])
217
+ except KeyError :
218
+ raise KeyError ("Model not supported" )
219
+ return ChatNVIDIA (llm_params )
220
+
221
+ else :
222
+ model_name = llm_params ["model" ].split ("/" )[- 1 ]
223
+ return handle_model (model_name , llm_params ["model" ], model_name )
224
+
225
+ except KeyError as e :
226
+ print (f"Model not supported: { e } " )
233
227
234
228
235
229
def get_state (self , key = None ) -> dict :
@@ -277,4 +271,4 @@ def _create_graph(self):
277
271
def run (self ) -> str :
278
272
"""
279
273
Abstract method to execute the graph and return the result.
280
- """
274
+ """
0 commit comments