Skip to content

support ernie #346

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions scrapegraphai/builders/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from langchain.chains import create_extraction_chain
from ..models import OpenAI, Gemini
from ..helpers import nodes_metadata, graph_schema
from ..models.ernie import Ernie


class GraphBuilder:
Expand Down Expand Up @@ -73,6 +74,8 @@ def _create_llm(self, llm_config: dict):
return OpenAI(llm_params)
elif "gemini" in llm_params["model"]:
return Gemini(llm_params)
elif "ernie" in llm_params["model"]:
return Ernie(llm_params)
raise ValueError("Model not supported")

def _generate_nodes_description(self):
Expand Down
8 changes: 8 additions & 0 deletions scrapegraphai/graphs/abstract_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
OpenAI,
OneApi
)
from ..models.ernie import Ernie
from ..utils.logging import set_verbosity_debug, set_verbosity_warning

from ..helpers import models_tokens
Expand Down Expand Up @@ -272,6 +273,13 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
print("model not found, using default token size (8192)")
self.model_token = 8192
return DeepSeek(llm_params)
elif "ernie" in llm_params["model"]:
try:
self.model_token = models_tokens["ernie"][llm_params["model"]]
except KeyError:
print("model not found, using default token size (8192)")
self.model_token = 8192
return Ernie(llm_params)
else:
raise ValueError("Model provided by the configuration not supported")

Expand Down
17 changes: 17 additions & 0 deletions scrapegraphai/models/ernie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
Ollama Module
"""
from langchain_community.chat_models import ErnieBotChat


class Ernie(ErnieBotChat):
"""
A wrapper for the ErnieBotChat class that provides default configuration
and could be extended with additional methods if needed.

Args:
llm_config (dict): Configuration parameters for the language model.
"""

def __init__(self, llm_config: dict):
super().__init__(**llm_config)
57 changes: 57 additions & 0 deletions tests/graphs/smart_scraper_ernie_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
Module for testing th smart scraper class
"""
import pytest
from scrapegraphai.graphs import SmartScraperGraph


@pytest.fixture
def graph_config():
"""
Configuration of the graph
"""
return {
"llm": {
"model": "ernie-bot-turbo",
"ernie_client_id": "<ernie_client_id>",
"ernie_client_secret": "<ernie_client_secret>",
"temperature": 0.1
},
"embeddings": {
"model": "ollama/nomic-embed-text",
"temperature": 0,
"base_url": "http://localhost:11434",
}
}


def test_scraping_pipeline(graph_config: dict):
"""
Start of the scraping pipeline
"""
smart_scraper_graph = SmartScraperGraph(
prompt="List me all the news with their description.",
source="https://perinim.github.io/projects",
config=graph_config
)

result = smart_scraper_graph.run()

assert result is not None


def test_get_execution_info(graph_config: dict):
"""
Get the execution info
"""
smart_scraper_graph = SmartScraperGraph(
prompt="List me all the news with their description.",
source="https://perinim.github.io/projects",
config=graph_config
)

smart_scraper_graph.run()

graph_exec_info = smart_scraper_graph.get_execution_info()

assert graph_exec_info is not None
Loading