Skip to content

Commit 376f758

Browse files
committed
feat(pydantic): added pydantic output schema
1 parent 1d217e4 commit 376f758

23 files changed

+165
-125
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""
2+
Example of Search Graph
3+
"""
4+
5+
import os
6+
from dotenv import load_dotenv
7+
load_dotenv()
8+
9+
from scrapegraphai.graphs import SearchGraph
10+
from scrapegraphai.utils import convert_to_csv, convert_to_json, prettify_exec_info
11+
12+
from pydantic import BaseModel, Field
13+
from typing import List
14+
15+
# ************************************************
16+
# Define the output schema for the graph
17+
# ************************************************
18+
19+
class Dish(BaseModel):
20+
name: str = Field(description="The name of the dish")
21+
description: str = Field(description="The description of the dish")
22+
23+
class Dishes(BaseModel):
24+
dishes: List[Dish]
25+
26+
# ************************************************
27+
# Define the configuration for the graph
28+
# ************************************************
29+
30+
openai_key = os.getenv("OPENAI_APIKEY")
31+
32+
graph_config = {
33+
"llm": {
34+
"api_key": openai_key,
35+
"model": "gpt-3.5-turbo",
36+
},
37+
"max_results": 2,
38+
"verbose": True,
39+
}
40+
41+
# ************************************************
42+
# Create the SearchGraph instance and run it
43+
# ************************************************
44+
45+
search_graph = SearchGraph(
46+
prompt="List me Chioggia's famous dishes",
47+
config=graph_config,
48+
schema=Dishes
49+
)
50+
51+
result = search_graph.run()
52+
print(result)
53+
54+
# ************************************************
55+
# Get graph execution info
56+
# ************************************************
57+
58+
graph_exec_info = search_graph.get_execution_info()
59+
print(prettify_exec_info(graph_exec_info))
60+
61+
# Save to json and csv
62+
convert_to_csv(result, "result")
63+
convert_to_json(result, "result")

examples/openai/smart_scraper_schema_openai.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44

55
import os, json
66
from dotenv import load_dotenv
7+
from pydantic import BaseModel, Field
8+
from typing import List
9+
710
from scrapegraphai.graphs import SmartScraperGraph
811

912
load_dotenv()
@@ -12,22 +15,12 @@
1215
# Define the output schema for the graph
1316
# ************************************************
1417

15-
schema= """
16-
{
17-
"Projects": [
18-
"Project #":
19-
{
20-
"title": "...",
21-
"description": "...",
22-
},
23-
"Project #":
24-
{
25-
"title": "...",
26-
"description": "...",
27-
}
28-
]
29-
}
30-
"""
18+
class Project(BaseModel):
19+
title: str = Field(description="The title of the project")
20+
description: str = Field(description="The description of the project")
21+
22+
class Projects(BaseModel):
23+
projects: List[Project]
3124

3225
# ************************************************
3326
# Define the configuration for the graph
@@ -51,9 +44,9 @@
5144
smart_scraper_graph = SmartScraperGraph(
5245
prompt="List me all the projects with their description",
5346
source="https://perinim.github.io/projects/",
54-
schema=schema,
47+
schema=Projects,
5548
config=graph_config
5649
)
5750

5851
result = smart_scraper_graph.run()
59-
print(json.dumps(result, indent=4))
52+
print(result)

scrapegraphai/graphs/abstract_graph.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
"""
44

55
from abc import ABC, abstractmethod
6-
from typing import Optional
6+
from typing import Optional, Union
77
import uuid
8+
from pydantic import BaseModel
89

910
from langchain_aws import BedrockEmbeddings
1011
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
@@ -62,7 +63,7 @@ class AbstractGraph(ABC):
6263
"""
6364

6465
def __init__(self, prompt: str, config: dict,
65-
source: Optional[str] = None, schema: Optional[str] = None):
66+
source: Optional[str] = None, schema: Optional[BaseModel] = None):
6667

6768
self.prompt = prompt
6869
self.source = source

scrapegraphai/graphs/csv_scraper_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from typing import Optional
6+
from pydantic import BaseModel
67

78
from .base_graph import BaseGraph
89
from .abstract_graph import AbstractGraph
@@ -20,7 +21,7 @@ class CSVScraperGraph(AbstractGraph):
2021
information from web pages using a natural language model to interpret and answer prompts.
2122
"""
2223

23-
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
24+
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
2425
"""
2526
Initializes the CSVScraperGraph with a prompt, source, and configuration.
2627
"""

scrapegraphai/graphs/deep_scraper_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from typing import Optional
6+
from pydantic import BaseModel
67

78
from .base_graph import BaseGraph
89
from .abstract_graph import AbstractGraph
@@ -56,7 +57,7 @@ class DeepScraperGraph(AbstractGraph):
5657
)
5758
"""
5859

59-
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
60+
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
6061

6162
super().__init__(prompt, config, source, schema)
6263

scrapegraphai/graphs/json_scraper_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from typing import Optional
6+
from pydantic import BaseModel
67

78
from .base_graph import BaseGraph
89
from .abstract_graph import AbstractGraph
@@ -44,7 +45,7 @@ class JSONScraperGraph(AbstractGraph):
4445
>>> result = json_scraper.run()
4546
"""
4647

47-
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
48+
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
4849
super().__init__(prompt, config, source, schema)
4950

5051
self.input_key = "json" if source.endswith("json") else "json_dir"

scrapegraphai/graphs/omni_scraper_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from typing import Optional
6+
from pydantic import BaseModel
67

78
from .base_graph import BaseGraph
89
from .abstract_graph import AbstractGraph
@@ -52,7 +53,7 @@ class OmniScraperGraph(AbstractGraph):
5253
)
5354
"""
5455

55-
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
56+
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
5657

5758
self.max_images = 5 if config is None else config.get("max_images", 5)
5859

scrapegraphai/graphs/omni_search_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from copy import copy, deepcopy
66
from typing import Optional
7+
from pydantic import BaseModel
78

89
from .base_graph import BaseGraph
910
from .abstract_graph import AbstractGraph
@@ -43,7 +44,7 @@ class OmniSearchGraph(AbstractGraph):
4344
>>> result = search_graph.run()
4445
"""
4546

46-
def __init__(self, prompt: str, config: dict, schema: Optional[str] = None):
47+
def __init__(self, prompt: str, config: dict, schema: Optional[BaseModel] = None):
4748

4849
self.max_results = config.get("max_results", 3)
4950

scrapegraphai/graphs/pdf_scraper_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from typing import Optional
6+
from pydantic import BaseModel
67

78
from .base_graph import BaseGraph
89
from .abstract_graph import AbstractGraph
@@ -46,7 +47,7 @@ class PDFScraperGraph(AbstractGraph):
4647
>>> result = pdf_scraper.run()
4748
"""
4849

49-
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
50+
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
5051
super().__init__(prompt, config, source, schema)
5152

5253
self.input_key = "pdf" if source.endswith("pdf") else "pdf_dir"

scrapegraphai/graphs/script_creator_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from typing import Optional
6+
from pydantic import BaseModel
67

78
from .base_graph import BaseGraph
89
from .abstract_graph import AbstractGraph
@@ -46,7 +47,7 @@ class ScriptCreatorGraph(AbstractGraph):
4647
>>> result = script_creator.run()
4748
"""
4849

49-
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
50+
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
5051

5152
self.library = config['library']
5253

scrapegraphai/graphs/search_graph.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from copy import copy, deepcopy
66
from typing import Optional
7+
from pydantic import BaseModel
78

89
from .base_graph import BaseGraph
910
from .abstract_graph import AbstractGraph
@@ -42,14 +43,16 @@ class SearchGraph(AbstractGraph):
4243
>>> result = search_graph.run()
4344
"""
4445

45-
def __init__(self, prompt: str, config: dict, schema: Optional[str] = None):
46+
def __init__(self, prompt: str, config: dict, schema: Optional[BaseModel] = None):
4647

4748
self.max_results = config.get("max_results", 3)
4849

4950
if all(isinstance(value, str) for value in config.values()):
5051
self.copy_config = copy(config)
5152
else:
5253
self.copy_config = deepcopy(config)
54+
55+
self.copy_schema = deepcopy(schema)
5356

5457
super().__init__(prompt, config, schema)
5558

@@ -68,7 +71,8 @@ def _create_graph(self) -> BaseGraph:
6871
smart_scraper_instance = SmartScraperGraph(
6972
prompt="",
7073
source="",
71-
config=self.copy_config
74+
config=self.copy_config,
75+
schema=self.copy_schema
7276
)
7377

7478
# ************************************************

scrapegraphai/graphs/smart_scraper_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from typing import Optional
6+
from pydantic import BaseModel
67

78
from .base_graph import BaseGraph
89
from .abstract_graph import AbstractGraph
@@ -48,7 +49,7 @@ class SmartScraperGraph(AbstractGraph):
4849
)
4950
"""
5051

51-
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
52+
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
5253
super().__init__(prompt, config, source, schema)
5354

5455
self.input_key = "url" if source.startswith("http") else "local_dir"

scrapegraphai/graphs/smart_scraper_multi_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from copy import copy, deepcopy
66
from typing import List, Optional
7+
from pydantic import BaseModel
78

89
from .base_graph import BaseGraph
910
from .abstract_graph import AbstractGraph
@@ -42,7 +43,7 @@ class SmartScraperMultiGraph(AbstractGraph):
4243
>>> result = search_graph.run()
4344
"""
4445

45-
def __init__(self, prompt: str, source: List[str], config: dict, schema: Optional[str] = None):
46+
def __init__(self, prompt: str, source: List[str], config: dict, schema: Optional[BaseModel] = None):
4647

4748
self.max_results = config.get("max_results", 3)
4849

scrapegraphai/graphs/speech_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from typing import Optional
6+
from pydantic import BaseModel
67

78
from .base_graph import BaseGraph
89
from .abstract_graph import AbstractGraph
@@ -47,7 +48,7 @@ class SpeechGraph(AbstractGraph):
4748
... {"llm": {"model": "gpt-3.5-turbo"}}
4849
"""
4950

50-
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
51+
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
5152
super().__init__(prompt, config, source, schema)
5253

5354
self.input_key = "url" if source.startswith("http") else "local_dir"

scrapegraphai/graphs/xml_scraper_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from typing import Optional
6+
from pydantic import BaseModel
67

78
from .base_graph import BaseGraph
89
from .abstract_graph import AbstractGraph
@@ -46,7 +47,7 @@ class XMLScraperGraph(AbstractGraph):
4647
>>> result = xml_scraper.run()
4748
"""
4849

49-
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
50+
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
5051
super().__init__(prompt, config, source, schema)
5152

5253
self.input_key = "xml" if source.endswith("xml") else "xml_dir"

scrapegraphai/helpers/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from .schemas import graph_schema
77
from .models_tokens import models_tokens
88
from .robots import robots_dictionary
9-
from .generate_answer_node_prompts import template_chunks, template_chunks_with_schema, template_no_chunks, template_no_chunks_with_schema, template_merge
9+
from .generate_answer_node_prompts import template_chunks, template_no_chunks, template_merge
1010
from .generate_answer_node_csv_prompts import template_chunks_csv, template_no_chunks_csv, template_merge_csv
11-
from .generate_answer_node_pdf_prompts import template_chunks_pdf, template_no_chunks_pdf, template_merge_pdf, template_chunks_pdf_with_schema, template_no_chunks_pdf_with_schema
11+
from .generate_answer_node_pdf_prompts import template_chunks_pdf, template_no_chunks_pdf, template_merge_pdf
1212
from .generate_answer_node_omni_prompts import template_chunks_omni, template_no_chunk_omni, template_merge_omni

0 commit comments

Comments
 (0)