-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Screenshot scraper integration #558
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
8e3d5de
5eb3cff
103c21c
c72c077
79fa3f6
0bf79b5
f60aa3a
f774fe4
fee77d1
d248646
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
""" | ||
Basic example of scraping pipeline using SmartScraper | ||
""" | ||
|
||
import os | ||
import json | ||
from dotenv import load_dotenv | ||
from scrapegraphai.graphs import ScreenshotScraperGraph | ||
from scrapegraphai.utils import prettify_exec_info | ||
|
||
load_dotenv() | ||
|
||
# ************************************************ | ||
# Define the configuration for the graph | ||
# ************************************************ | ||
|
||
|
||
graph_config = { | ||
"llm": { | ||
"api_key": os.getenv("OPENAI_API_KEY"), | ||
"model": "gpt-4o", | ||
}, | ||
"verbose": True, | ||
"headless": False, | ||
} | ||
|
||
# ************************************************ | ||
# Create the ScreenshotScraperGraph instance and run it | ||
# ************************************************ | ||
|
||
smart_scraper_graph = ScreenshotScraperGraph( | ||
prompt="List me all the projects", | ||
source="https://perinim.github.io/projects/", | ||
config=graph_config | ||
) | ||
|
||
result = smart_scraper_graph.run() | ||
print(json.dumps(result, indent=4)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
""" | ||
ScreenshotScraperGraph Module | ||
""" | ||
from typing import Optional | ||
import logging | ||
from pydantic import BaseModel | ||
from .base_graph import BaseGraph | ||
from .abstract_graph import AbstractGraph | ||
from ..nodes import ( FetchScreenNode, GenerateAnswerFromImageNode, ) | ||
|
||
class ScreenshotScraperGraph(AbstractGraph): | ||
""" | ||
A graph instance representing the web scraping workflow for images. | ||
|
||
Attributes: | ||
prompt (str): The input text to be scraped. | ||
config (dict): Configuration parameters for the graph. | ||
source (str): The source URL or image link to scrape from. | ||
|
||
Methods: | ||
__init__(prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None) | ||
Initializes the ScreenshotScraperGraph instance with the given prompt, | ||
source, and configuration parameters. | ||
|
||
_create_graph() | ||
Creates a graph of nodes representing the web scraping workflow for images. | ||
|
||
run() | ||
Executes the scraping process and returns the answer to the prompt. | ||
""" | ||
|
||
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None): | ||
super().__init__(prompt, config, source, schema) | ||
|
||
|
||
def _create_graph(self) -> BaseGraph: | ||
""" | ||
Creates the graph of nodes representing the workflow for web scraping with images. | ||
|
||
Returns: | ||
BaseGraph: A graph instance representing the web scraping workflow for images. | ||
""" | ||
fetch_screen_node = FetchScreenNode( | ||
input="url", | ||
output=["imgs"], | ||
node_config={ | ||
"link": self.source | ||
} | ||
) | ||
generate_answer_from_image_node = GenerateAnswerFromImageNode( | ||
input="imgs", | ||
output=["answer"], | ||
node_config={ | ||
"config": self.config | ||
} | ||
) | ||
|
||
return BaseGraph( | ||
nodes=[ | ||
fetch_screen_node, | ||
generate_answer_from_image_node, | ||
], | ||
edges=[ | ||
(fetch_screen_node, generate_answer_from_image_node), | ||
], | ||
entry_point=fetch_screen_node, | ||
graph_name=self.__class__.__name__ | ||
) | ||
|
||
def run(self) -> str: | ||
""" | ||
Executes the scraping process and returns the answer to the prompt. | ||
|
||
Returns: | ||
str: The answer to the prompt. | ||
""" | ||
|
||
inputs = {"user_prompt": self.prompt} | ||
self.final_state, self.execution_info = self.graph.execute(inputs) | ||
|
||
return self.final_state.get("answer", "No answer found.") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
""" | ||
fetch_screen_node module | ||
""" | ||
from typing import List, Optional | ||
from playwright.sync_api import sync_playwright | ||
from .base_node import BaseNode | ||
from ..utils.logging import get_logger | ||
|
||
class FetchScreenNode(BaseNode): | ||
""" | ||
FetchScreenNode captures screenshots from a given URL and stores the image data as bytes. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
input: str, | ||
output: List[str], | ||
node_config: Optional[dict] = None, | ||
node_name: str = "FetchScreenNode", | ||
): | ||
super().__init__(node_name, "node", input, output, 2, node_config) | ||
self.url = node_config.get("link") | ||
|
||
def execute(self, state: dict) -> dict: | ||
""" | ||
Captures screenshots from the input URL and stores them in the state dictionary as bytes. | ||
""" | ||
self.logger.info(f"--- Executing {self.node_name} Node ---") | ||
|
||
with sync_playwright() as p: | ||
browser = p.chromium.launch() | ||
page = browser.new_page() | ||
page.goto(self.url) | ||
|
||
viewport_height = page.viewport_size["height"] | ||
|
||
screenshot_counter = 1 | ||
|
||
screenshot_data_list = [] | ||
|
||
def capture_screenshot(scroll_position, counter): | ||
page.evaluate(f"window.scrollTo(0, {scroll_position});") | ||
screenshot_data = page.screenshot() | ||
screenshot_data_list.append(screenshot_data) | ||
|
||
capture_screenshot(0, screenshot_counter) | ||
screenshot_counter += 1 | ||
capture_screenshot(viewport_height, screenshot_counter) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. infinite-scrolling web pages (with dynamic JS rendering) may have undefined behavior with the viewport height, as that is bound to change as well as you scroll through the page in the browser. looks fine for now, but I would maybe add a disclaimer in the docs saying screenshot capturing might not work as well for those; collecting only two screenshots is fine as well, but that number might better be a function of the viewport height, so that we don't miss parts of the page content in case a page is very very long (although not infinite-scrolling) |
||
|
||
browser.close() | ||
|
||
state["link"] = self.url | ||
state['screenshots'] = screenshot_data_list | ||
|
||
return state |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import base64 | ||
import asyncio | ||
from typing import List, Optional | ||
import aiohttp | ||
from .base_node import BaseNode | ||
from ..utils.logging import get_logger | ||
|
||
class GenerateAnswerFromImageNode(BaseNode): | ||
""" | ||
GenerateAnswerFromImageNode analyzes images from the state dictionary using the OpenAI API | ||
and updates the state with the consolidated answers. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
input: str, | ||
output: List[str], | ||
node_config: Optional[dict] = None, | ||
node_name: str = "GenerateAnswerFromImageNode", | ||
): | ||
super().__init__(node_name, "node", input, output, 2, node_config) | ||
|
||
async def process_image(self, session, api_key, image_data, user_prompt): | ||
# Convert image data to base64 | ||
base64_image = base64.b64encode(image_data).decode('utf-8') | ||
|
||
headers = { | ||
"Content-Type": "application/json", | ||
"Authorization": f"Bearer {api_key}" | ||
} | ||
|
||
payload = { | ||
"model": self.node_config["config"]["llm"]["model"], | ||
"messages": [ | ||
{ | ||
"role": "user", | ||
"content": [ | ||
{ | ||
"type": "text", | ||
"text": user_prompt | ||
}, | ||
{ | ||
"type": "image_url", | ||
"image_url": { | ||
"url": f"data:image/jpeg;base64,{base64_image}" | ||
} | ||
} | ||
] | ||
} | ||
], | ||
"max_tokens": 300 | ||
} | ||
|
||
async with session.post("https://api.openai.com/v1/chat/completions", | ||
headers=headers, json=payload) as response: | ||
result = await response.json() | ||
return result.get('choices', [{}])[0].get('message', {}).get('content', 'No response') | ||
|
||
async def execute_async(self, state: dict) -> dict: | ||
""" | ||
Processes images from the state, generates answers, | ||
consolidates the results, and updates the state asynchronously. | ||
""" | ||
self.logger.info(f"--- Executing {self.node_name} Node ---") | ||
|
||
images = state.get('screenshots', []) | ||
analyses = [] | ||
|
||
supported_models = ("gpt-4o", "gpt-4o-mini", "gpt-4-turbo") | ||
|
||
if self.node_config["config"]["llm"]["model"] not in supported_models: | ||
raise ValueError(f"""Model '{self.node_config['config']['llm']['model']}' | ||
is not supported. Supported models are: | ||
{', '.join(supported_models)}.""") | ||
|
||
api_key = self.node_config.get("config", {}).get("llm", {}).get("api_key", "") | ||
|
||
async with aiohttp.ClientSession() as session: | ||
tasks = [ | ||
self.process_image(session, api_key, image_data, | ||
state.get("user_prompt", "Extract information from the image")) | ||
for image_data in images | ||
] | ||
|
||
analyses = await asyncio.gather(*tasks) | ||
|
||
consolidated_analysis = " ".join(analyses) | ||
|
||
state['answer'] = { | ||
"consolidated_analysis": consolidated_analysis | ||
} | ||
|
||
return state | ||
|
||
def execute(self, state: dict) -> dict: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wrap the function with a safer guard on the running event loop, e.g.: def execute(self, state: dict) -> dict:
"""
Wrapper to run the asynchronous execute_async function in a synchronous context.
"""
try:
eventloop = asyncio.get_event_loop()
except RuntimeError:
eventloop = None
if eventloop and eventloop.is_running():
state = eventloop.run_until_complete(self.execute_async(state))
else:
state = asyncio.run(self.execute_async(state))
return state |
||
""" | ||
Wrapper to run the asynchronous execute_async function in a synchronous context. | ||
""" | ||
return asyncio.run(self.execute_async(state)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import os | ||
import pytest | ||
import json | ||
from scrapegraphai.graphs import ScreenshotScraperGraph | ||
from dotenv import load_dotenv | ||
|
||
# Load environment variables | ||
load_dotenv() | ||
|
||
# Define a fixture for the graph configuration | ||
@pytest.fixture | ||
def graph_config(): | ||
""" | ||
Creation of the graph | ||
""" | ||
return { | ||
"llm": { | ||
"api_key": os.getenv("OPENAI_API_KEY"), | ||
"model": "gpt-4o", | ||
}, | ||
"verbose": True, | ||
"headless": False, | ||
} | ||
|
||
def test_screenshot_scraper_graph(graph_config): | ||
""" | ||
test | ||
""" | ||
smart_scraper_graph = ScreenshotScraperGraph( | ||
prompt="List me all the projects", | ||
source="https://perinim.github.io/projects/", | ||
config=graph_config | ||
) | ||
|
||
result = smart_scraper_graph.run() | ||
|
||
assert result is not None, "The result should not be None" | ||
|
||
print(json.dumps(result, indent=4)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why putting the key here if the node doesn't use it and hardcodes it as
screenshots
?