|
1 | 1 | """
|
2 | 2 | GenerateAnswerNode Module
|
3 | 3 | """
|
4 |
| - |
| 4 | +import asyncio |
5 | 5 | from typing import List, Optional
|
6 | 6 | from langchain.prompts import PromptTemplate
|
7 | 7 | from langchain_core.output_parsers import JsonOutputParser
|
@@ -107,44 +107,43 @@ def execute(self, state: dict) -> dict:
|
107 | 107 | template_chunks_prompt = self.additional_info + template_chunks_prompt
|
108 | 108 | template_merge_prompt = self.additional_info + template_merge_prompt
|
109 | 109 |
|
110 |
| - chains_dict = {} |
| 110 | + if len(doc) == 1: |
| 111 | + prompt = PromptTemplate( |
| 112 | + template=template_no_chunks_prompt, |
| 113 | + input_variables=["question"], |
| 114 | + partial_variables={"context": doc, |
| 115 | + "format_instructions": format_instructions}) |
| 116 | + chain = prompt | self.llm_model | output_parser |
| 117 | + answer = chain.invoke({"question": user_prompt}) |
| 118 | + |
| 119 | + state.update({self.output[0]: answer}) |
| 120 | + return state |
111 | 121 |
|
112 |
| - # Use tqdm to add progress bar |
| 122 | + chains_dict = {} |
113 | 123 | for i, chunk in enumerate(tqdm(doc, desc="Processing chunks", disable=not self.verbose)):
|
114 |
| - if len(doc) == 1: |
115 |
| - prompt = PromptTemplate( |
116 |
| - template=template_no_chunks_prompt, |
117 |
| - input_variables=["question"], |
118 |
| - partial_variables={"context": chunk, |
119 |
| - "format_instructions": format_instructions}) |
120 |
| - chain = prompt | self.llm_model | output_parser |
121 |
| - answer = chain.invoke({"question": user_prompt}) |
122 |
| - break |
123 | 124 |
|
124 | 125 | prompt = PromptTemplate(
|
125 |
| - template=template_chunks_prompt, |
126 |
| - input_variables=["question"], |
127 |
| - partial_variables={"context": chunk, |
128 |
| - "chunk_id": i + 1, |
129 |
| - "format_instructions": format_instructions}) |
130 |
| - # Dynamically name the chains based on their index |
| 126 | + template=template_chunks, |
| 127 | + input_variables=["question"], |
| 128 | + partial_variables={"context": chunk, |
| 129 | + "chunk_id": i + 1, |
| 130 | + "format_instructions": format_instructions}) |
| 131 | + # Add chain to dictionary with dynamic name |
131 | 132 | chain_name = f"chunk{i+1}"
|
132 | 133 | chains_dict[chain_name] = prompt | self.llm_model | output_parser
|
133 | 134 |
|
134 |
| - if len(chains_dict) > 1: |
135 |
| - # Use dictionary unpacking to pass the dynamically named chains to RunnableParallel |
136 |
| - map_chain = RunnableParallel(**chains_dict) |
137 |
| - # Chain |
138 |
| - answer = map_chain.invoke({"question": user_prompt}) |
139 |
| - # Merge the answers from the chunks |
140 |
| - merge_prompt = PromptTemplate( |
| 135 | + async_runner = RunnableParallel(**chains_dict) |
| 136 | + |
| 137 | + batch_results = async_runner.invoke({"question": user_prompt}) |
| 138 | + |
| 139 | + merge_prompt = PromptTemplate( |
141 | 140 | template = template_merge_prompt,
|
142 | 141 | input_variables=["context", "question"],
|
143 | 142 | partial_variables={"format_instructions": format_instructions},
|
144 | 143 | )
|
145 |
| - merge_chain = merge_prompt | self.llm_model | output_parser |
146 |
| - answer = merge_chain.invoke({"context": answer, "question": user_prompt}) |
147 | 144 |
|
148 |
| - # Update the state with the generated answer |
| 145 | + merge_chain = merge_prompt | self.llm_model | output_parser |
| 146 | + answer = merge_chain.invoke({"context": batch_results, "question": user_prompt}) |
| 147 | + |
149 | 148 | state.update({self.output[0]: answer})
|
150 | 149 | return state
|
0 commit comments