Skip to content

Commit d745564

Browse files
authored
Merge pull request #485 from ScrapeGraphAI/generate_answer_parallel
merge: generate answer parallel
2 parents 66f9421 + 2edad66 commit d745564

File tree

3 files changed

+32
-36
lines changed

3 files changed

+32
-36
lines changed

pyproject.toml

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,15 @@ name = "scrapegraphai"
33

44
version = "1.11.0b3"
55

6-
7-
86
description = "A web scraping library based on LangChain which uses LLM and direct graph logic to create scraping pipelines."
97
authors = [
108
{ name = "Marco Vinciguerra", email = "[email protected]" },
119
{ name = "Marco Perini", email = "[email protected]" },
1210
{ name = "Lorenzo Padoan", email = "[email protected]" }
1311
]
12+
1413
dependencies = [
1514
"langchain>=0.2.10",
16-
17-
"langchain-fireworks>=0.1.3",
18-
"langchain_community>=0.2.9",
1915
"langchain-google-genai>=1.0.7",
2016
"langchain-google-vertexai",
2117
"langchain-openai>=0.1.17",
@@ -37,6 +33,8 @@ dependencies = [
3733
"google>=3.0.0",
3834
"undetected-playwright>=0.3.0",
3935
"semchunk>=1.0.1",
36+
"langchain-fireworks>=0.1.3",
37+
"langchain-community>=0.2.9"
4038
]
4139

4240
license = "MIT"

requirements.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
langchain>=0.2.10
2-
langchain_community>=0.2.9
32
langchain-google-genai>=1.0.7
4-
langchain-fireworks>=0.1.3
53
langchain-google-vertexai
64
langchain-openai>=0.1.17
75
langchain-groq>=0.1.3
@@ -22,4 +20,5 @@ playwright>=1.43.0
2220
google>=3.0.0
2321
undetected-playwright>=0.3.0
2422
semchunk>=1.0.1
25-
23+
langchain-fireworks>=0.1.3
24+
langchain-community>=0.2.9

scrapegraphai/nodes/generate_answer_node.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
GenerateAnswerNode Module
33
"""
4-
4+
import asyncio
55
from typing import List, Optional
66
from langchain.prompts import PromptTemplate
77
from langchain_core.output_parsers import JsonOutputParser
@@ -107,44 +107,43 @@ def execute(self, state: dict) -> dict:
107107
template_chunks_prompt = self.additional_info + template_chunks_prompt
108108
template_merge_prompt = self.additional_info + template_merge_prompt
109109

110-
chains_dict = {}
110+
if len(doc) == 1:
111+
prompt = PromptTemplate(
112+
template=template_no_chunks_prompt,
113+
input_variables=["question"],
114+
partial_variables={"context": doc,
115+
"format_instructions": format_instructions})
116+
chain = prompt | self.llm_model | output_parser
117+
answer = chain.invoke({"question": user_prompt})
118+
119+
state.update({self.output[0]: answer})
120+
return state
111121

112-
# Use tqdm to add progress bar
122+
chains_dict = {}
113123
for i, chunk in enumerate(tqdm(doc, desc="Processing chunks", disable=not self.verbose)):
114-
if len(doc) == 1:
115-
prompt = PromptTemplate(
116-
template=template_no_chunks_prompt,
117-
input_variables=["question"],
118-
partial_variables={"context": chunk,
119-
"format_instructions": format_instructions})
120-
chain = prompt | self.llm_model | output_parser
121-
answer = chain.invoke({"question": user_prompt})
122-
break
123124

124125
prompt = PromptTemplate(
125-
template=template_chunks_prompt,
126-
input_variables=["question"],
127-
partial_variables={"context": chunk,
128-
"chunk_id": i + 1,
129-
"format_instructions": format_instructions})
130-
# Dynamically name the chains based on their index
126+
template=template_chunks,
127+
input_variables=["question"],
128+
partial_variables={"context": chunk,
129+
"chunk_id": i + 1,
130+
"format_instructions": format_instructions})
131+
# Add chain to dictionary with dynamic name
131132
chain_name = f"chunk{i+1}"
132133
chains_dict[chain_name] = prompt | self.llm_model | output_parser
133134

134-
if len(chains_dict) > 1:
135-
# Use dictionary unpacking to pass the dynamically named chains to RunnableParallel
136-
map_chain = RunnableParallel(**chains_dict)
137-
# Chain
138-
answer = map_chain.invoke({"question": user_prompt})
139-
# Merge the answers from the chunks
140-
merge_prompt = PromptTemplate(
135+
async_runner = RunnableParallel(**chains_dict)
136+
137+
batch_results = async_runner.invoke({"question": user_prompt})
138+
139+
merge_prompt = PromptTemplate(
141140
template = template_merge_prompt,
142141
input_variables=["context", "question"],
143142
partial_variables={"format_instructions": format_instructions},
144143
)
145-
merge_chain = merge_prompt | self.llm_model | output_parser
146-
answer = merge_chain.invoke({"context": answer, "question": user_prompt})
147144

148-
# Update the state with the generated answer
145+
merge_chain = merge_prompt | self.llm_model | output_parser
146+
answer = merge_chain.invoke({"context": batch_results, "question": user_prompt})
147+
149148
state.update({self.output[0]: answer})
150149
return state

0 commit comments

Comments
 (0)