Skip to content

Commit 2ef6d67

Browse files
authored
Merge pull request #346 from duke147/ernie
support ernie
2 parents 49cdadf + 4e16c9a commit 2ef6d67

File tree

4 files changed

+85
-0
lines changed

4 files changed

+85
-0
lines changed

scrapegraphai/builders/graph_builder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from langchain.chains import create_extraction_chain
77
from ..models import OpenAI, Gemini
88
from ..helpers import nodes_metadata, graph_schema
9+
from ..models.ernie import Ernie
910

1011

1112
class GraphBuilder:
@@ -73,6 +74,8 @@ def _create_llm(self, llm_config: dict):
7374
return OpenAI(llm_params)
7475
elif "gemini" in llm_params["model"]:
7576
return Gemini(llm_params)
77+
elif "ernie" in llm_params["model"]:
78+
return Ernie(llm_params)
7679
raise ValueError("Model not supported")
7780

7881
def _generate_nodes_description(self):

scrapegraphai/graphs/abstract_graph.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
OpenAI,
2525
OneApi
2626
)
27+
from ..models.ernie import Ernie
2728
from ..utils.logging import set_verbosity_debug, set_verbosity_warning
2829

2930
from ..helpers import models_tokens
@@ -272,6 +273,13 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
272273
print("model not found, using default token size (8192)")
273274
self.model_token = 8192
274275
return DeepSeek(llm_params)
276+
elif "ernie" in llm_params["model"]:
277+
try:
278+
self.model_token = models_tokens["ernie"][llm_params["model"]]
279+
except KeyError:
280+
print("model not found, using default token size (8192)")
281+
self.model_token = 8192
282+
return Ernie(llm_params)
275283
else:
276284
raise ValueError("Model provided by the configuration not supported")
277285

scrapegraphai/models/ernie.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
"""
2+
Ollama Module
3+
"""
4+
from langchain_community.chat_models import ErnieBotChat
5+
6+
7+
class Ernie(ErnieBotChat):
8+
"""
9+
A wrapper for the ErnieBotChat class that provides default configuration
10+
and could be extended with additional methods if needed.
11+
12+
Args:
13+
llm_config (dict): Configuration parameters for the language model.
14+
"""
15+
16+
def __init__(self, llm_config: dict):
17+
super().__init__(**llm_config)
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""
2+
Module for testing th smart scraper class
3+
"""
4+
import pytest
5+
from scrapegraphai.graphs import SmartScraperGraph
6+
7+
8+
@pytest.fixture
9+
def graph_config():
10+
"""
11+
Configuration of the graph
12+
"""
13+
return {
14+
"llm": {
15+
"model": "ernie-bot-turbo",
16+
"ernie_client_id": "<ernie_client_id>",
17+
"ernie_client_secret": "<ernie_client_secret>",
18+
"temperature": 0.1
19+
},
20+
"embeddings": {
21+
"model": "ollama/nomic-embed-text",
22+
"temperature": 0,
23+
"base_url": "http://localhost:11434",
24+
}
25+
}
26+
27+
28+
def test_scraping_pipeline(graph_config: dict):
29+
"""
30+
Start of the scraping pipeline
31+
"""
32+
smart_scraper_graph = SmartScraperGraph(
33+
prompt="List me all the news with their description.",
34+
source="https://perinim.github.io/projects",
35+
config=graph_config
36+
)
37+
38+
result = smart_scraper_graph.run()
39+
40+
assert result is not None
41+
42+
43+
def test_get_execution_info(graph_config: dict):
44+
"""
45+
Get the execution info
46+
"""
47+
smart_scraper_graph = SmartScraperGraph(
48+
prompt="List me all the news with their description.",
49+
source="https://perinim.github.io/projects",
50+
config=graph_config
51+
)
52+
53+
smart_scraper_graph.run()
54+
55+
graph_exec_info = smart_scraper_graph.get_execution_info()
56+
57+
assert graph_exec_info is not None

0 commit comments

Comments
 (0)