Skip to content

GraphIteratorNode and MergeAnswersNode #155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions examples/openai/custom_graph_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import os
from dotenv import load_dotenv

from langchain_openai import OpenAIEmbeddings
from scrapegraphai.models import OpenAI
from scrapegraphai.graphs import BaseGraph
from scrapegraphai.nodes import FetchNode, ParseNode, RAGNode, GenerateAnswerNode, RobotsNode
Expand All @@ -20,7 +22,7 @@
"api_key": openai_key,
"model": "gpt-3.5-turbo",
"temperature": 0,
"streaming": True
"streaming": False
},
}

Expand All @@ -29,33 +31,50 @@
# ************************************************

llm_model = OpenAI(graph_config["llm"])
embedder = OpenAIEmbeddings(api_key=llm_model.openai_api_key)

# define the nodes for the graph
robot_node = RobotsNode(
input="url",
output=["is_scrapable"],
node_config={"llm_model": llm_model}
node_config={
"llm_model": llm_model,
"verbose": True,
}
)

fetch_node = FetchNode(
input="url | local_dir",
output=["doc"],
node_config={"headless": True, "verbose": True}
node_config={
"verbose": True,
"headless": True,
}
)
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
node_config={"chunk_size": 4096}
node_config={
"chunk_size": 4096,
"verbose": True,
}
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
node_config={"llm_model": llm_model},
node_config={
"llm_model": llm_model,
"embedder_model": embedder,
"verbose": True,
}
)
generate_answer_node = GenerateAnswerNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",
output=["answer"],
node_config={"llm_model": llm_model},
node_config={
"llm_model": llm_model,
"verbose": True,
}
)

# ************************************************
Expand Down
98 changes: 98 additions & 0 deletions examples/openai/search_graph_multi.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why you do not use smart scraper graph?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm passing it inside the GraphIteratorNode configuration with empty prompt and source since they are mandatory for a SmartScraper. Inside GraphIteratorNode they will be replaced correctly

Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""
Example of custom graph using existing nodes
"""

import os
from dotenv import load_dotenv
from langchain_openai import OpenAIEmbeddings
from scrapegraphai.models import OpenAI
from scrapegraphai.graphs import BaseGraph, SmartScraperGraph
from scrapegraphai.nodes import SearchInternetNode, GraphIteratorNode, MergeAnswersNode
load_dotenv()

# ************************************************
# Define the configuration for the graph
# ************************************************

openai_key = os.getenv("OPENAI_APIKEY")

graph_config = {
"llm": {
"api_key": openai_key,
"model": "gpt-3.5-turbo",
},
}

# ************************************************
# Create a SmartScraperGraph instance
# ************************************************

smart_scraper_graph = SmartScraperGraph(
prompt="",
source="",
config=graph_config
)

# ************************************************
# Define the graph nodes
# ************************************************

llm_model = OpenAI(graph_config["llm"])
embedder = OpenAIEmbeddings(api_key=llm_model.openai_api_key)

search_internet_node = SearchInternetNode(
input="user_prompt",
output=["urls"],
node_config={
"llm_model": llm_model,
"max_results": 5, # num of search results to fetch
"verbose": True,
}
)

graph_iterator_node = GraphIteratorNode(
input="user_prompt & urls",
output=["results"],
node_config={
"graph_instance": smart_scraper_graph,
"verbose": True,
}
)

merge_answers_node = MergeAnswersNode(
input="user_prompt & results",
output=["answer"],
node_config={
"llm_model": llm_model,
"verbose": True,
}
)

# ************************************************
# Create the graph by defining the connections
# ************************************************

graph = BaseGraph(
nodes=[
search_internet_node,
graph_iterator_node,
merge_answers_node
],
edges=[
(search_internet_node, graph_iterator_node),
(graph_iterator_node, merge_answers_node)
],
entry_point=search_internet_node
)

# ************************************************
# Execute the graph
# ************************************************

result, execution_info = graph.execute({
"user_prompt": "List me all the typical Chioggia dishes."
})

# get the answer from the result
result = result.get("answer", "No answer found.")
print(result)
4 changes: 3 additions & 1 deletion examples/openai/search_graph_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@
"api_key": openai_key,
"model": "gpt-3.5-turbo",
},
"max_results": 5,
"verbose": True,
}

# ************************************************
# Create the SearchGraph instance and run it
# ************************************************

search_graph = SearchGraph(
prompt="List me top 5 eyeliner products for a gift.",
prompt="List me the best escursions near Trento",
config=graph_config
)

Expand Down
2 changes: 1 addition & 1 deletion examples/openai/smart_scraper_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"api_key": openai_key,
"model": "gpt-3.5-turbo",
},
"verbose": True,
"verbose": False,
}

# ************************************************
Expand Down
1 change: 1 addition & 0 deletions scrapegraphai/graphs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
__init__.py file for graphs folder
"""

from .abstract_graph import AbstractGraph
from .base_graph import BaseGraph
from .smart_scraper_graph import SmartScraperGraph
from .speech_graph import SpeechGraph
Expand Down
2 changes: 1 addition & 1 deletion scrapegraphai/graphs/abstract_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
self.execution_info = None

# Set common configuration parameters
self.verbose = True if config is None else config.get("verbose", False)
self.verbose = False if config is None else config.get("verbose", False)
self.headless = True if config is None else config.get(
"headless", True)
common_params = {"headless": self.headless,
Expand Down
70 changes: 37 additions & 33 deletions scrapegraphai/graphs/search_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
from .base_graph import BaseGraph
from ..nodes import (
SearchInternetNode,
FetchNode,
ParseNode,
RAGNode,
GenerateAnswerNode
GraphIteratorNode,
MergeAnswersNode
)
from .abstract_graph import AbstractGraph
from .smart_scraper_graph import SmartScraperGraph


class SearchGraph(AbstractGraph):
Expand Down Expand Up @@ -38,6 +37,11 @@ class SearchGraph(AbstractGraph):
>>> result = search_graph.run()
"""

def __init__(self, prompt: str, config: dict):

self.max_results = config.get("max_results", 3)
super().__init__(prompt, config)

def _create_graph(self) -> BaseGraph:
"""
Creates the graph of nodes representing the workflow for web scraping and searching.
Expand All @@ -46,53 +50,53 @@ def _create_graph(self) -> BaseGraph:
BaseGraph: A graph instance representing the web scraping and searching workflow.
"""

# ************************************************
# Create a SmartScraperGraph instance
# ************************************************

smart_scraper_instance = SmartScraperGraph(
prompt="",
source="",
config=self.config
)

# ************************************************
# Define the graph nodes
# ************************************************

search_internet_node = SearchInternetNode(
input="user_prompt",
output=["url"],
node_config={
"llm_model": self.llm_model
}
)
fetch_node = FetchNode(
input="url | local_dir",
output=["doc"]
)
parse_node = ParseNode(
input="doc",
output=["parsed_doc"],
output=["urls"],
node_config={
"chunk_size": self.model_token
"llm_model": self.llm_model,
"max_results": self.max_results
}
)
rag_node = RAGNode(
input="user_prompt & (parsed_doc | doc)",
output=["relevant_chunks"],
graph_iterator_node = GraphIteratorNode(
input="user_prompt & urls",
output=["results"],
node_config={
"llm_model": self.llm_model,
"embedder_model": self.embedder_model
"graph_instance": smart_scraper_instance,
}
)
generate_answer_node = GenerateAnswerNode(
input="user_prompt & (relevant_chunks | parsed_doc | doc)",

merge_answers_node = MergeAnswersNode(
input="user_prompt & results",
output=["answer"],
node_config={
"llm_model": self.llm_model
"llm_model": self.llm_model,
}
)

return BaseGraph(
nodes=[
search_internet_node,
fetch_node,
parse_node,
rag_node,
generate_answer_node,
graph_iterator_node,
merge_answers_node
],
edges=[
(search_internet_node, fetch_node),
(fetch_node, parse_node),
(parse_node, rag_node),
(rag_node, generate_answer_node)
(search_internet_node, graph_iterator_node),
(graph_iterator_node, merge_answers_node)
],
entry_point=search_internet_node
)
Expand Down
2 changes: 2 additions & 0 deletions scrapegraphai/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@
from .robots_node import RobotsNode
from .generate_answer_csv_node import GenerateAnswerCSVNode
from .generate_answer_pdf_node import GenerateAnswerPDFNode
from .graph_iterator_node import GraphIteratorNode
from .merge_answers_node import MergeAnswersNode
4 changes: 2 additions & 2 deletions scrapegraphai/nodes/generate_answer_csv_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Module for generating the answer node
"""
# Imports from standard library
from typing import List
from typing import List, Optional
from tqdm import tqdm

# Imports from Langchain
Expand Down Expand Up @@ -39,7 +39,7 @@ class GenerateAnswerCSVNode(BaseNode):
updating the state with the generated answer under the 'answer' key.
"""

def __init__(self, input: str, output: List[str], node_config: dict,
def __init__(self, input: str, output: List[str], node_config: Optional[dict] = None,
node_name: str = "GenerateAnswer"):
"""
Initializes the GenerateAnswerNodeCsv with a language model client and a node name.
Expand Down
4 changes: 2 additions & 2 deletions scrapegraphai/nodes/generate_answer_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

# Imports from standard library
from typing import List
from typing import List, Optional
from tqdm import tqdm

# Imports from Langchain
Expand Down Expand Up @@ -33,7 +33,7 @@ class GenerateAnswerNode(BaseNode):
node_name (str): The unique identifier name for the node, defaulting to "GenerateAnswer".
"""

def __init__(self, input: str, output: List[str], node_config: dict,
def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None,
node_name: str = "GenerateAnswer"):
super().__init__(node_name, "node", input, output, 2, node_config)

Expand Down
4 changes: 2 additions & 2 deletions scrapegraphai/nodes/generate_answer_node_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Module for generating the answer node
"""
# Imports from standard library
from typing import List
from typing import List, Optional
from tqdm import tqdm

# Imports from Langchain
Expand Down Expand Up @@ -39,7 +39,7 @@ class GenerateAnswerCSVNode(BaseNode):
updating the state with the generated answer under the 'answer' key.
"""

def __init__(self, input: str, output: List[str], node_config: dict,
def __init__(self, input: str, output: List[str], node_config: Optional[dict] = None,
node_name: str = "GenerateAnswer"):
"""
Initializes the GenerateAnswerNodeCsv with a language model client and a node name.
Expand Down
Loading