Skip to content

Screenshot scraper integration #558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions examples/openai/screenshot_scraper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
Basic example of scraping pipeline using SmartScraper
"""

import os
import json
from dotenv import load_dotenv
from scrapegraphai.graphs import ScreenshotScraperGraph
from scrapegraphai.utils import prettify_exec_info

load_dotenv()

# ************************************************
# Define the configuration for the graph
# ************************************************


graph_config = {
"llm": {
"api_key": os.getenv("OPENAI_API_KEY"),
"model": "gpt-4o",
},
"verbose": True,
"headless": False,
}

# ************************************************
# Create the ScreenshotScraperGraph instance and run it
# ************************************************

smart_scraper_graph = ScreenshotScraperGraph(
prompt="List me all the projects",
source="https://perinim.github.io/projects/",
config=graph_config
)

result = smart_scraper_graph.run()
print(json.dumps(result, indent=4))
5 changes: 3 additions & 2 deletions examples/openai/smart_scraper_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

import os
import json
from dotenv import load_dotenv
from scrapegraphai.graphs import SmartScraperGraph
from scrapegraphai.utils import prettify_exec_info
from dotenv import load_dotenv

load_dotenv()

# ************************************************
Expand All @@ -17,7 +18,7 @@
graph_config = {
"llm": {
"api_key": os.getenv("OPENAI_API_KEY"),
"model": "gpt-3.5-turbo",
"model": "gpt-4o",
},
"verbose": True,
"headless": False,
Expand Down
1 change: 1 addition & 0 deletions scrapegraphai/graphs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@
from .markdown_scraper_graph import MDScraperGraph
from .markdown_scraper_multi_graph import MDScraperMultiGraph
from .search_link_graph import SearchLinkGraph
from .screenshot_scraper_graph import ScreenshotScraperGraph
82 changes: 82 additions & 0 deletions scrapegraphai/graphs/screenshot_scraper_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""
ScreenshotScraperGraph Module
"""
from typing import Optional
import logging
from pydantic import BaseModel
from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
from ..nodes import ( FetchScreenNode, GenerateAnswerFromImageNode, )

class ScreenshotScraperGraph(AbstractGraph):
"""
A graph instance representing the web scraping workflow for images.

Attributes:
prompt (str): The input text to be scraped.
config (dict): Configuration parameters for the graph.
source (str): The source URL or image link to scrape from.

Methods:
__init__(prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None)
Initializes the ScreenshotScraperGraph instance with the given prompt,
source, and configuration parameters.

_create_graph()
Creates a graph of nodes representing the web scraping workflow for images.

run()
Executes the scraping process and returns the answer to the prompt.
"""

def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
super().__init__(prompt, config, source, schema)


def _create_graph(self) -> BaseGraph:
"""
Creates the graph of nodes representing the workflow for web scraping with images.

Returns:
BaseGraph: A graph instance representing the web scraping workflow for images.
"""
fetch_screen_node = FetchScreenNode(
input="url",
output=["screenshots"],
node_config={
"link": self.source
}
)
generate_answer_from_image_node = GenerateAnswerFromImageNode(
input="screenshots",
output=["answer"],
node_config={
"config": self.config
}
)

return BaseGraph(
nodes=[
fetch_screen_node,
generate_answer_from_image_node,
],
edges=[
(fetch_screen_node, generate_answer_from_image_node),
],
entry_point=fetch_screen_node,
graph_name=self.__class__.__name__
)

def run(self) -> str:
"""
Executes the scraping process and returns the answer to the prompt.

Returns:
str: The answer to the prompt.
"""

inputs = {"user_prompt": self.prompt}
self.final_state, self.execution_info = self.graph.execute(inputs)

return self.final_state.get("answer", "No answer found.")

4 changes: 3 additions & 1 deletion scrapegraphai/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,6 @@
from .graph_iterator_node import GraphIteratorNode
from .merge_answers_node import MergeAnswersNode
from .generate_answer_omni_node import GenerateAnswerOmniNode
from .merge_generated_scripts import MergeGeneratedScriptsNode
from .merge_generated_scripts import MergeGeneratedScriptsNode
from .fetch_screen_node import FetchScreenNode
from .generate_answer_from_image_node import GenerateAnswerFromImageNode
55 changes: 55 additions & 0 deletions scrapegraphai/nodes/fetch_screen_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
fetch_screen_node module
"""
from typing import List, Optional
from playwright.sync_api import sync_playwright
from .base_node import BaseNode
from ..utils.logging import get_logger

class FetchScreenNode(BaseNode):
"""
FetchScreenNode captures screenshots from a given URL and stores the image data as bytes.
"""

def __init__(
self,
input: str,
output: List[str],
node_config: Optional[dict] = None,
node_name: str = "FetchScreenNode",
):
super().__init__(node_name, "node", input, output, 2, node_config)
self.url = node_config.get("link")

def execute(self, state: dict) -> dict:
"""
Captures screenshots from the input URL and stores them in the state dictionary as bytes.
"""
self.logger.info(f"--- Executing {self.node_name} Node ---")

with sync_playwright() as p:
browser = p.chromium.launch()
page = browser.new_page()
page.goto(self.url)

viewport_height = page.viewport_size["height"]

screenshot_counter = 1

screenshot_data_list = []

def capture_screenshot(scroll_position, counter):
page.evaluate(f"window.scrollTo(0, {scroll_position});")
screenshot_data = page.screenshot()
screenshot_data_list.append(screenshot_data)

capture_screenshot(0, screenshot_counter)
screenshot_counter += 1
capture_screenshot(viewport_height, screenshot_counter)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

infinite-scrolling web pages (with dynamic JS rendering) may have undefined behavior with the viewport height, as that is bound to change as well as you scroll through the page in the browser.

looks fine for now, but I would maybe add a disclaimer in the docs saying screenshot capturing might not work as well for those; collecting only two screenshots is fine as well, but that number might better be a function of the viewport height, so that we don't miss parts of the page content in case a page is very very long (although not infinite-scrolling)


browser.close()

state["link"] = self.url
state['screenshots'] = screenshot_data_list

return state
115 changes: 115 additions & 0 deletions scrapegraphai/nodes/generate_answer_from_image_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""
GenerateAnswerFromImageNode Module
"""
import base64
import asyncio
from typing import List, Optional
import aiohttp
from .base_node import BaseNode
from ..utils.logging import get_logger

class GenerateAnswerFromImageNode(BaseNode):
"""
GenerateAnswerFromImageNode analyzes images from the state dictionary using the OpenAI API
and updates the state with the consolidated answers.
"""

def __init__(
self,
input: str,
output: List[str],
node_config: Optional[dict] = None,
node_name: str = "GenerateAnswerFromImageNode",
):
super().__init__(node_name, "node", input, output, 2, node_config)

async def process_image(self, session, api_key, image_data, user_prompt):
"""
async process image
"""
base64_image = base64.b64encode(image_data).decode('utf-8')

headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}

payload = {
"model": self.node_config["config"]["llm"]["model"],
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": user_prompt
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
]
}
],
"max_tokens": 300
}

async with session.post("https://api.openai.com/v1/chat/completions",
headers=headers, json=payload) as response:
result = await response.json()
return result.get('choices', [{}])[0].get('message', {}).get('content', 'No response')

async def execute_async(self, state: dict) -> dict:
"""
Processes images from the state, generates answers,
consolidates the results, and updates the state asynchronously.
"""
self.logger.info(f"--- Executing {self.node_name} Node ---")

images = state.get('screenshots', [])
analyses = []

supported_models = ("gpt-4o", "gpt-4o-mini", "gpt-4-turbo")

if self.node_config["config"]["llm"]["model"] not in supported_models:
raise ValueError(f"""Model '{self.node_config['config']['llm']['model']}'
is not supported. Supported models are:
{', '.join(supported_models)}.""")

api_key = self.node_config.get("config", {}).get("llm", {}).get("api_key", "")

async with aiohttp.ClientSession() as session:
tasks = [
self.process_image(session, api_key, image_data,
state.get("user_prompt", "Extract information from the image"))
for image_data in images
]

analyses = await asyncio.gather(*tasks)

consolidated_analysis = " ".join(analyses)

state['answer'] = {
"consolidated_analysis": consolidated_analysis
}

return state

def execute(self, state: dict) -> dict:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrap the function with a safer guard on the running event loop, e.g.:

def execute(self, state: dict) -> dict:
    """
    Wrapper to run the asynchronous execute_async function in a synchronous context.
    """
    try:
        eventloop = asyncio.get_event_loop()
    except RuntimeError:
        eventloop = None

    if eventloop and eventloop.is_running():
        state = eventloop.run_until_complete(self.execute_async(state))
    else:
        state = asyncio.run(self.execute_async(state))

    return state

"""
Wrapper to run the asynchronous execute_async function in a synchronous context.
"""
try:
eventloop = asyncio.get_event_loop()
except RuntimeError:
eventloop = None

if eventloop and eventloop.is_running():
task = eventloop.create_task(self.execute_async(state))
state = eventloop.run_until_complete(asyncio.gather(task))[0]
else:
state = asyncio.run(self.execute_async(state))

return state
39 changes: 39 additions & 0 deletions tests/graphs/screenshot_scraper_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os
import pytest
import json
from scrapegraphai.graphs import ScreenshotScraperGraph
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Define a fixture for the graph configuration
@pytest.fixture
def graph_config():
"""
Creation of the graph
"""
return {
"llm": {
"api_key": os.getenv("OPENAI_API_KEY"),
"model": "gpt-4o",
},
"verbose": True,
"headless": False,
}

def test_screenshot_scraper_graph(graph_config):
"""
test
"""
smart_scraper_graph = ScreenshotScraperGraph(
prompt="List me all the projects",
source="https://perinim.github.io/projects/",
config=graph_config
)

result = smart_scraper_graph.run()

assert result is not None, "The result should not be None"

print(json.dumps(result, indent=4))
Loading