Skip to content

feat: refactoring of the conditional node #728

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,29 @@
import os
import json
from dotenv import load_dotenv
from scrapegraphai.graphs import SmartScraperMultiCondGraph
from scrapegraphai.graphs import SmartScraperMultiGraph

load_dotenv()

# ************************************************
# Define the configuration for the graph
# ************************************************

groq_key = os.getenv("GROQ_APIKEY")

graph_config = {
"llm": {
"model": "groq/gemma-7b-it",
"api_key": groq_key,
"temperature": 0
"api_key": os.getenv("OPENAI_API_KEY"),
"model": "openai/gpt-4o",
},
"headless": False

"verbose": True,
"headless": False,
}

# *******************************************************
# Create the SmartScraperMultiCondGraph instance and run it
# *******************************************************

multiple_search_graph = SmartScraperMultiCondGraph(
multiple_search_graph = SmartScraperMultiGraph(
prompt="Who is Marco Perini?",
source=[
"https://perinim.github.io/",
Expand Down
1 change: 0 additions & 1 deletion scrapegraphai/graphs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,4 @@
from .screenshot_scraper_graph import ScreenshotScraperGraph
from .smart_scraper_multi_concat_graph import SmartScraperMultiConcatGraph
from .code_generator_graph import CodeGeneratorGraph
from .smart_scraper_multi_cond_graph import SmartScraperMultiCondGraph
from .depth_search_graph import DepthSearchGraph
2 changes: 1 addition & 1 deletion scrapegraphai/graphs/markdown_scraper_multi_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class MDScraperMultiGraph(AbstractGraph):
>>> result = search_graph.run()
"""

def __init__(self, prompt: str, source: List[str],
def __init__(self, prompt: str, source: List[str],
config: dict, schema: Optional[BaseModel] = None):
self.copy_config = safe_deepcopy(config)
self.copy_schema = deepcopy(schema)
Expand Down
3 changes: 2 additions & 1 deletion scrapegraphai/graphs/smart_scraper_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
FetchNode,
ParseNode,
ReasoningNode,
GenerateAnswerNode
GenerateAnswerNode,
ConditionalNode
)

class SmartScraperGraph(AbstractGraph):
Expand Down
62 changes: 47 additions & 15 deletions scrapegraphai/graphs/smart_scraper_multi_concat_graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
SmartScraperMultiGraph Module
"""
SmartScraperMultiCondGraph Module with ConditionalNode
"""
from copy import deepcopy
from typing import List, Optional
Expand All @@ -9,15 +9,16 @@
from .smart_scraper_graph import SmartScraperGraph
from ..nodes import (
GraphIteratorNode,
ConcatAnswersNode
MergeAnswersNode,
ConcatAnswersNode,
ConditionalNode
)
from ..utils.copy import safe_deepcopy

class SmartScraperMultiConcatGraph(AbstractGraph):
class SmartScraperMultiCondGraph(AbstractGraph):
"""
SmartScraperMultiGraph is a scraping pipeline that scrapes a
SmartScraperMultiConditionalGraph is a scraping pipeline that scrapes a
list of URLs and generates answers to a given prompt.
It only requires a user prompt and a list of URLs.

Attributes:
prompt (str): The user prompt to search the internet.
Expand All @@ -34,24 +35,26 @@ class SmartScraperMultiConcatGraph(AbstractGraph):
schema (Optional[BaseModel]): The schema for the graph output.

Example:
>>> search_graph = SmartScraperMultiConcatGraph(
>>> search_graph = MultipleSearchGraph(
... "What is Chioggia famous for?",
... {"llm": {"model": "openai/gpt-3.5-turbo"}}
... )
>>> result = search_graph.run()
"""

def __init__(self, prompt: str, source: List[str],
def __init__(self, prompt: str, source: List[str],
config: dict, schema: Optional[BaseModel] = None):
self.copy_config = safe_deepcopy(config)

self.max_results = config.get("max_results", 3)
self.copy_config = safe_deepcopy(config)
self.copy_schema = deepcopy(schema)

super().__init__(prompt, config, source, schema)

def _create_graph(self) -> BaseGraph:
"""
Creates the graph of nodes representing the workflow for web scraping and searching.
Creates the graph of nodes representing the workflow for web scraping and searching,
including a ConditionalNode to decide between merging or concatenating the results.

Returns:
BaseGraph: A graph instance representing the web scraping and searching workflow.
Expand All @@ -65,20 +68,49 @@ def _create_graph(self) -> BaseGraph:
"scraper_config": self.copy_config,
},
schema=self.copy_schema,
node_name="GraphIteratorNode"
)

conditional_node = ConditionalNode(
input="results",
output=["results"],
node_name="ConditionalNode",
node_config={
'key_name': 'results',
'condition': 'len(results) > 2'
}
)

merge_answers_node = MergeAnswersNode(
input="user_prompt & results",
output=["answer"],
node_config={
"llm_model": self.llm_model,
"schema": self.copy_schema
},
node_name="MergeAnswersNode"
)

concat_answers_node = ConcatAnswersNode(
concat_node = ConcatAnswersNode(
input="results",
output=["answer"]
output=["answer"],
node_config={},
node_name="ConcatNode"
)

return BaseGraph(
nodes=[
graph_iterator_node,
concat_answers_node,
conditional_node,
merge_answers_node,
concat_node,
],
edges=[
(graph_iterator_node, concat_answers_node),
(graph_iterator_node, conditional_node),
# True node (len(results) > 2)
(conditional_node, merge_answers_node),
# False node (len(results) <= 2)
(conditional_node, concat_node)
],
entry_point=graph_iterator_node,
graph_name=self.__class__.__name__
Expand Down
130 changes: 0 additions & 130 deletions scrapegraphai/graphs/smart_scraper_multi_cond_graph.py

This file was deleted.

15 changes: 5 additions & 10 deletions scrapegraphai/nodes/conditional_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,15 @@ def __init__(self,
Initializes an empty ConditionalNode.
"""
super().__init__(node_name, "conditional_node", input, output, 2, node_config)

try:
self.key_name = self.node_config["key_name"]
except:
raise NotImplementedError("You need to provide key_name inside the node config")

self.true_node_name = None
self.false_node_name = None

self.condition = self.node_config.get("condition", None)

self.eval_instance = EvalWithCompoundTypes()
self.eval_instance.functions = {'len': len}

Expand All @@ -65,21 +63,18 @@ def execute(self, state: dict) -> dict:

if self.true_node_name is None or self.false_node_name is None:
raise ValueError("ConditionalNode's next nodes are not set properly.")

# Evaluate the condition

if self.condition:
condition_result = self._evaluate_condition(state, self.condition)
else:
# Default behavior: check existence and non-emptiness of key_name
value = state.get(self.key_name)
condition_result = value is not None and value != ''

# Return the appropriate next node name
if condition_result:
return self.true_node_name
else:
return self.false_node_name

def _evaluate_condition(self, state: dict, condition: str) -> bool:
"""
Parses and evaluates the condition expression against the state.
Expand All @@ -104,4 +99,4 @@ def _evaluate_condition(self, state: dict, condition: str) -> bool:
)
return bool(result)
except Exception as e:
raise ValueError(f"Error evaluating condition '{condition}' in {self.node_name}: {e}")
raise ValueError(f"Error evaluating condition '{condition}' in {self.node_name}: {e}")
Loading