Skip to content

Commit 420c71b

Browse files
committed
feat: refactoring of the conditional node
1 parent ea9ed1a commit 420c71b

File tree

7 files changed

+62
-166
lines changed

7 files changed

+62
-166
lines changed

examples/groq/smart_scraper_multi_cond_groq.py renamed to examples/extras/conditional_usage.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,30 +5,29 @@
55
import os
66
import json
77
from dotenv import load_dotenv
8-
from scrapegraphai.graphs import SmartScraperMultiCondGraph
8+
from scrapegraphai.graphs import SmartScraperMultiGraph
99

1010
load_dotenv()
1111

1212
# ************************************************
1313
# Define the configuration for the graph
1414
# ************************************************
1515

16-
groq_key = os.getenv("GROQ_APIKEY")
17-
1816
graph_config = {
1917
"llm": {
20-
"model": "groq/gemma-7b-it",
21-
"api_key": groq_key,
22-
"temperature": 0
18+
"api_key": os.getenv("OPENAI_API_KEY"),
19+
"model": "openai/gpt-4o",
2320
},
24-
"headless": False
21+
22+
"verbose": True,
23+
"headless": False,
2524
}
2625

2726
# *******************************************************
2827
# Create the SmartScraperMultiCondGraph instance and run it
2928
# *******************************************************
3029

31-
multiple_search_graph = SmartScraperMultiCondGraph(
30+
multiple_search_graph = SmartScraperMultiGraph(
3231
prompt="Who is Marco Perini?",
3332
source=[
3433
"https://perinim.github.io/",

scrapegraphai/graphs/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,4 @@
2626
from .screenshot_scraper_graph import ScreenshotScraperGraph
2727
from .smart_scraper_multi_concat_graph import SmartScraperMultiConcatGraph
2828
from .code_generator_graph import CodeGeneratorGraph
29-
from .smart_scraper_multi_cond_graph import SmartScraperMultiCondGraph
3029
from .depth_search_graph import DepthSearchGraph

scrapegraphai/graphs/markdown_scraper_multi_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class MDScraperMultiGraph(AbstractGraph):
4141
>>> result = search_graph.run()
4242
"""
4343

44-
def __init__(self, prompt: str, source: List[str],
44+
def __init__(self, prompt: str, source: List[str],
4545
config: dict, schema: Optional[BaseModel] = None):
4646
self.copy_config = safe_deepcopy(config)
4747
self.copy_schema = deepcopy(schema)

scrapegraphai/graphs/smart_scraper_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
FetchNode,
1111
ParseNode,
1212
ReasoningNode,
13-
GenerateAnswerNode
13+
GenerateAnswerNode,
14+
ConditionalNode
1415
)
1516

1617
class SmartScraperGraph(AbstractGraph):

scrapegraphai/graphs/smart_scraper_multi_concat_graph.py

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
"""
2-
SmartScraperMultiGraph Module
1+
"""
2+
SmartScraperMultiCondGraph Module with ConditionalNode
33
"""
44
from copy import deepcopy
55
from typing import List, Optional
@@ -9,15 +9,16 @@
99
from .smart_scraper_graph import SmartScraperGraph
1010
from ..nodes import (
1111
GraphIteratorNode,
12-
ConcatAnswersNode
12+
MergeAnswersNode,
13+
ConcatAnswersNode,
14+
ConditionalNode
1315
)
1416
from ..utils.copy import safe_deepcopy
1517

16-
class SmartScraperMultiConcatGraph(AbstractGraph):
18+
class SmartScraperMultiCondGraph(AbstractGraph):
1719
"""
18-
SmartScraperMultiGraph is a scraping pipeline that scrapes a
20+
SmartScraperMultiConditionalGraph is a scraping pipeline that scrapes a
1921
list of URLs and generates answers to a given prompt.
20-
It only requires a user prompt and a list of URLs.
2122
2223
Attributes:
2324
prompt (str): The user prompt to search the internet.
@@ -34,24 +35,26 @@ class SmartScraperMultiConcatGraph(AbstractGraph):
3435
schema (Optional[BaseModel]): The schema for the graph output.
3536
3637
Example:
37-
>>> search_graph = SmartScraperMultiConcatGraph(
38+
>>> search_graph = MultipleSearchGraph(
3839
... "What is Chioggia famous for?",
3940
... {"llm": {"model": "openai/gpt-3.5-turbo"}}
4041
... )
4142
>>> result = search_graph.run()
4243
"""
43-
44-
def __init__(self, prompt: str, source: List[str],
44+
45+
def __init__(self, prompt: str, source: List[str],
4546
config: dict, schema: Optional[BaseModel] = None):
46-
self.copy_config = safe_deepcopy(config)
4747

48+
self.max_results = config.get("max_results", 3)
49+
self.copy_config = safe_deepcopy(config)
4850
self.copy_schema = deepcopy(schema)
4951

5052
super().__init__(prompt, config, source, schema)
5153

5254
def _create_graph(self) -> BaseGraph:
5355
"""
54-
Creates the graph of nodes representing the workflow for web scraping and searching.
56+
Creates the graph of nodes representing the workflow for web scraping and searching,
57+
including a ConditionalNode to decide between merging or concatenating the results.
5558
5659
Returns:
5760
BaseGraph: A graph instance representing the web scraping and searching workflow.
@@ -65,20 +68,49 @@ def _create_graph(self) -> BaseGraph:
6568
"scraper_config": self.copy_config,
6669
},
6770
schema=self.copy_schema,
71+
node_name="GraphIteratorNode"
72+
)
73+
74+
conditional_node = ConditionalNode(
75+
input="results",
76+
output=["results"],
77+
node_name="ConditionalNode",
78+
node_config={
79+
'key_name': 'results',
80+
'condition': 'len(results) > 2'
81+
}
82+
)
83+
84+
merge_answers_node = MergeAnswersNode(
85+
input="user_prompt & results",
86+
output=["answer"],
87+
node_config={
88+
"llm_model": self.llm_model,
89+
"schema": self.copy_schema
90+
},
91+
node_name="MergeAnswersNode"
6892
)
6993

70-
concat_answers_node = ConcatAnswersNode(
94+
concat_node = ConcatAnswersNode(
7195
input="results",
72-
output=["answer"]
96+
output=["answer"],
97+
node_config={},
98+
node_name="ConcatNode"
7399
)
74100

75101
return BaseGraph(
76102
nodes=[
77103
graph_iterator_node,
78-
concat_answers_node,
104+
conditional_node,
105+
merge_answers_node,
106+
concat_node,
79107
],
80108
edges=[
81-
(graph_iterator_node, concat_answers_node),
109+
(graph_iterator_node, conditional_node),
110+
# True node (len(results) > 2)
111+
(conditional_node, merge_answers_node),
112+
# False node (len(results) <= 2)
113+
(conditional_node, concat_node)
82114
],
83115
entry_point=graph_iterator_node,
84116
graph_name=self.__class__.__name__

scrapegraphai/graphs/smart_scraper_multi_cond_graph.py

Lines changed: 0 additions & 130 deletions
This file was deleted.

scrapegraphai/nodes/conditional_node.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,15 @@ def __init__(self,
3838
Initializes an empty ConditionalNode.
3939
"""
4040
super().__init__(node_name, "conditional_node", input, output, 2, node_config)
41-
41+
4242
try:
4343
self.key_name = self.node_config["key_name"]
4444
except:
4545
raise NotImplementedError("You need to provide key_name inside the node config")
46-
46+
4747
self.true_node_name = None
4848
self.false_node_name = None
49-
5049
self.condition = self.node_config.get("condition", None)
51-
5250
self.eval_instance = EvalWithCompoundTypes()
5351
self.eval_instance.functions = {'len': len}
5452

@@ -65,21 +63,18 @@ def execute(self, state: dict) -> dict:
6563

6664
if self.true_node_name is None or self.false_node_name is None:
6765
raise ValueError("ConditionalNode's next nodes are not set properly.")
68-
69-
# Evaluate the condition
66+
7067
if self.condition:
7168
condition_result = self._evaluate_condition(state, self.condition)
7269
else:
73-
# Default behavior: check existence and non-emptiness of key_name
7470
value = state.get(self.key_name)
7571
condition_result = value is not None and value != ''
7672

77-
# Return the appropriate next node name
7873
if condition_result:
7974
return self.true_node_name
8075
else:
8176
return self.false_node_name
82-
77+
8378
def _evaluate_condition(self, state: dict, condition: str) -> bool:
8479
"""
8580
Parses and evaluates the condition expression against the state.
@@ -104,4 +99,4 @@ def _evaluate_condition(self, state: dict, condition: str) -> bool:
10499
)
105100
return bool(result)
106101
except Exception as e:
107-
raise ValueError(f"Error evaluating condition '{condition}' in {self.node_name}: {e}")
102+
raise ValueError(f"Error evaluating condition '{condition}' in {self.node_name}: {e}")

0 commit comments

Comments
 (0)