Skip to content

Commit c72c077

Browse files
committed
refactoring of the nodes
1 parent 103c21c commit c72c077

File tree

3 files changed

+30
-25
lines changed

3 files changed

+30
-25
lines changed

scrapegraphai/graphs/screenshot_scraper_graph.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,32 @@
1+
"""
2+
ScreenshotScraperGraph Module
13
"""
2-
ScreenshotScraperGraph Module
3-
"""
4-
54
from typing import Optional
65
import logging
76
from pydantic import BaseModel
87
from .base_graph import BaseGraph
98
from .abstract_graph import AbstractGraph
9+
from ..nodes import ( FetchScreenNode, GenerateAnswerFromImageNode, )
1010

11-
from ..nodes import (
12-
FetchScreenNode,
13-
GenerateAnswerFromImageNode,
14-
)
11+
class ScreenshotScraperGraph(AbstractGraph):
12+
"""
13+
A graph instance representing the web scraping workflow for images.
1514
16-
class ScreenshotScraperGraph(AbstractGraph):
17-
"""
18-
smart_scraper.run()
19-
)
15+
Attributes:
16+
prompt (str): The input text to be scraped.
17+
config (dict): Configuration parameters for the graph.
18+
source (str): The source URL or image link to scrape from.
19+
20+
Methods:
21+
__init__(prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None)
22+
Initializes the ScreenshotScraperGraph instance with the given prompt,
23+
source, and configuration parameters.
24+
25+
_create_graph()
26+
Creates a graph of nodes representing the web scraping workflow for images.
27+
28+
run()
29+
Executes the scraping process and returns the answer to the prompt.
2030
"""
2131

2232
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
@@ -25,10 +35,10 @@ def __init__(self, prompt: str, source: str, config: dict, schema: Optional[Base
2535

2636
def _create_graph(self) -> BaseGraph:
2737
"""
28-
Creates the graph of nodes representing the workflow for web scraping.
38+
Creates the graph of nodes representing the workflow for web scraping with images.
2939
3040
Returns:
31-
BaseGraph: A graph instance representing the web scraping workflow.
41+
BaseGraph: A graph instance representing the web scraping workflow for images.
3242
"""
3343
fetch_screen_node = FetchScreenNode(
3444
input="url",
@@ -38,8 +48,8 @@ def _create_graph(self) -> BaseGraph:
3848
}
3949
)
4050
generate_answer_from_image_node = GenerateAnswerFromImageNode(
41-
input="doc",
42-
output=["parsed_doc"],
51+
input="imgs",
52+
output=["answer"],
4353
node_config={
4454
"config": self.config
4555
}

scrapegraphai/nodes/fetch_screen_node.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ def execute(self, state: dict) -> dict:
2525
Captures screenshots from the input URL and stores them in the state dictionary as bytes.
2626
"""
2727

28-
screenshots = []
29-
3028
with sync_playwright() as p:
3129
browser = p.chromium.launch()
3230
page = browser.new_page()
@@ -49,10 +47,7 @@ def capture_screenshot(scroll_position, counter):
4947

5048
browser.close()
5149

52-
for screenshot_data in screenshot_data_list:
53-
screenshots.append(screenshot_data)
54-
5550
state["link"] = self.url
56-
state['screenshots'] = screenshots
51+
state['screenshots'] = screenshot_data_list
5752

5853
return state

scrapegraphai/nodes/generate_answer_from_image_node.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ def execute(self, state: dict) -> dict:
3131

3232
api_key = self.node_config.get("config", {}).get("llm", {}).get("api_key", "")
3333

34-
supported_models = ["gpt-4o", "gpt-4o-mini", "gpt-4-turbo"]
34+
supported_models = ("gpt-4o", "gpt-4o-mini", "gpt-4-turbo")
3535

3636
if self.node_config["config"]["llm"]["model"] not in supported_models:
37-
raise ValueError(f"""Model '{self.node_config['config']['llm']['model']}'
37+
raise ValueError(f"""Model '{self.node_config['config']['llm']['model']}'
3838
is not supported. Supported models are:
3939
{', '.join(supported_models)}.""")
4040

@@ -47,7 +47,7 @@ def execute(self, state: dict) -> dict:
4747
}
4848

4949
payload = {
50-
"model": "gpt-4o-mini",
50+
"model": self.node_config["config"]["llm"]["model"],
5151
"messages": [
5252
{
5353
"role": "user",
@@ -72,7 +72,7 @@ def execute(self, state: dict) -> dict:
7272
response = requests.post("https://api.openai.com/v1/chat/completions",
7373
headers=headers,
7474
json=payload,
75-
timeout=10 )
75+
timeout=10)
7676
result = response.json()
7777

7878
response_text = result.get('choices',

0 commit comments

Comments
 (0)