Skip to content

#332 pydantic schema validation #341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions examples/openai/search_graph_schema_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""
Example of Search Graph
"""

import os
from dotenv import load_dotenv
load_dotenv()

from scrapegraphai.graphs import SearchGraph
from scrapegraphai.utils import convert_to_csv, convert_to_json, prettify_exec_info

from pydantic import BaseModel, Field
from typing import List

# ************************************************
# Define the output schema for the graph
# ************************************************

class Dish(BaseModel):
name: str = Field(description="The name of the dish")
description: str = Field(description="The description of the dish")

class Dishes(BaseModel):
dishes: List[Dish]

# ************************************************
# Define the configuration for the graph
# ************************************************

openai_key = os.getenv("OPENAI_APIKEY")

graph_config = {
"llm": {
"api_key": openai_key,
"model": "gpt-3.5-turbo",
},
"max_results": 2,
"verbose": True,
}

# ************************************************
# Create the SearchGraph instance and run it
# ************************************************

search_graph = SearchGraph(
prompt="List me Chioggia's famous dishes",
config=graph_config,
schema=Dishes
)

result = search_graph.run()
print(result)

# ************************************************
# Get graph execution info
# ************************************************

graph_exec_info = search_graph.get_execution_info()
print(prettify_exec_info(graph_exec_info))

# Save to json and csv
convert_to_csv(result, "result")
convert_to_json(result, "result")
29 changes: 11 additions & 18 deletions examples/openai/smart_scraper_schema_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

import os, json
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from typing import List

from scrapegraphai.graphs import SmartScraperGraph

load_dotenv()
Expand All @@ -12,22 +15,12 @@
# Define the output schema for the graph
# ************************************************

schema= """
{
"Projects": [
"Project #":
{
"title": "...",
"description": "...",
},
"Project #":
{
"title": "...",
"description": "...",
}
]
}
"""
class Project(BaseModel):
title: str = Field(description="The title of the project")
description: str = Field(description="The description of the project")

class Projects(BaseModel):
projects: List[Project]

# ************************************************
# Define the configuration for the graph
Expand All @@ -51,9 +44,9 @@
smart_scraper_graph = SmartScraperGraph(
prompt="List me all the projects with their description",
source="https://perinim.github.io/projects/",
schema=schema,
schema=Projects,
config=graph_config
)

result = smart_scraper_graph.run()
print(json.dumps(result, indent=4))
print(result)
15 changes: 13 additions & 2 deletions scrapegraphai/graphs/abstract_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
"""

from abc import ABC, abstractmethod
from typing import Optional
from typing import Optional, Union
import uuid
from pydantic import BaseModel

from langchain_aws import BedrockEmbeddings
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
Expand Down Expand Up @@ -62,7 +63,7 @@ class AbstractGraph(ABC):
"""

def __init__(self, prompt: str, config: dict,
source: Optional[str] = None, schema: Optional[str] = None):
source: Optional[str] = None, schema: Optional[BaseModel] = None):

self.prompt = prompt
self.source = source
Expand Down Expand Up @@ -352,6 +353,16 @@ def get_state(self, key=None) -> dict:
return self.final_state[key]
return self.final_state

def append_node(self, node):
"""
Add a node to the graph.

Args:
node (BaseNode): The node to add to the graph.
"""

self.graph.append_node(node)

def get_execution_info(self):
"""
Returns the execution information of the graph.
Expand Down
24 changes: 23 additions & 1 deletion scrapegraphai/graphs/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class BaseGraph:
def __init__(self, nodes: list, edges: list, entry_point: str, use_burr: bool = False, burr_config: dict = None):

self.nodes = nodes
self.raw_edges = edges
self.edges = self._create_edges({e for e in edges})
self.entry_point = entry_point.node_name
self.initial_state = {}
Expand Down Expand Up @@ -168,4 +169,25 @@ def execute(self, initial_state: dict) -> Tuple[dict, list]:
result = bridge.execute(initial_state)
return (result["_state"], [])
else:
return self._execute_standard(initial_state)
return self._execute_standard(initial_state)

def append_node(self, node):
"""
Adds a node to the graph.

Args:
node (BaseNode): The node instance to add to the graph.
"""

# if node name already exists in the graph, raise an exception
if node.node_name in {n.node_name for n in self.nodes}:
raise ValueError(f"Node with name '{node.node_name}' already exists in the graph. You can change it by setting the 'node_name' attribute.")

# get the last node in the list
last_node = self.nodes[-1]
# add the edge connecting the last node to the new node
self.raw_edges.append((last_node, node))
# add the node to the list of nodes
self.nodes.append(node)
# update the edges connecting the last node to the new node
self.edges = self._create_edges({e for e in self.raw_edges})
3 changes: 2 additions & 1 deletion scrapegraphai/graphs/csv_scraper_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from typing import Optional
from pydantic import BaseModel

from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
Expand All @@ -20,7 +21,7 @@ class CSVScraperGraph(AbstractGraph):
information from web pages using a natural language model to interpret and answer prompts.
"""

def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
"""
Initializes the CSVScraperGraph with a prompt, source, and configuration.
"""
Expand Down
3 changes: 2 additions & 1 deletion scrapegraphai/graphs/deep_scraper_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from typing import Optional
from pydantic import BaseModel

from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
Expand Down Expand Up @@ -56,7 +57,7 @@ class DeepScraperGraph(AbstractGraph):
)
"""

def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):

super().__init__(prompt, config, source, schema)

Expand Down
3 changes: 2 additions & 1 deletion scrapegraphai/graphs/json_scraper_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from typing import Optional
from pydantic import BaseModel

from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
Expand Down Expand Up @@ -44,7 +45,7 @@ class JSONScraperGraph(AbstractGraph):
>>> result = json_scraper.run()
"""

def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
super().__init__(prompt, config, source, schema)

self.input_key = "json" if source.endswith("json") else "json_dir"
Expand Down
3 changes: 2 additions & 1 deletion scrapegraphai/graphs/omni_scraper_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from typing import Optional
from pydantic import BaseModel

from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
Expand Down Expand Up @@ -52,7 +53,7 @@ class OmniScraperGraph(AbstractGraph):
)
"""

def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):

self.max_images = 5 if config is None else config.get("max_images", 5)

Expand Down
3 changes: 2 additions & 1 deletion scrapegraphai/graphs/omni_search_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from copy import copy, deepcopy
from typing import Optional
from pydantic import BaseModel

from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
Expand Down Expand Up @@ -43,7 +44,7 @@ class OmniSearchGraph(AbstractGraph):
>>> result = search_graph.run()
"""

def __init__(self, prompt: str, config: dict, schema: Optional[str] = None):
def __init__(self, prompt: str, config: dict, schema: Optional[BaseModel] = None):

self.max_results = config.get("max_results", 3)

Expand Down
3 changes: 2 additions & 1 deletion scrapegraphai/graphs/pdf_scraper_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

from typing import Optional
from pydantic import BaseModel

from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
Expand Down Expand Up @@ -47,7 +48,7 @@ class PDFScraperGraph(AbstractGraph):
>>> result = pdf_scraper.run()
"""

def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
super().__init__(prompt, config, source, schema)

self.input_key = "pdf" if source.endswith("pdf") else "pdf_dir"
Expand Down
3 changes: 2 additions & 1 deletion scrapegraphai/graphs/script_creator_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from typing import Optional
from pydantic import BaseModel

from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
Expand Down Expand Up @@ -46,7 +47,7 @@ class ScriptCreatorGraph(AbstractGraph):
>>> result = script_creator.run()
"""

def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):

self.library = config['library']

Expand Down
8 changes: 6 additions & 2 deletions scrapegraphai/graphs/search_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from copy import copy, deepcopy
from typing import Optional
from pydantic import BaseModel

from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
Expand Down Expand Up @@ -42,14 +43,16 @@ class SearchGraph(AbstractGraph):
>>> result = search_graph.run()
"""

def __init__(self, prompt: str, config: dict, schema: Optional[str] = None):
def __init__(self, prompt: str, config: dict, schema: Optional[BaseModel] = None):

self.max_results = config.get("max_results", 3)

if all(isinstance(value, str) for value in config.values()):
self.copy_config = copy(config)
else:
self.copy_config = deepcopy(config)

self.copy_schema = deepcopy(schema)

super().__init__(prompt, config, schema)

Expand All @@ -68,7 +71,8 @@ def _create_graph(self) -> BaseGraph:
smart_scraper_instance = SmartScraperGraph(
prompt="",
source="",
config=self.copy_config
config=self.copy_config,
schema=self.copy_schema
)

# ************************************************
Expand Down
3 changes: 2 additions & 1 deletion scrapegraphai/graphs/smart_scraper_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from typing import Optional
from pydantic import BaseModel

from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
Expand Down Expand Up @@ -48,7 +49,7 @@ class SmartScraperGraph(AbstractGraph):
)
"""

def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
super().__init__(prompt, config, source, schema)

self.input_key = "url" if source.startswith("http") else "local_dir"
Expand Down
3 changes: 2 additions & 1 deletion scrapegraphai/graphs/smart_scraper_multi_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from copy import copy, deepcopy
from typing import List, Optional
from pydantic import BaseModel

from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
Expand Down Expand Up @@ -42,7 +43,7 @@ class SmartScraperMultiGraph(AbstractGraph):
>>> result = search_graph.run()
"""

def __init__(self, prompt: str, source: List[str], config: dict, schema: Optional[str] = None):
def __init__(self, prompt: str, source: List[str], config: dict, schema: Optional[BaseModel] = None):

self.max_results = config.get("max_results", 3)

Expand Down
3 changes: 2 additions & 1 deletion scrapegraphai/graphs/speech_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from typing import Optional
from pydantic import BaseModel

from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
Expand Down Expand Up @@ -47,7 +48,7 @@ class SpeechGraph(AbstractGraph):
... {"llm": {"model": "gpt-3.5-turbo"}}
"""

def __init__(self, prompt: str, source: str, config: dict, schema: Optional[str] = None):
def __init__(self, prompt: str, source: str, config: dict, schema: Optional[BaseModel] = None):
super().__init__(prompt, config, source, schema)

self.input_key = "url" if source.startswith("http") else "local_dir"
Expand Down
Loading
Loading