@@ -125,103 +125,47 @@ def _create_llm(self, llm_config: dict) -> object:
125
125
self .model_token = llm_params ["model_tokens" ]
126
126
except KeyError as exc :
127
127
raise KeyError ("model_tokens not specified" ) from exc
128
- return llm_params ["model_instance" ]
129
-
130
- def handle_model (model_name , provider , token_key , default_token = 8192 ):
131
- try :
132
- self .model_token = models_tokens [provider ][token_key ]
133
- except KeyError :
134
- print (f"Model not found, using default token size ({ default_token } )" )
135
- self .model_token = default_token
136
- llm_params ["model_provider" ] = provider
137
- llm_params ["model" ] = model_name
138
- with warnings .catch_warnings ():
139
- warnings .simplefilter ("ignore" )
140
- return init_chat_model (** llm_params )
141
-
142
- known_models = {"chatgpt" ,"gpt" ,"openai" , "azure_openai" , "google_genai" ,
143
- "ollama" , "oneapi" , "nvidia" , "groq" , "google_vertexai" ,
144
- "bedrock" , "mistralai" , "hugging_face" , "deepseek" , "ernie" ,
145
- "fireworks" , "claude-3-" }
146
-
147
- if llm_params ["model" ].split ("/" )[0 ] not in known_models and llm_params ["model" ].split ("-" )[0 ] not in known_models :
148
- raise ValueError (f"Model '{ llm_params ['model' ]} ' is not supported" )
149
-
128
+ return llm_params ["model_instance" ]
129
+
130
+ known_providers = {"openai" , "azure_openai" , "google_genai" , "google_vertexai" ,
131
+ "ollama" , "oneapi" , "nvidia" , "groq" , "anthropic" "bedrock" , "mistralai" ,
132
+ "hugging_face" , "deepseek" , "ernie" , "fireworks" }
133
+
134
+ split_model_provider = llm_params ["model" ].split ("/" )
135
+ llm_params ["model_provider" ] = split_model_provider [0 ]
136
+ llm_params ["model" ] = split_model_provider [1 :]
137
+
138
+ if llm_params ["model_provider" ] not in known_providers :
139
+ raise ValueError (f"Provider { llm_params ['model_provider' ]} is not supported. If possible, try to use a model instance instead." )
140
+
150
141
try :
151
- if "fireworks" in llm_params ["model" ]:
152
- model_name = "/" .join (llm_params ["model" ].split ("/" )[1 :])
153
- token_key = llm_params ["model" ].split ("/" )[- 1 ]
154
- return handle_model (model_name , "fireworks" , token_key )
155
-
156
- elif "gemini" in llm_params ["model" ]:
157
- model_name = llm_params ["model" ].split ("/" )[- 1 ]
158
- return handle_model (model_name , "google_genai" , model_name )
159
-
160
- elif llm_params ["model" ].startswith ("claude" ):
161
- model_name = llm_params ["model" ].split ("/" )[- 1 ]
162
- return handle_model (model_name , "anthropic" , model_name )
163
-
164
- elif llm_params ["model" ].startswith ("vertexai" ):
165
- return handle_model (llm_params ["model" ], "google_vertexai" , llm_params ["model" ])
166
-
167
- elif "gpt-" in llm_params ["model" ]:
168
- return handle_model (llm_params ["model" ], "openai" , llm_params ["model" ])
169
-
170
- elif "ollama" in llm_params ["model" ]:
171
- model_name = llm_params ["model" ].split ("ollama/" )[- 1 ]
172
- token_key = model_name if "model_tokens" not in llm_params else None
173
- model_tokens = 8192 if "model_tokens" not in llm_params else llm_params ["model_tokens" ]
174
- return handle_model (model_name , "ollama" , token_key , model_tokens )
175
-
176
- elif "claude-3-" in llm_params ["model" ]:
177
- return handle_model (llm_params ["model" ], "anthropic" , "claude3" )
178
-
179
- elif llm_params ["model" ].startswith ("mistral" ):
180
- model_name = llm_params ["model" ].split ("/" )[- 1 ]
181
- return handle_model (model_name , "mistralai" , model_name )
182
-
183
- elif "deepseek" in llm_params ["model" ]:
184
- try :
185
- self .model_token = models_tokens ["deepseek" ][llm_params ["model" ]]
186
- except KeyError :
187
- print ("model not found, using default token size (8192)" )
188
- self .model_token = 8192
189
- return DeepSeek (llm_params )
190
-
191
- elif "ernie" in llm_params ["model" ]:
192
- from langchain_community .chat_models import ErnieBotChat
193
-
194
- try :
195
- self .model_token = models_tokens ["ernie" ][llm_params ["model" ]]
196
- except KeyError :
197
- print ("model not found, using default token size (8192)" )
198
- self .model_token = 8192
199
- return ErnieBotChat (llm_params )
200
-
201
- elif "oneapi" in llm_params ["model" ]:
202
- llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
203
- try :
204
- self .model_token = models_tokens ["oneapi" ][llm_params ["model" ]]
205
- except KeyError :
206
- raise KeyError ("Model not supported" )
207
- return OneApi (llm_params )
208
-
209
- elif "nvidia" in llm_params ["model" ]:
210
- from langchain_nvidia_ai_endpoints import ChatNVIDIA
211
-
212
- try :
213
- self .model_token = models_tokens ["nvidia" ][llm_params ["model" ].split ("/" )[- 1 ]]
214
- llm_params ["model" ] = "/" .join (llm_params ["model" ].split ("/" )[1 :])
215
- except KeyError :
216
- raise KeyError ("Model not supported" )
217
- return ChatNVIDIA (llm_params )
142
+ self .model_token = models_tokens [llm_params ["model" ]][llm_params ["model" ]]
143
+ except KeyError :
144
+ print ("Model not found, using default token size (8192)" )
145
+ self .model_token = 8192
218
146
147
+ try :
148
+ if llm_params ["model_provider" ] not in {"oneapi" , "nvidia" , "ernie" , "deepseek" }:
149
+ with warnings .catch_warnings ():
150
+ warnings .simplefilter ("ignore" )
151
+ return init_chat_model (** llm_params )
219
152
else :
220
- model_name = llm_params ["model" ].split ("/" )[- 1 ]
221
- return handle_model (model_name , llm_params ["model" ], model_name )
153
+ if "deepseek" in llm_params ["model" ]:
154
+ return DeepSeek (** llm_params )
155
+
156
+ if "ernie" in llm_params ["model" ]:
157
+ from langchain_community .chat_models import ErnieBotChat
158
+ return ErnieBotChat (** llm_params )
159
+
160
+ if "oneapi" in llm_params ["model" ]:
161
+ return OneApi (** llm_params )
162
+
163
+ if "nvidia" in llm_params ["model" ]:
164
+ from langchain_nvidia_ai_endpoints import ChatNVIDIA
165
+ return ChatNVIDIA (** llm_params )
222
166
223
- except KeyError as e :
224
- print (f"Model not supported : { e } " )
167
+ except Exception as e :
168
+ print (f"Error instancing model : { e } " )
225
169
226
170
227
171
def get_state (self , key = None ) -> dict :
0 commit comments