Skip to content

Commit f73343f

Browse files
committed
fix(AbstractGraph): correct and simplify instancing logic
1 parent 22ab45f commit f73343f

File tree

2 files changed

+39
-95
lines changed

2 files changed

+39
-95
lines changed

scrapegraphai/graphs/abstract_graph.py

Lines changed: 37 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -125,103 +125,47 @@ def _create_llm(self, llm_config: dict) -> object:
125125
self.model_token = llm_params["model_tokens"]
126126
except KeyError as exc:
127127
raise KeyError("model_tokens not specified") from exc
128-
return llm_params["model_instance"]
129-
130-
def handle_model(model_name, provider, token_key, default_token=8192):
131-
try:
132-
self.model_token = models_tokens[provider][token_key]
133-
except KeyError:
134-
print(f"Model not found, using default token size ({default_token})")
135-
self.model_token = default_token
136-
llm_params["model_provider"] = provider
137-
llm_params["model"] = model_name
138-
with warnings.catch_warnings():
139-
warnings.simplefilter("ignore")
140-
return init_chat_model(**llm_params)
141-
142-
known_models = {"chatgpt","gpt","openai", "azure_openai", "google_genai",
143-
"ollama", "oneapi", "nvidia", "groq", "google_vertexai",
144-
"bedrock", "mistralai", "hugging_face", "deepseek", "ernie",
145-
"fireworks", "claude-3-"}
146-
147-
if llm_params["model"].split("/")[0] not in known_models and llm_params["model"].split("-")[0] not in known_models:
148-
raise ValueError(f"Model '{llm_params['model']}' is not supported")
149-
128+
return llm_params["model_instance"]
129+
130+
known_providers = {"openai", "azure_openai", "google_genai", "google_vertexai",
131+
"ollama", "oneapi", "nvidia", "groq", "anthropic" "bedrock", "mistralai",
132+
"hugging_face", "deepseek", "ernie", "fireworks"}
133+
134+
split_model_provider = llm_params["model"].split("/")
135+
llm_params["model_provider"] = split_model_provider[0]
136+
llm_params["model"] = split_model_provider[1:]
137+
138+
if llm_params["model_provider"] not in known_providers:
139+
raise ValueError(f"Provider {llm_params['model_provider']} is not supported. If possible, try to use a model instance instead.")
140+
150141
try:
151-
if "fireworks" in llm_params["model"]:
152-
model_name = "/".join(llm_params["model"].split("/")[1:])
153-
token_key = llm_params["model"].split("/")[-1]
154-
return handle_model(model_name, "fireworks", token_key)
155-
156-
elif "gemini" in llm_params["model"]:
157-
model_name = llm_params["model"].split("/")[-1]
158-
return handle_model(model_name, "google_genai", model_name)
159-
160-
elif llm_params["model"].startswith("claude"):
161-
model_name = llm_params["model"].split("/")[-1]
162-
return handle_model(model_name, "anthropic", model_name)
163-
164-
elif llm_params["model"].startswith("vertexai"):
165-
return handle_model(llm_params["model"], "google_vertexai", llm_params["model"])
166-
167-
elif "gpt-" in llm_params["model"]:
168-
return handle_model(llm_params["model"], "openai", llm_params["model"])
169-
170-
elif "ollama" in llm_params["model"]:
171-
model_name = llm_params["model"].split("ollama/")[-1]
172-
token_key = model_name if "model_tokens" not in llm_params else None
173-
model_tokens = 8192 if "model_tokens" not in llm_params else llm_params["model_tokens"]
174-
return handle_model(model_name, "ollama", token_key, model_tokens)
175-
176-
elif "claude-3-" in llm_params["model"]:
177-
return handle_model(llm_params["model"], "anthropic", "claude3")
178-
179-
elif llm_params["model"].startswith("mistral"):
180-
model_name = llm_params["model"].split("/")[-1]
181-
return handle_model(model_name, "mistralai", model_name)
182-
183-
elif "deepseek" in llm_params["model"]:
184-
try:
185-
self.model_token = models_tokens["deepseek"][llm_params["model"]]
186-
except KeyError:
187-
print("model not found, using default token size (8192)")
188-
self.model_token = 8192
189-
return DeepSeek(llm_params)
190-
191-
elif "ernie" in llm_params["model"]:
192-
from langchain_community.chat_models import ErnieBotChat
193-
194-
try:
195-
self.model_token = models_tokens["ernie"][llm_params["model"]]
196-
except KeyError:
197-
print("model not found, using default token size (8192)")
198-
self.model_token = 8192
199-
return ErnieBotChat(llm_params)
200-
201-
elif "oneapi" in llm_params["model"]:
202-
llm_params["model"] = llm_params["model"].split("/")[-1]
203-
try:
204-
self.model_token = models_tokens["oneapi"][llm_params["model"]]
205-
except KeyError:
206-
raise KeyError("Model not supported")
207-
return OneApi(llm_params)
208-
209-
elif "nvidia" in llm_params["model"]:
210-
from langchain_nvidia_ai_endpoints import ChatNVIDIA
211-
212-
try:
213-
self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]]
214-
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
215-
except KeyError:
216-
raise KeyError("Model not supported")
217-
return ChatNVIDIA(llm_params)
142+
self.model_token = models_tokens[llm_params["model"]][llm_params["model"]]
143+
except KeyError:
144+
print("Model not found, using default token size (8192)")
145+
self.model_token = 8192
218146

147+
try:
148+
if llm_params["model_provider"] not in {"oneapi", "nvidia", "ernie", "deepseek"}:
149+
with warnings.catch_warnings():
150+
warnings.simplefilter("ignore")
151+
return init_chat_model(**llm_params)
219152
else:
220-
model_name = llm_params["model"].split("/")[-1]
221-
return handle_model(model_name, llm_params["model"], model_name)
153+
if "deepseek" in llm_params["model"]:
154+
return DeepSeek(**llm_params)
155+
156+
if "ernie" in llm_params["model"]:
157+
from langchain_community.chat_models import ErnieBotChat
158+
return ErnieBotChat(**llm_params)
159+
160+
if "oneapi" in llm_params["model"]:
161+
return OneApi(**llm_params)
162+
163+
if "nvidia" in llm_params["model"]:
164+
from langchain_nvidia_ai_endpoints import ChatNVIDIA
165+
return ChatNVIDIA(**llm_params)
222166

223-
except KeyError as e:
224-
print(f"Model not supported: {e}")
167+
except Exception as e:
168+
print(f"Error instancing model: {e}")
225169

226170

227171
def get_state(self, key=None) -> dict:

scrapegraphai/helpers/models_tokens.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@
102102
"oneapi": {
103103
"qwen-turbo": 6000,
104104
},
105-
"nvdia": {
105+
"nvidia": {
106106
"meta/llama3-70b-instruct": 419,
107107
"meta/llama3-8b-instruct": 419,
108108
"nemotron-4-340b-instruct": 1024,
@@ -127,7 +127,7 @@
127127
"gemma-7b-it": 8192,
128128
"claude-3-haiku-20240307'": 8192,
129129
},
130-
"claude": {
130+
"anthropic": {
131131
"claude_instant": 100000,
132132
"claude2": 9000,
133133
"claude2.1": 200000,

0 commit comments

Comments
 (0)