|
19 | 19 | from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
|
20 | 20 | from langchain_fireworks import FireworksEmbeddings, ChatFireworks
|
21 | 21 | from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings, ChatOpenAI, AzureChatOpenAI
|
22 |
| -from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings |
| 22 | +from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings, ChatNVIDIA |
| 23 | +from langchain_community.chat_models import ErnieBotChat |
23 | 24 | from ..helpers import models_tokens
|
24 | 25 | from ..models import (
|
25 | 26 | OneApi,
|
26 |
| - Nvidia, |
27 | 27 | DeepSeek
|
28 | 28 | )
|
29 |
| -from ..models.ernie import Ernie |
| 29 | + |
30 | 30 | from langchain.chat_models import init_chat_model
|
31 | 31 |
|
32 | 32 | from ..utils.logging import set_verbosity_debug, set_verbosity_warning, set_verbosity_info
|
@@ -192,7 +192,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
|
192 | 192 | llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
|
193 | 193 | except KeyError as exc:
|
194 | 194 | raise KeyError("Model not supported") from exc
|
195 |
| - return Nvidia(llm_params) |
| 195 | + return ChatNVIDIA(llm_params) |
196 | 196 | elif "gemini" in llm_params["model"]:
|
197 | 197 | llm_params["model"] = llm_params["model"].split("/")[-1]
|
198 | 198 | try:
|
@@ -289,7 +289,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
|
289 | 289 | except KeyError:
|
290 | 290 | print("model not found, using default token size (8192)")
|
291 | 291 | self.model_token = 8192
|
292 |
| - return Ernie(llm_params) |
| 292 | + return ErnieBotChat(llm_params) |
293 | 293 | else:
|
294 | 294 | raise ValueError("Model provided by the configuration not supported")
|
295 | 295 |
|
@@ -320,7 +320,7 @@ def _create_default_embedder(self, llm_config=None) -> object:
|
320 | 320 | return AzureOpenAIEmbeddings()
|
321 | 321 | elif isinstance(self.llm_model, ChatFireworks):
|
322 | 322 | return FireworksEmbeddings(model=self.llm_model.model_name)
|
323 |
| - elif isinstance(self.llm_model, Nvidia): |
| 323 | + elif isinstance(self.llm_model, ChatNVIDIA): |
324 | 324 | return NVIDIAEmbeddings(model=self.llm_model.model_name)
|
325 | 325 | elif isinstance(self.llm_model, ChatOllama):
|
326 | 326 | # unwrap the kwargs from the model whihc is a dict
|
|
0 commit comments