Skip to content

Commit bb73d91

Browse files
committed
refactor: reuse code for common interface models
1 parent b17756d commit bb73d91

File tree

1 file changed

+49
-108
lines changed

1 file changed

+49
-108
lines changed

scrapegraphai/graphs/abstract_graph.py

Lines changed: 49 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -146,138 +146,61 @@ def _create_llm(self, llm_config: dict) -> object:
146146
raise KeyError("model_tokens not specified") from exc
147147
return llm_params["model_instance"]
148148

149-
# Instantiate the language model based on the model name
150-
if "gpt-" in llm_params["model"]:
149+
# Instantiate the language model based on the model name (models that use the common interface)
150+
def handle_model(model_name, provider, token_key, default_token=8192):
151151
try:
152-
self.model_token = models_tokens["openai"][llm_params["model"]]
153-
llm_params["model_provider"] = "openai"
154-
except KeyError as exc:
155-
raise KeyError("Model not supported") from exc
152+
self.model_token = models_tokens[provider][token_key]
153+
except KeyError:
154+
print(f"Model not found, using default token size ({default_token})")
155+
self.model_token = default_token
156+
llm_params["model_provider"] = provider
157+
llm_params["model"] = model_name
156158
return init_chat_model(**llm_params)
157159

158-
if "oneapi" in llm_params["model"]:
159-
# take the model after the last dash
160-
llm_params["model"] = llm_params["model"].split("/")[-1]
161-
try:
162-
self.model_token = models_tokens["oneapi"][llm_params["model"]]
163-
except KeyError as exc:
164-
raise KeyError("Model not supported") from exc
165-
return OneApi(llm_params)
160+
if "gpt-" in llm_params["model"]:
161+
return handle_model(llm_params["model"], "openai", llm_params["model"])
166162

167163
if "fireworks" in llm_params["model"]:
168-
try:
169-
self.model_token = models_tokens["fireworks"][llm_params["model"].split("/")[-1]]
170-
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
171-
except KeyError as exc:
172-
raise KeyError("Model not supported") from exc
173-
llm_params["model_provider"] = "fireworks"
174-
return init_chat_model(**llm_params)
164+
model_name = "/".join(llm_params["model"].split("/")[1:])
165+
token_key = llm_params["model"].split("/")[-1]
166+
return handle_model(model_name, "fireworks", token_key)
175167

176168
if "azure" in llm_params["model"]:
177-
# take the model after the last dash
178-
llm_params["model"] = llm_params["model"].split("/")[-1]
179-
try:
180-
self.model_token = models_tokens["azure"][llm_params["model"]]
181-
except KeyError as exc:
182-
raise KeyError("Model not supported") from exc
183-
llm_params["model_provider"] = "azure_openai"
184-
return init_chat_model(**llm_params)
185-
186-
if "nvidia" in llm_params["model"]:
187-
try:
188-
self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]]
189-
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
190-
except KeyError as exc:
191-
raise KeyError("Model not supported") from exc
192-
return ChatNVIDIA(llm_params)
169+
model_name = llm_params["model"].split("/")[-1]
170+
return handle_model(model_name, "azure_openai", model_name)
193171

194172
if "gemini" in llm_params["model"]:
195-
llm_params["model"] = llm_params["model"].split("/")[-1]
196-
try:
197-
self.model_token = models_tokens["gemini"][llm_params["model"]]
198-
except KeyError as exc:
199-
raise KeyError("Model not supported") from exc
200-
llm_params["model_provider"] = "google_genai "
201-
return init_chat_model(**llm_params)
173+
model_name = llm_params["model"].split("/")[-1]
174+
return handle_model(model_name, "google_genai", model_name)
202175

203176
if llm_params["model"].startswith("claude"):
204-
llm_params["model"] = llm_params["model"].split("/")[-1]
205-
try:
206-
self.model_token = models_tokens["claude"][llm_params["model"]]
207-
except KeyError as exc:
208-
raise KeyError("Model not supported") from exc
209-
llm_params["model_provider"] = "anthropic"
210-
return init_chat_model(**llm_params)
177+
model_name = llm_params["model"].split("/")[-1]
178+
return handle_model(model_name, "anthropic", model_name)
211179

212180
if llm_params["model"].startswith("vertexai"):
213-
try:
214-
self.model_token = models_tokens["vertexai"][llm_params["model"]]
215-
except KeyError as exc:
216-
raise KeyError("Model not supported") from exc
217-
llm_params["model_provider"] = "google_vertexai"
218-
return init_chat_model(**llm_params)
181+
return handle_model(llm_params["model"], "google_vertexai", llm_params["model"])
219182

220183
if "ollama" in llm_params["model"]:
221-
llm_params["model"] = llm_params["model"].split("ollama/")[-1]
222-
llm_params["model_provider"] = "ollama"
223-
224-
# allow user to set model_tokens in config
225-
try:
226-
if "model_tokens" in llm_params:
227-
self.model_token = llm_params["model_tokens"]
228-
elif llm_params["model"] in models_tokens["ollama"]:
229-
try:
230-
self.model_token = models_tokens["ollama"][llm_params["model"]]
231-
except KeyError as exc:
232-
print("model not found, using default token size (8192)")
233-
self.model_token = 8192
234-
else:
235-
self.model_token = 8192
236-
except AttributeError:
237-
self.model_token = 8192
238-
239-
return init_chat_model(**llm_params)
184+
model_name = llm_params["model"].split("ollama/")[-1]
185+
token_key = model_name if "model_tokens" not in llm_params else llm_params["model_tokens"]
186+
return handle_model(model_name, "ollama", token_key)
240187

241188
if "hugging_face" in llm_params["model"]:
242-
llm_params["model"] = llm_params["model"].split("/")[-1]
243-
try:
244-
self.model_token = models_tokens["hugging_face"][llm_params["model"]]
245-
except KeyError:
246-
print("model not found, using default token size (8192)")
247-
self.model_token = 8192
248-
llm_params["model_provider"] = "hugging_face"
249-
return init_chat_model(**llm_params)
189+
model_name = llm_params["model"].split("/")[-1]
190+
return handle_model(model_name, "hugging_face", model_name)
250191

251192
if "groq" in llm_params["model"]:
252-
llm_params["model"] = llm_params["model"].split("/")[-1]
253-
254-
try:
255-
self.model_token = models_tokens["groq"][llm_params["model"]]
256-
except KeyError:
257-
print("model not found, using default token size (8192)")
258-
self.model_token = 8192
259-
llm_params["model_provider"] = "groq"
260-
return init_chat_model(**llm_params)
193+
model_name = llm_params["model"].split("/")[-1]
194+
return handle_model(model_name, "groq", model_name)
261195

262196
if "bedrock" in llm_params["model"]:
263-
llm_params["model"] = llm_params["model"].split("/")[-1]
264-
try:
265-
self.model_token = models_tokens["bedrock"][llm_params["model"]]
266-
except KeyError:
267-
print("model not found, using default token size (8192)")
268-
self.model_token = 8192
269-
llm_params["model_provider"] = "bedrock"
270-
return init_chat_model(**llm_params)
197+
model_name = llm_params["model"].split("/")[-1]
198+
return handle_model(model_name, "bedrock", model_name)
271199

272200
if "claude-3-" in llm_params["model"]:
273-
try:
274-
self.model_token = models_tokens["claude"]["claude3"]
275-
except KeyError:
276-
print("model not found, using default token size (8192)")
277-
self.model_token = 8192
278-
llm_params["model_provider"] = "anthropic"
279-
return init_chat_model(**llm_params)
201+
return handle_model(llm_params["model"], "anthropic", "claude3")
280202

203+
# Instantiate the language model based on the model name (models that do not use the common interface)
281204
if "deepseek" in llm_params["model"]:
282205
try:
283206
self.model_token = models_tokens["deepseek"][llm_params["model"]]
@@ -293,7 +216,25 @@ def _create_llm(self, llm_config: dict) -> object:
293216
print("model not found, using default token size (8192)")
294217
self.model_token = 8192
295218
return ErnieBotChat(llm_params)
219+
220+
if "oneapi" in llm_params["model"]:
221+
# take the model after the last dash
222+
llm_params["model"] = llm_params["model"].split("/")[-1]
223+
try:
224+
self.model_token = models_tokens["oneapi"][llm_params["model"]]
225+
except KeyError as exc:
226+
raise KeyError("Model not supported") from exc
227+
return OneApi(llm_params)
228+
229+
if "nvidia" in llm_params["model"]:
230+
try:
231+
self.model_token = models_tokens["nvidia"][llm_params["model"].split("/")[-1]]
232+
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
233+
except KeyError as exc:
234+
raise KeyError("Model not supported") from exc
235+
return ChatNVIDIA(llm_params)
296236

237+
# Raise an error if the model did not match any of the previous cases
297238
raise ValueError("Model provided by the configuration not supported")
298239

299240
def _create_default_embedder(self, llm_config=None) -> object:

0 commit comments

Comments
 (0)