1
+ """
2
+ GenerateAnswerNode Module
3
+ """
1
4
from typing import List , Optional
2
5
from langchain .prompts import PromptTemplate
3
6
from langchain_core .output_parsers import JsonOutputParser
15
18
)
16
19
17
20
class GenerateAnswerNode (BaseNode ):
21
+ """
22
+ Initializes the GenerateAnswerNode class.
23
+
24
+ Args:
25
+ input (str): The input data type for the node.
26
+ output (List[str]): The output data type(s) for the node.
27
+ node_config (Optional[dict]): Configuration dictionary for the node,
28
+ which includes the LLM model, verbosity, schema, and other settings.
29
+ Defaults to None.
30
+ node_name (str): The name of the node. Defaults to "GenerateAnswer".
31
+
32
+ Attributes:
33
+ llm_model: The language model specified in the node configuration.
34
+ verbose (bool): Whether verbose mode is enabled.
35
+ force (bool): Whether to force certain behaviors, overriding defaults.
36
+ script_creator (bool): Whether the node is in script creation mode.
37
+ is_md_scraper (bool): Whether the node is scraping markdown data.
38
+ additional_info (Optional[str]): Any additional information to be
39
+ included in the prompt templates.
40
+ """
18
41
def __init__ (
19
42
self ,
20
43
input : str ,
@@ -34,7 +57,17 @@ def __init__(
34
57
self .is_md_scraper = node_config .get ("is_md_scraper" , False )
35
58
self .additional_info = node_config .get ("additional_info" )
36
59
37
- def execute (self , state : dict ) -> dict :
60
+ async def execute (self , state : dict ) -> dict :
61
+ """
62
+ Executes the GenerateAnswerNode.
63
+
64
+ Args:
65
+ state (dict): The current state of the graph. The input keys will be used
66
+ to fetch the correct data from the state.
67
+
68
+ Returns:
69
+ dict: The updated state with the output key containing the generated answer.
70
+ """
38
71
self .logger .info (f"--- Executing { self .node_name } Node ---" )
39
72
40
73
input_keys = self .get_input_keys (state )
@@ -90,7 +123,7 @@ def execute(self, state: dict) -> dict:
90
123
chain = prompt | self .llm_model
91
124
if output_parser :
92
125
chain = chain | output_parser
93
- answer = chain .ainvoke ({"question" : user_prompt })
126
+ answer = await chain .ainvoke ({"question" : user_prompt })
94
127
95
128
state .update ({self .output [0 ]: answer })
96
129
return state
@@ -110,7 +143,7 @@ def execute(self, state: dict) -> dict:
110
143
chains_dict [chain_name ] = chains_dict [chain_name ] | output_parser
111
144
112
145
async_runner = RunnableParallel (** chains_dict )
113
- batch_results = async_runner .invoke ({"question" : user_prompt })
146
+ batch_results = await async_runner .ainvoke ({"question" : user_prompt })
114
147
115
148
merge_prompt = PromptTemplate (
116
149
template = template_merge_prompt ,
@@ -121,7 +154,7 @@ def execute(self, state: dict) -> dict:
121
154
merge_chain = merge_prompt | self .llm_model
122
155
if output_parser :
123
156
merge_chain = merge_chain | output_parser
124
- answer = merge_chain .ainvoke ({"context" : batch_results , "question" : user_prompt })
157
+ answer = await merge_chain .ainvoke ({"context" : batch_results , "question" : user_prompt })
125
158
126
159
state .update ({self .output [0 ]: answer })
127
160
return state
0 commit comments