Skip to content

Commit 026a70b

Browse files
committed
fix: bugs
1 parent 257f393 commit 026a70b

File tree

3 files changed

+40
-6
lines changed

3 files changed

+40
-6
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ keywords = [
6060
"web scraping tool",
6161
"webscraping",
6262
"graph",
63+
"llm"
6364
]
6465
classifiers = [
6566
"Intended Audience :: Developers",

scrapegraphai/nodes/generate_answer_csv_node.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(
6060

6161
self.additional_info = node_config.get("additional_info")
6262

63-
def execute(self, state):
63+
async def execute(self, state):
6464
"""
6565
Generates an answer by constructing a prompt from the user's input and the scraped
6666
content, querying the language model, and parsing its response.
@@ -157,7 +157,7 @@ def execute(self, state):
157157
)
158158

159159
merge_chain = merge_prompt | self.llm_model | output_parser
160-
answer = merge_chain.ainvoke({"context": batch_results, "question": user_prompt})
160+
answer = await merge_chain.ainvoke({"context": batch_results, "question": user_prompt})
161161

162162
state.update({self.output[0]: answer})
163163
return state

scrapegraphai/nodes/generate_answer_node.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
"""
2+
GenerateAnswerNode Module
3+
"""
14
from typing import List, Optional
25
from langchain.prompts import PromptTemplate
36
from langchain_core.output_parsers import JsonOutputParser
@@ -15,6 +18,26 @@
1518
)
1619

1720
class GenerateAnswerNode(BaseNode):
21+
"""
22+
Initializes the GenerateAnswerNode class.
23+
24+
Args:
25+
input (str): The input data type for the node.
26+
output (List[str]): The output data type(s) for the node.
27+
node_config (Optional[dict]): Configuration dictionary for the node,
28+
which includes the LLM model, verbosity, schema, and other settings.
29+
Defaults to None.
30+
node_name (str): The name of the node. Defaults to "GenerateAnswer".
31+
32+
Attributes:
33+
llm_model: The language model specified in the node configuration.
34+
verbose (bool): Whether verbose mode is enabled.
35+
force (bool): Whether to force certain behaviors, overriding defaults.
36+
script_creator (bool): Whether the node is in script creation mode.
37+
is_md_scraper (bool): Whether the node is scraping markdown data.
38+
additional_info (Optional[str]): Any additional information to be
39+
included in the prompt templates.
40+
"""
1841
def __init__(
1942
self,
2043
input: str,
@@ -34,7 +57,17 @@ def __init__(
3457
self.is_md_scraper = node_config.get("is_md_scraper", False)
3558
self.additional_info = node_config.get("additional_info")
3659

37-
def execute(self, state: dict) -> dict:
60+
async def execute(self, state: dict) -> dict:
61+
"""
62+
Executes the GenerateAnswerNode.
63+
64+
Args:
65+
state (dict): The current state of the graph. The input keys will be used
66+
to fetch the correct data from the state.
67+
68+
Returns:
69+
dict: The updated state with the output key containing the generated answer.
70+
"""
3871
self.logger.info(f"--- Executing {self.node_name} Node ---")
3972

4073
input_keys = self.get_input_keys(state)
@@ -90,7 +123,7 @@ def execute(self, state: dict) -> dict:
90123
chain = prompt | self.llm_model
91124
if output_parser:
92125
chain = chain | output_parser
93-
answer = chain.ainvoke({"question": user_prompt})
126+
answer = await chain.ainvoke({"question": user_prompt})
94127

95128
state.update({self.output[0]: answer})
96129
return state
@@ -110,7 +143,7 @@ def execute(self, state: dict) -> dict:
110143
chains_dict[chain_name] = chains_dict[chain_name] | output_parser
111144

112145
async_runner = RunnableParallel(**chains_dict)
113-
batch_results = async_runner.invoke({"question": user_prompt})
146+
batch_results = await async_runner.ainvoke({"question": user_prompt})
114147

115148
merge_prompt = PromptTemplate(
116149
template=template_merge_prompt,
@@ -121,7 +154,7 @@ def execute(self, state: dict) -> dict:
121154
merge_chain = merge_prompt | self.llm_model
122155
if output_parser:
123156
merge_chain = merge_chain | output_parser
124-
answer = merge_chain.ainvoke({"context": batch_results, "question": user_prompt})
157+
answer = await merge_chain.ainvoke({"context": batch_results, "question": user_prompt})
125158

126159
state.update({self.output[0]: answer})
127160
return state

0 commit comments

Comments
 (0)