Skip to content

Commit 4f120e2

Browse files
committed
fix(AbstractGraph): model selection bug
1 parent 4eccc76 commit 4f120e2

File tree

2 files changed

+78
-24
lines changed

2 files changed

+78
-24
lines changed

scrapegraphai/graphs/abstract_graph.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -131,15 +131,15 @@ def _create_llm(self, llm_config: dict) -> object:
131131
"ollama", "oneapi", "nvidia", "groq", "anthropic" "bedrock", "mistralai",
132132
"hugging_face", "deepseek", "ernie", "fireworks"}
133133

134-
split_model_provider = llm_params["model"].split("/")
134+
split_model_provider = llm_params["model"].split("/", 1)
135135
llm_params["model_provider"] = split_model_provider[0]
136-
llm_params["model"] = split_model_provider[1:]
136+
llm_params["model"] = split_model_provider[1]
137137

138138
if llm_params["model_provider"] not in known_providers:
139139
raise ValueError(f"Provider {llm_params['model_provider']} is not supported. If possible, try to use a model instance instead.")
140140

141141
try:
142-
self.model_token = models_tokens[llm_params["model_provider"]].get(llm_params["model"][0])
142+
self.model_token = models_tokens[llm_params["model_provider"]][llm_params["model"]]
143143
except KeyError:
144144
print("Model not found, using default token size (8192)")
145145
self.model_token = 8192
@@ -150,18 +150,21 @@ def _create_llm(self, llm_config: dict) -> object:
150150
warnings.simplefilter("ignore")
151151
return init_chat_model(**llm_params)
152152
else:
153-
if "deepseek" in llm_params["model"]:
153+
if llm_params["model_provider"] == "deepseek":
154154
return DeepSeek(**llm_params)
155155

156-
if "ernie" in llm_params["model"]:
156+
if llm_params["model_provider"] == "ernie":
157157
from langchain_community.chat_models import ErnieBotChat
158158
return ErnieBotChat(**llm_params)
159159

160-
if "oneapi" in llm_params["model"]:
160+
if llm_params["model_provider"] == "oneapi":
161161
return OneApi(**llm_params)
162162

163-
if "nvidia" in llm_params["model"]:
164-
from langchain_nvidia_ai_endpoints import ChatNVIDIA
163+
if llm_params["model_provider"] == "nvidia":
164+
try:
165+
from langchain_nvidia_ai_endpoints import ChatNVIDIA
166+
except ImportError:
167+
raise ImportError("The langchain_nvidia_ai_endpoints module is not installed. Please install it using `pip install langchain_nvidia_ai_endpoints`.")
165168
return ChatNVIDIA(**llm_params)
166169

167170
except Exception as e:

tests/graphs/abstract_graph_test.py

Lines changed: 67 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,80 @@
33
"""
44
import pytest
55
from unittest.mock import patch
6-
from scrapegraphai.graphs import AbstractGraph
6+
from scrapegraphai.graphs import AbstractGraph, BaseGraph
7+
from scrapegraphai.nodes import (
8+
FetchNode,
9+
ParseNode
10+
)
11+
from scrapegraphai.models import OneApi, DeepSeek
12+
from langchain_openai import ChatOpenAI, AzureChatOpenAI
13+
from langchain_community.chat_models import ChatOllama
14+
from langchain_google_genai import ChatGoogleGenerativeAI
15+
16+
17+
18+
class TestGraph(AbstractGraph):
19+
def __init__(self, prompt: str, config: dict):
20+
super().__init__(prompt, config)
21+
22+
def _create_graph(self) -> BaseGraph:
23+
fetch_node = FetchNode(
24+
input="url| local_dir",
25+
output=["doc", "link_urls", "img_urls"],
26+
node_config={
27+
"llm_model": self.llm_model,
28+
"force": self.config.get("force", False),
29+
"cut": self.config.get("cut", True),
30+
"loader_kwargs": self.config.get("loader_kwargs", {}),
31+
"browser_base": self.config.get("browser_base")
32+
}
33+
)
34+
parse_node = ParseNode(
35+
input="doc",
36+
output=["parsed_doc"],
37+
node_config={
38+
"chunk_size": self.model_token
39+
}
40+
)
41+
return BaseGraph(
42+
nodes=[
43+
fetch_node,
44+
parse_node
45+
],
46+
edges=[
47+
(fetch_node, parse_node),
48+
],
49+
entry_point=fetch_node,
50+
graph_name=self.__class__.__name__
51+
)
52+
53+
def run(self) -> str:
54+
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
55+
self.final_state, self.execution_info = self.graph.execute(inputs)
56+
57+
return self.final_state.get("answer", "No answer found.")
58+
759

860
class TestAbstractGraph:
961
@pytest.mark.parametrize("llm_config, expected_model", [
10-
({"model": "openai/gpt-3.5-turbo"}, "ChatOpenAI"),
11-
({"model": "azure_openai/gpt-3.5-turbo"}, "AzureChatOpenAI"),
12-
({"model": "google_genai/gemini-pro"}, "ChatGoogleGenerativeAI"),
13-
({"model": "google_vertexai/chat-bison"}, "ChatVertexAI"),
14-
({"model": "ollama/llama2"}, "Ollama"),
15-
({"model": "oneapi/text-davinci-003"}, "OneApi"),
16-
({"model": "nvidia/clara-instant-1-base"}, "ChatNVIDIA"),
17-
({"model": "deepseek/deepseek-coder-6.7b-instruct"}, "DeepSeek"),
18-
({"model": "ernie/ernie-bot"}, "ErnieBotChat"),
62+
({"model": "openai/gpt-3.5-turbo", "openai_api_key": "sk-randomtest001"}, ChatOpenAI),
63+
({
64+
"model": "azure_openai/gpt-3.5-turbo",
65+
"api_key": "random-api-key",
66+
"api_version": "no version",
67+
"azure_endpoint": "https://www.example.com/"},
68+
AzureChatOpenAI),
69+
({"model": "google_genai/gemini-pro", "google_api_key": "google-key-test"}, ChatGoogleGenerativeAI),
70+
({"model": "ollama/llama2"}, ChatOllama),
71+
({"model": "oneapi/qwen-turbo"}, OneApi),
72+
({"model": "deepseek/deepseek-coder"}, DeepSeek),
1973
])
74+
2075
def test_create_llm(self, llm_config, expected_model):
21-
graph = AbstractGraph("Test prompt", {"llm": llm_config})
76+
graph = TestGraph("Test prompt", {"llm": llm_config})
2277
assert isinstance(graph.llm_model, expected_model)
2378

2479
def test_create_llm_unknown_provider(self):
2580
with pytest.raises(ValueError):
26-
AbstractGraph("Test prompt", {"llm": {"model": "unknown_provider/model"}})
81+
TestGraph("Test prompt", {"llm": {"model": "unknown_provider/model"}})
2782

28-
def test_create_llm_error(self):
29-
with patch("your_module.init_chat_model", side_effect=Exception("Test error")):
30-
with pytest.raises(Exception):
31-
AbstractGraph("Test prompt", {"llm": {"model": "openai/gpt-3.5-turbo"}})

0 commit comments

Comments
 (0)