Skip to content

Commit a7443a7

Browse files
authored
Merge pull request #341 from VinciGit00/332-pydantic-schema-validation
#332 pydantic schema validation
2 parents ac8e7c1 + 74fd530 commit a7443a7

24 files changed

+199
-125
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""
2+
Example of Search Graph
3+
"""
4+
5+
import os
6+
from dotenv import load_dotenv
7+
load_dotenv()
8+
9+
from scrapegraphai.graphs import SearchGraph
10+
from scrapegraphai.utils import convert_to_csv, convert_to_json, prettify_exec_info
11+
12+
from pydantic import BaseModel, Field
13+
from typing import List
14+
15+
# ************************************************
16+
# Define the output schema for the graph
17+
# ************************************************
18+
19+
class Dish(BaseModel):
20+
name: str = Field(description="The name of the dish")
21+
description: str = Field(description="The description of the dish")
22+
23+
class Dishes(BaseModel):
24+
dishes: List[Dish]
25+
26+
# ************************************************
27+
# Define the configuration for the graph
28+
# ************************************************
29+
30+
openai_key = os.getenv("OPENAI_APIKEY")
31+
32+
graph_config = {
33+
"llm": {
34+
"api_key": openai_key,
35+
"model": "gpt-3.5-turbo",
36+
},
37+
"max_results": 2,
38+
"verbose": True,
39+
}
40+
41+
# ************************************************
42+
# Create the SearchGraph instance and run it
43+
# ************************************************
44+
45+
search_graph = SearchGraph(
46+
prompt="List me Chioggia's famous dishes",
47+
config=graph_config,
48+
schema=Dishes
49+
)
50+
51+
result = search_graph.run()
52+
print(result)
53+
54+
# ************************************************
55+
# Get graph execution info
56+
# ************************************************
57+
58+
graph_exec_info = search_graph.get_execution_info()
59+
print(prettify_exec_info(graph_exec_info))
60+
61+
# Save to json and csv
62+
convert_to_csv(result, "result")
63+
convert_to_json(result, "result")

examples/openai/smart_scraper_schema_openai.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44

55
import os, json
66
from dotenv import load_dotenv
7+
from pydantic import BaseModel, Field
8+
from typing import List
9+
710
from scrapegraphai.graphs import SmartScraperGraph
811

912
load_dotenv()
@@ -12,22 +15,12 @@
1215
# Define the output schema for the graph
1316
# ************************************************
1417

15-
schema= """
16-
{
17-
"Projects": [
18-
"Project #":
19-
{
20-
"title": "...",
21-
"description": "...",
22-
},
23-
"Project #":
24-
{
25-
"title": "...",
26-
"description": "...",
27-
}
28-
]
29-
}
30-
"""
18+
class Project(BaseModel):
19+
title: str = Field(description="The title of the project")
20+
description: str = Field(description="The description of the project")
21+
22+
class Projects(BaseModel):
23+
projects: List[Project]
3124

3225
# ************************************************
3326
# Define the configuration for the graph
@@ -51,9 +44,9 @@
5144
smart_scraper_graph = SmartScraperGraph(
5245
prompt="List me all the projects with their description",
5346
source="https://perinim.github.io/projects/",
54-
schema=schema,
47+
schema=Projects,
5548
config=graph_config
5649
)
5750

5851
result = smart_scraper_graph.run()
59-
print(json.dumps(result, indent=4))
52+
print(result)

scrapegraphai/graphs/abstract_graph.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
"""
44

55
from abc import ABC, abstractmethod
6-
from typing import Optional
6+
from typing import Optional, Union
77
import uuid
8+
from pydantic import BaseModel
89

910
from langchain_aws import BedrockEmbeddings
1011
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
@@ -62,7 +63,7 @@ class AbstractGraph(ABC):
6263
"""
6364

6465
def __init__(self, prompt: str, config: dict,
65-
source: Optional[str] = None, schema: Optional[str] = None):
66+
source: Optional[str] = None, schema: Optional[BaseModel] = None):
6667

6768
self.prompt = prompt
6869
self.source = source
@@ -352,6 +353,16 @@ def get_state(self, key=None) -> dict:
352353
return self.final_state[key]
353354
return self.final_state
354355

356+
def append_node(self, node):
357+
"""
358+
Add a node to the graph.
359+
360+
Args:
361+
node (BaseNode): The node to add to the graph.
362+
"""
363+
364+
self.graph.append_node(node)
365+
355366
def get_execution_info(self):
356367
"""
357368
Returns the execution information of the graph.

scrapegraphai/graphs/base_graph.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class BaseGraph:
4949
def __init__(self, nodes: list, edges: list, entry_point: str, use_burr: bool = False, burr_config: dict = None):
5050

5151
self.nodes = nodes
52+
self.raw_edges = edges
5253
self.edges = self._create_edges({e for e in edges})
5354
self.entry_point = entry_point.node_name
5455
self.initial_state = {}
@@ -168,4 +169,25 @@ def execute(self, initial_state: dict) -> Tuple[dict, list]:
168169
result = bridge.execute(initial_state)
169170
return (result["_state"], [])
170171
else:
171-
return self._execute_standard(initial_state)
172+
return self._execute_standard(initial_state)
173+
174+
def append_node(self, node):
175+
"""
176+
Adds a node to the graph.
177+
178+
Args:
179+
node (BaseNode): The node instance to add to the graph.
180+
"""
181+
182+
# if node name already exists in the graph, raise an exception
183+
if node.node_name in {n.node_name for n in self.nodes}:
184+
raise ValueError(f"Node with name '{node.node_name}' already exists in the graph. You can change it by setting the 'node_name' attribute.")
185+
186+
# get the last node in the list
187+
last_node = self.nodes[-1]
188+
# add the edge connecting the last node to the new node
189+
self.raw_edges.append((last_node, node))
190+
# add the node to the list of nodes
191+
self.nodes.append(node)
192+
# update the edges connecting the last node to the new node
193+
self.edges = self._create_edges({e for e in self.raw_edges})

scrapegraphai/graphs/csv_scraper_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from typing import Optional
6+
from pydantic import BaseModel
67

78
from .base_graph import BaseGraph
89
from .abstract_graph import AbstractGraph
@@ -20,7 +21,7 @@ class CSVScraperGraph(AbstractGraph):
2021
information from web pages using a natural language model to interpret and answer prompts.
2122
"""
2223

23-
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
24+
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
2425
"""
2526
Initializes the CSVScraperGraph with a prompt, source, and configuration.
2627
"""

scrapegraphai/graphs/deep_scraper_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from typing import Optional
6+
from pydantic import BaseModel
67

78
from .base_graph import BaseGraph
89
from .abstract_graph import AbstractGraph
@@ -56,7 +57,7 @@ class DeepScraperGraph(AbstractGraph):
5657
)
5758
"""
5859

59-
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
60+
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
6061

6162
super().__init__(prompt, config, source, schema)
6263

scrapegraphai/graphs/json_scraper_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from typing import Optional
6+
from pydantic import BaseModel
67

78
from .base_graph import BaseGraph
89
from .abstract_graph import AbstractGraph
@@ -44,7 +45,7 @@ class JSONScraperGraph(AbstractGraph):
4445
>>> result = json_scraper.run()
4546
"""
4647

47-
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
48+
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
4849
super().__init__(prompt, config, source, schema)
4950

5051
self.input_key = "json" if source.endswith("json") else "json_dir"

scrapegraphai/graphs/omni_scraper_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from typing import Optional
6+
from pydantic import BaseModel
67

78
from .base_graph import BaseGraph
89
from .abstract_graph import AbstractGraph
@@ -52,7 +53,7 @@ class OmniScraperGraph(AbstractGraph):
5253
)
5354
"""
5455

55-
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
56+
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
5657

5758
self.max_images = 5 if config is None else config.get("max_images", 5)
5859

scrapegraphai/graphs/omni_search_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from copy import copy, deepcopy
66
from typing import Optional
7+
from pydantic import BaseModel
78

89
from .base_graph import BaseGraph
910
from .abstract_graph import AbstractGraph
@@ -43,7 +44,7 @@ class OmniSearchGraph(AbstractGraph):
4344
>>> result = search_graph.run()
4445
"""
4546

46-
def __init__(self, prompt: str, config: dict, schema: Optional[str] = None):
47+
def __init__(self, prompt: str, config: dict, schema: Optional[BaseModel] = None):
4748

4849
self.max_results = config.get("max_results", 3)
4950

scrapegraphai/graphs/pdf_scraper_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55

66
from typing import Optional
7+
from pydantic import BaseModel
78

89
from .base_graph import BaseGraph
910
from .abstract_graph import AbstractGraph
@@ -47,7 +48,7 @@ class PDFScraperGraph(AbstractGraph):
4748
>>> result = pdf_scraper.run()
4849
"""
4950

50-
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
51+
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
5152
super().__init__(prompt, config, source, schema)
5253

5354
self.input_key = "pdf" if source.endswith("pdf") else "pdf_dir"

scrapegraphai/graphs/script_creator_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from typing import Optional
6+
from pydantic import BaseModel
67

78
from .base_graph import BaseGraph
89
from .abstract_graph import AbstractGraph
@@ -46,7 +47,7 @@ class ScriptCreatorGraph(AbstractGraph):
4647
>>> result = script_creator.run()
4748
"""
4849

49-
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
50+
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
5051

5152
self.library = config['library']
5253

scrapegraphai/graphs/search_graph.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from copy import copy, deepcopy
66
from typing import Optional
7+
from pydantic import BaseModel
78

89
from .base_graph import BaseGraph
910
from .abstract_graph import AbstractGraph
@@ -42,14 +43,16 @@ class SearchGraph(AbstractGraph):
4243
>>> result = search_graph.run()
4344
"""
4445

45-
def __init__(self, prompt: str, config: dict, schema: Optional[str] = None):
46+
def __init__(self, prompt: str, config: dict, schema: Optional[BaseModel] = None):
4647

4748
self.max_results = config.get("max_results", 3)
4849

4950
if all(isinstance(value, str) for value in config.values()):
5051
self.copy_config = copy(config)
5152
else:
5253
self.copy_config = deepcopy(config)
54+
55+
self.copy_schema = deepcopy(schema)
5356

5457
super().__init__(prompt, config, schema)
5558

@@ -68,7 +71,8 @@ def _create_graph(self) -> BaseGraph:
6871
smart_scraper_instance = SmartScraperGraph(
6972
prompt="",
7073
source="",
71-
config=self.copy_config
74+
config=self.copy_config,
75+
schema=self.copy_schema
7276
)
7377

7478
# ************************************************

scrapegraphai/graphs/smart_scraper_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from typing import Optional
6+
from pydantic import BaseModel
67

78
from .base_graph import BaseGraph
89
from .abstract_graph import AbstractGraph
@@ -48,7 +49,7 @@ class SmartScraperGraph(AbstractGraph):
4849
)
4950
"""
5051

51-
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
52+
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
5253
super().__init__(prompt, config, source, schema)
5354

5455
self.input_key = "url" if source.startswith("http") else "local_dir"

scrapegraphai/graphs/smart_scraper_multi_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from copy import copy, deepcopy
66
from typing import List, Optional
7+
from pydantic import BaseModel
78

89
from .base_graph import BaseGraph
910
from .abstract_graph import AbstractGraph
@@ -42,7 +43,7 @@ class SmartScraperMultiGraph(AbstractGraph):
4243
>>> result = search_graph.run()
4344
"""
4445

45-
def __init__(self, prompt: str, source: List[str], config: dict, schema: Optional[str] = None):
46+
def __init__(self, prompt: str, source: List[str], config: dict, schema: Optional[BaseModel] = None):
4647

4748
self.max_results = config.get("max_results", 3)
4849

scrapegraphai/graphs/speech_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from typing import Optional
6+
from pydantic import BaseModel
67

78
from .base_graph import BaseGraph
89
from .abstract_graph import AbstractGraph
@@ -47,7 +48,7 @@ class SpeechGraph(AbstractGraph):
4748
... {"llm": {"model": "gpt-3.5-turbo"}}
4849
"""
4950

50-
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
51+
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
5152
super().__init__(prompt, config, source, schema)
5253

5354
self.input_key = "url" if source.startswith("http") else "local_dir"

0 commit comments

Comments
 (0)