Skip to content

Commit 612c644

Browse files
committed
feat: implement ScrapeGraph class for only web scraping automation
1 parent e0fc457 commit 612c644

File tree

1 file changed

+98
-0
lines changed

1 file changed

+98
-0
lines changed

scrapegraphai/graphs/scrape_graph.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""
2+
SmartScraperGraph Module
3+
"""
4+
from typing import Optional
5+
from pydantic import BaseModel
6+
from .base_graph import BaseGraph
7+
from .abstract_graph import AbstractGraph
8+
from ..nodes import (
9+
FetchNode,
10+
ParseNode,
11+
)
12+
13+
class ScrapeGraph(AbstractGraph):
14+
"""
15+
ScrapeGraph is a scraping pipeline that automates the process of
16+
extracting information from web pages.
17+
18+
Attributes:
19+
prompt (str): The prompt for the graph.
20+
source (str): The source of the graph.
21+
config (dict): Configuration parameters for the graph.
22+
schema (BaseModel): The schema for the graph output.
23+
verbose (bool): A flag indicating whether to show print statements during execution.
24+
headless (bool): A flag indicating whether to run the graph in headless mode.
25+
26+
Args:
27+
prompt (str): The prompt for the graph.
28+
source (str): The source of the graph.
29+
config (dict): Configuration parameters for the graph.
30+
schema (BaseModel): The schema for the graph output.
31+
32+
Example:
33+
>>> scraper = ScraperGraph(
34+
... "https://en.wikipedia.org/wiki/Chioggia",
35+
... {"llm": {"model": "openai/gpt-3.5-turbo"}}
36+
... )
37+
>>> result = smart_scraper.run()
38+
)
39+
"""
40+
41+
def __init__(self, source: str, config: dict, prompt: str = "", schema: Optional[BaseModel] = None):
42+
super().__init__(prompt, config, source, schema)
43+
44+
self.input_key = "url" if source.startswith("http") else "local_dir"
45+
46+
def _create_graph(self) -> BaseGraph:
47+
"""
48+
Creates the graph of nodes representing the workflow for web scraping.
49+
50+
Returns:
51+
BaseGraph: A graph instance representing the web scraping workflow.
52+
"""
53+
fetch_node = FetchNode(
54+
input="url| local_dir",
55+
output=["doc"],
56+
node_config={
57+
"llm_model": self.llm_model,
58+
"force": self.config.get("force", False),
59+
"cut": self.config.get("cut", True),
60+
"loader_kwargs": self.config.get("loader_kwargs", {}),
61+
"browser_base": self.config.get("browser_base"),
62+
"scrape_do": self.config.get("scrape_do")
63+
}
64+
)
65+
66+
parse_node = ParseNode(
67+
input="doc",
68+
output=["parsed_doc"],
69+
node_config={
70+
"llm_model": self.llm_model,
71+
"chunk_size": self.model_token
72+
}
73+
)
74+
75+
return BaseGraph(
76+
nodes=[
77+
fetch_node,
78+
parse_node,
79+
],
80+
edges=[
81+
(fetch_node, parse_node),
82+
],
83+
entry_point=fetch_node,
84+
graph_name=self.__class__.__name__
85+
)
86+
87+
def run(self) -> str:
88+
"""
89+
Executes the scraping process and returns the scraping content.
90+
91+
Returns:
92+
str: The scraping content.
93+
"""
94+
95+
inputs = {"user_prompt": self.prompt, self.input_key: self.source}
96+
self.final_state, self.execution_info = self.graph.execute(inputs)
97+
98+
return self.final_state.get("parsed_doc", "No document found.")

0 commit comments

Comments
 (0)