Skip to content

Commit bc2c996

Browse files
committed
refactor: remove redundant wrappers for Ernie and Nvidia
1 parent 9275486 commit bc2c996

File tree

3 files changed

+6
-48
lines changed

3 files changed

+6
-48
lines changed

scrapegraphai/graphs/abstract_graph.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@
1919
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
2020
from langchain_fireworks import FireworksEmbeddings, ChatFireworks
2121
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings, ChatOpenAI, AzureChatOpenAI
22-
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
22+
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings, ChatNVIDIA
23+
from langchain_community.chat_models import ErnieBotChat
2324
from ..helpers import models_tokens
2425
from ..models import (
2526
OneApi,
26-
Nvidia,
2727
DeepSeek
2828
)
29-
from ..models.ernie import Ernie
29+
3030
from langchain.chat_models import init_chat_model
3131

3232
from ..utils.logging import set_verbosity_debug, set_verbosity_warning, set_verbosity_info
@@ -192,7 +192,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
192192
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
193193
except KeyError as exc:
194194
raise KeyError("Model not supported") from exc
195-
return Nvidia(llm_params)
195+
return ChatNVIDIA(llm_params)
196196
elif "gemini" in llm_params["model"]:
197197
llm_params["model"] = llm_params["model"].split("/")[-1]
198198
try:
@@ -289,7 +289,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
289289
except KeyError:
290290
print("model not found, using default token size (8192)")
291291
self.model_token = 8192
292-
return Ernie(llm_params)
292+
return ErnieBotChat(llm_params)
293293
else:
294294
raise ValueError("Model provided by the configuration not supported")
295295

@@ -320,7 +320,7 @@ def _create_default_embedder(self, llm_config=None) -> object:
320320
return AzureOpenAIEmbeddings()
321321
elif isinstance(self.llm_model, ChatFireworks):
322322
return FireworksEmbeddings(model=self.llm_model.model_name)
323-
elif isinstance(self.llm_model, Nvidia):
323+
elif isinstance(self.llm_model, ChatNVIDIA):
324324
return NVIDIAEmbeddings(model=self.llm_model.model_name)
325325
elif isinstance(self.llm_model, ChatOllama):
326326
# unwrap the kwargs from the model whihc is a dict

scrapegraphai/models/ernie.py

Lines changed: 0 additions & 17 deletions
This file was deleted.

scrapegraphai/models/nvidia.py

Lines changed: 0 additions & 25 deletions
This file was deleted.

0 commit comments

Comments
 (0)