Skip to content

Commit f837dc1

Browse files
committed
feat: conditional_node
1 parent 154ca4c commit f837dc1

File tree

7 files changed

+260
-6
lines changed

7 files changed

+260
-6
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""
2+
Basic example of scraping pipeline using SmartScraperMultiConcatGraph with Groq
3+
"""
4+
5+
import os
6+
import json
7+
from dotenv import load_dotenv
8+
from scrapegraphai.graphs import SmartScraperMultiCondGraph
9+
10+
load_dotenv()
11+
12+
# ************************************************
13+
# Define the configuration for the graph
14+
# ************************************************
15+
16+
groq_key = os.getenv("GROQ_APIKEY")
17+
18+
graph_config = {
19+
"llm": {
20+
"model": "groq/gemma-7b-it",
21+
"api_key": groq_key,
22+
"temperature": 0
23+
},
24+
"headless": False
25+
}
26+
27+
# *******************************************************
28+
# Create the SmartScraperMultiCondGraph instance and run it
29+
# *******************************************************
30+
31+
multiple_search_graph = SmartScraperMultiCondGraph(
32+
prompt="Who is Marco Perini?",
33+
source=[
34+
"https://perinim.github.io/",
35+
"https://perinim.github.io/cv/"
36+
],
37+
schema=None,
38+
config=graph_config
39+
)
40+
41+
result = multiple_search_graph.run()
42+
print(json.dumps(result, indent=4))

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ undetected-playwright>=0.3.0
1818
google>=3.0.0
1919
semchunk>=1.0.1
2020
langchain-ollama>=0.1.3
21+
simpleeval>=0.9.13

scrapegraphai/graphs/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,4 @@
2626
from .screenshot_scraper_graph import ScreenshotScraperGraph
2727
from .smart_scraper_multi_concat_graph import SmartScraperMultiConcatGraph
2828
from .code_generator_graph import CodeGeneratorGraph
29+
from .smart_scraper_multi_cond_graph import SmartScraperMultiCondGraph

scrapegraphai/graphs/base_graph.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ def __init__(self, nodes: list, edges: list, entry_point: str,
5959
# raise a warning if the entry point is not the first node in the list
6060
warnings.warn(
6161
"Careful! The entry point node is different from the first node in the graph.")
62+
63+
self._set_conditional_node_edges()
6264

6365
# Burr configuration
6466
self.use_burr = use_burr
@@ -77,9 +79,24 @@ def _create_edges(self, edges: list) -> dict:
7779

7880
edge_dict = {}
7981
for from_node, to_node in edges:
80-
edge_dict[from_node.node_name] = to_node.node_name
82+
if from_node.node_type != 'conditional_node':
83+
edge_dict[from_node.node_name] = to_node.node_name
8184
return edge_dict
8285

86+
def _set_conditional_node_edges(self):
87+
"""
88+
Sets the true_node_name and false_node_name for each ConditionalNode.
89+
"""
90+
for node in self.nodes:
91+
if node.node_type == 'conditional_node':
92+
# Find outgoing edges from this ConditionalNode
93+
outgoing_edges = [(from_node, to_node) for from_node, to_node in self.raw_edges if from_node.node_name == node.node_name]
94+
if len(outgoing_edges) != 2:
95+
raise ValueError(f"ConditionalNode '{node.node_name}' must have exactly two outgoing edges.")
96+
# Assign true_node_name and false_node_name
97+
node.true_node_name = outgoing_edges[0][1].node_name
98+
node.false_node_name = outgoing_edges[1][1].node_name
99+
83100
def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
84101
"""
85102
Executes the graph by traversing nodes starting from the
@@ -201,7 +218,12 @@ def _execute_standard(self, initial_state: dict) -> Tuple[dict, list]:
201218
cb_total["total_cost_USD"] += cb_data["total_cost_USD"]
202219

203220
if current_node.node_type == "conditional_node":
204-
current_node_name = result
221+
node_names = {node.node_name for node in self.nodes}
222+
if result in node_names:
223+
current_node_name = result
224+
else:
225+
raise ValueError(f"Conditional Node returned a node name '{result}' that does not exist in the graph")
226+
205227
elif current_node_name in self.edges:
206228
current_node_name = self.edges[current_node_name]
207229
else:
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
"""
2+
SmartScraperMultiCondGraph Module with ConditionalNode
3+
"""
4+
from copy import deepcopy
5+
from typing import List, Optional
6+
from pydantic import BaseModel
7+
from .base_graph import BaseGraph
8+
from .abstract_graph import AbstractGraph
9+
from .smart_scraper_graph import SmartScraperGraph
10+
from ..nodes import (
11+
GraphIteratorNode,
12+
MergeAnswersNode,
13+
ConcatAnswersNode,
14+
ConditionalNode
15+
)
16+
from ..utils.copy import safe_deepcopy
17+
18+
class SmartScraperMultiCondGraph(AbstractGraph):
19+
"""
20+
SmartScraperMultiConditionalGraph is a scraping pipeline that scrapes a
21+
list of URLs and generates answers to a given prompt.
22+
23+
Attributes:
24+
prompt (str): The user prompt to search the internet.
25+
llm_model (dict): The configuration for the language model.
26+
embedder_model (dict): The configuration for the embedder model.
27+
headless (bool): A flag to run the browser in headless mode.
28+
verbose (bool): A flag to display the execution information.
29+
model_token (int): The token limit for the language model.
30+
31+
Args:
32+
prompt (str): The user prompt to search the internet.
33+
source (List[str]): The source of the graph.
34+
config (dict): Configuration parameters for the graph.
35+
schema (Optional[BaseModel]): The schema for the graph output.
36+
37+
Example:
38+
>>> search_graph = MultipleSearchGraph(
39+
... "What is Chioggia famous for?",
40+
... {"llm": {"model": "openai/gpt-3.5-turbo"}}
41+
... )
42+
>>> result = search_graph.run()
43+
"""
44+
45+
def __init__(self, prompt: str, source: List[str],
46+
config: dict, schema: Optional[BaseModel] = None):
47+
48+
self.max_results = config.get("max_results", 3)
49+
self.copy_config = safe_deepcopy(config)
50+
self.copy_schema = deepcopy(schema)
51+
52+
super().__init__(prompt, config, source, schema)
53+
54+
def _create_graph(self) -> BaseGraph:
55+
"""
56+
Creates the graph of nodes representing the workflow for web scraping and searching,
57+
including a ConditionalNode to decide between merging or concatenating the results.
58+
59+
Returns:
60+
BaseGraph: A graph instance representing the web scraping and searching workflow.
61+
"""
62+
63+
# Node that iterates over the URLs and collects results
64+
graph_iterator_node = GraphIteratorNode(
65+
input="user_prompt & urls",
66+
output=["results"],
67+
node_config={
68+
"graph_instance": SmartScraperGraph,
69+
"scraper_config": self.copy_config,
70+
},
71+
schema=self.copy_schema,
72+
node_name="GraphIteratorNode"
73+
)
74+
75+
# ConditionalNode to check if len(results) > 2
76+
conditional_node = ConditionalNode(
77+
input="results",
78+
output=["results"],
79+
node_name="ConditionalNode",
80+
node_config={
81+
'key_name': 'results',
82+
'condition': 'len(results) > 2'
83+
}
84+
)
85+
86+
merge_answers_node = MergeAnswersNode(
87+
input="user_prompt & results",
88+
output=["answer"],
89+
node_config={
90+
"llm_model": self.llm_model,
91+
"schema": self.copy_schema
92+
},
93+
node_name="MergeAnswersNode"
94+
)
95+
96+
concat_node = ConcatAnswersNode(
97+
input="results",
98+
output=["answer"],
99+
node_config={},
100+
node_name="ConcatNode"
101+
)
102+
103+
# Build the graph
104+
return BaseGraph(
105+
nodes=[
106+
graph_iterator_node,
107+
conditional_node,
108+
merge_answers_node,
109+
concat_node,
110+
],
111+
edges=[
112+
(graph_iterator_node, conditional_node),
113+
(conditional_node, merge_answers_node), # True node (len(results) > 2)
114+
(conditional_node, concat_node), # False node (len(results) <= 2)
115+
],
116+
entry_point=graph_iterator_node,
117+
graph_name=self.__class__.__name__
118+
)
119+
120+
def run(self) -> str:
121+
"""
122+
Executes the web scraping and searching process.
123+
124+
Returns:
125+
str: The answer to the prompt.
126+
"""
127+
inputs = {"user_prompt": self.prompt, "urls": self.source}
128+
self.final_state, self.execution_info = self.graph.execute(inputs)
129+
130+
return self.final_state.get("answer", "No answer found.")

scrapegraphai/nodes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,4 @@
2727
from .html_analyzer_node import HtmlAnalyzerNode
2828
from .generate_code_node import GenerateCodeNode
2929
from .search_node_with_context import SearchLinksWithContext
30+
from .conditional_node import ConditionalNode

scrapegraphai/nodes/conditional_node.py

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44
from typing import Optional, List
55
from .base_node import BaseNode
6+
from simpleeval import simple_eval, EvalWithCompoundTypes
67

78
class ConditionalNode(BaseNode):
89
"""
@@ -28,13 +29,28 @@ class ConditionalNode(BaseNode):
2829
2930
"""
3031

31-
def __init__(self):
32+
def __init__(self,
33+
input: str,
34+
output: List[str],
35+
node_config: Optional[dict] = None,
36+
node_name: str = "Cond",):
3237
"""
3338
Initializes an empty ConditionalNode.
3439
"""
35-
#super().__init__(node_name, "node", input, output, 2, node_config)
36-
pass
40+
super().__init__(node_name, "conditional_node", input, output, 2, node_config)
41+
42+
try:
43+
self.key_name = self.node_config["key_name"]
44+
except:
45+
raise NotImplementedError("You need to provide key_name inside the node config")
46+
47+
self.true_node_name = None
48+
self.false_node_name = None
3749

50+
self.condition = self.node_config.get("condition", None)
51+
52+
self.eval_instance = EvalWithCompoundTypes()
53+
self.eval_instance.functions = {'len': len}
3854

3955
def execute(self, state: dict) -> dict:
4056
"""
@@ -47,4 +63,45 @@ def execute(self, state: dict) -> dict:
4763
str: The name of the next node to execute based on the presence of the key.
4864
"""
4965

50-
pass
66+
if self.true_node_name is None or self.false_node_name is None:
67+
raise ValueError("ConditionalNode's next nodes are not set properly.")
68+
69+
# Evaluate the condition
70+
if self.condition:
71+
condition_result = self._evaluate_condition(state, self.condition)
72+
else:
73+
# Default behavior: check existence and non-emptiness of key_name
74+
value = state.get(self.key_name)
75+
condition_result = value is not None and value != ''
76+
77+
# Return the appropriate next node name
78+
if condition_result:
79+
return self.true_node_name
80+
else:
81+
return self.false_node_name
82+
83+
def _evaluate_condition(self, state: dict, condition: str) -> bool:
84+
"""
85+
Parses and evaluates the condition expression against the state.
86+
87+
Args:
88+
state (dict): The current state of the graph.
89+
condition (str): The condition expression to evaluate.
90+
91+
Returns:
92+
bool: The result of the condition evaluation.
93+
"""
94+
# Combine state and allowed functions for evaluation context
95+
eval_globals = self.eval_instance.functions.copy()
96+
eval_globals.update(state)
97+
98+
try:
99+
result = simple_eval(
100+
condition,
101+
names=eval_globals,
102+
functions=self.eval_instance.functions,
103+
operators=self.eval_instance.operators
104+
)
105+
return bool(result)
106+
except Exception as e:
107+
raise ValueError(f"Error evaluating condition '{condition}' in {self.node_name}: {e}")

0 commit comments

Comments
 (0)