Skip to content

Commit 5eb3cff

Browse files
committed
feat: refactoring of the code
1 parent 8e3d5de commit 5eb3cff

File tree

5 files changed

+82
-30
lines changed

5 files changed

+82
-30
lines changed

examples/openai/screenshot_scraper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929
# ************************************************
3030

3131
smart_scraper_graph = ScreenshotScraperGraph(
32-
prompt="List me the email of the company",
33-
source="https://scrapegraphai.com/",
32+
prompt="List me all the projects",
33+
source="https://perinim.github.io/projects/",
3434
config=graph_config
3535
)
3636

scrapegraphai/nodes/fetch_screen_node.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
"""
2+
fetch_screen_node module
3+
"""
14
from typing import List, Optional
25
from playwright.sync_api import sync_playwright
36
from .base_node import BaseNode
@@ -18,8 +21,10 @@ def __init__(
1821
self.url = node_config.get("link")
1922

2023
def execute(self, state: dict) -> dict:
21-
"""Captures screenshots from the input URL and stores them in the state dictionary as bytes."""
22-
24+
"""
25+
Captures screenshots from the input URL and stores them in the state dictionary as bytes.
26+
"""
27+
2328
screenshots = []
2429

2530
with sync_playwright() as p:
@@ -29,28 +34,25 @@ def execute(self, state: dict) -> dict:
2934

3035
viewport_height = page.viewport_size["height"]
3136

32-
# Initialize screenshot counter
3337
screenshot_counter = 1
3438

35-
# List to keep track of screenshot data
3639
screenshot_data_list = []
3740

38-
# Function to capture screenshots
3941
def capture_screenshot(scroll_position, counter):
4042
page.evaluate(f"window.scrollTo(0, {scroll_position});")
4143
screenshot_data = page.screenshot()
4244
screenshot_data_list.append(screenshot_data)
4345

44-
# Capture screenshots
45-
capture_screenshot(0, screenshot_counter) # First screenshot
46+
capture_screenshot(0, screenshot_counter)
4647
screenshot_counter += 1
47-
capture_screenshot(viewport_height, screenshot_counter) # Second screenshot
48+
capture_screenshot(viewport_height, screenshot_counter)
4849

4950
browser.close()
5051

51-
# Store screenshot data as bytes in the state dictionary
5252
for screenshot_data in screenshot_data_list:
5353
screenshots.append(screenshot_data)
54+
5455
state["link"] = self.url
5556
state['screenshots'] = screenshots
57+
5658
return state

scrapegraphai/nodes/generate_answer_from_image_node.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
class GenerateAnswerFromImageNode(BaseNode):
77
"""
88
GenerateAnswerFromImageNode analyzes images from the state dictionary using the OpenAI API
9-
and updates the state with the generated answers.
9+
and updates the state with the consolidated answers.
1010
"""
1111

1212
def __init__(
@@ -19,20 +19,28 @@ def __init__(
1919
super().__init__(node_name, "node", input, output, 2, node_config)
2020

2121
def execute(self, state: dict) -> dict:
22-
"""Processes images from the state, generates answers, and updates the state."""
23-
# Retrieve the image data from the state dictionary
22+
"""
23+
Processes images from the state, generates answers,
24+
consolidates the results, and updates the state.
25+
"""
2426
images = state.get('screenshots', [])
25-
results = []
27+
analyses = []
28+
29+
api_key = self.node_config.get("config", {}).get("llm", {}).get("api_key", "")
30+
31+
supported_models = ["gpt-4o", "gpt-4o-mini", "gpt-4-turbo"]
32+
33+
if self.node_config["config"]["llm"]["model"] not in supported_models:
34+
raise ValueError(f"""Model '{self.node_config['config']['llm']['model']}'
35+
is not supported. Supported models are:
36+
{', '.join(supported_models)}.""")
2637

27-
# OpenAI API Key
2838
for image_data in images:
29-
# Encode the image data to base64
3039
base64_image = base64.b64encode(image_data).decode('utf-8')
3140

32-
# Prepare API request
3341
headers = {
3442
"Content-Type": "application/json",
35-
"Authorization": f"Bearer {self.node_config.get("config").get("llm").get("api_key")}"
43+
"Authorization": f"Bearer {api_key}"
3644
}
3745

3846
payload = {
@@ -43,7 +51,8 @@ def execute(self, state: dict) -> dict:
4351
"content": [
4452
{
4553
"type": "text",
46-
"text": state.get("user_prompt", "Extract information from the image")
54+
"text": state.get("user_prompt",
55+
"Extract information from the image")
4756
},
4857
{
4958
"type": "image_url",
@@ -57,18 +66,20 @@ def execute(self, state: dict) -> dict:
5766
"max_tokens": 300
5867
}
5968

60-
# Make the API request
61-
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
69+
response = requests.post("https://api.openai.com/v1/chat/completions",
70+
headers=headers,
71+
json=payload,
72+
timeout=10 )
6273
result = response.json()
6374

64-
# Extract the response text
65-
response_text = result.get('choices', [{}])[0].get('message', {}).get('content', 'No response')
75+
response_text = result.get('choices',
76+
[{}])[0].get('message', {}).get('content', 'No response')
77+
analyses.append(response_text)
78+
79+
consolidated_analysis = " ".join(analyses)
6680

67-
# Append the result to the results list
68-
results.append({
69-
"analysis": response_text
70-
})
81+
state['answer'] = {
82+
"consolidated_analysis": consolidated_analysis
83+
}
7184

72-
# Update the state dictionary with the results
73-
state['answer'] = results
7485
return state
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import os
2+
import pytest
3+
import json
4+
from scrapegraphai.graphs import ScreenshotScraperGraph
5+
from dotenv import load_dotenv
6+
7+
# Load environment variables
8+
load_dotenv()
9+
10+
# Define a fixture for the graph configuration
11+
@pytest.fixture
12+
def graph_config():
13+
"""
14+
Creation of the graph
15+
"""
16+
return {
17+
"llm": {
18+
"api_key": os.getenv("OPENAI_API_KEY"),
19+
"model": "gpt-4o",
20+
},
21+
"verbose": True,
22+
"headless": False,
23+
}
24+
25+
def test_screenshot_scraper_graph(graph_config):
26+
"""
27+
test
28+
"""
29+
smart_scraper_graph = ScreenshotScraperGraph(
30+
prompt="List me all the projects",
31+
source="https://perinim.github.io/projects/",
32+
config=graph_config
33+
)
34+
35+
result = smart_scraper_graph.run()
36+
37+
assert result is not None, "The result should not be None"
38+
39+
print(json.dumps(result, indent=4))

0 commit comments

Comments
 (0)