Skip to content

Generate answer parallel #485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,15 @@ name = "scrapegraphai"

version = "1.11.0b1"



description = "A web scraping library based on LangChain which uses LLM and direct graph logic to create scraping pipelines."
authors = [
{ name = "Marco Vinciguerra", email = "[email protected]" },
{ name = "Marco Perini", email = "[email protected]" },
{ name = "Lorenzo Padoan", email = "[email protected]" }
]

dependencies = [
"langchain>=0.2.10",

"langchain-fireworks>=0.1.3",
"langchain_community>=0.2.9",
"langchain-google-genai>=1.0.7",
"langchain-google-vertexai",
"langchain-openai>=0.1.17",
Expand All @@ -37,6 +33,8 @@ dependencies = [
"google>=3.0.0",
"undetected-playwright>=0.3.0",
"semchunk>=1.0.1",
"langchain-fireworks>=0.1.3",
"langchain-community>=0.2.9"
]

license = "MIT"
Expand Down
5 changes: 2 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
langchain>=0.2.10
langchain_community>=0.2.9
langchain-google-genai>=1.0.7
langchain-fireworks>=0.1.3
langchain-google-vertexai
langchain-openai>=0.1.17
langchain-groq>=0.1.3
Expand All @@ -22,4 +20,5 @@ playwright>=1.43.0
google>=3.0.0
undetected-playwright>=0.3.0
semchunk>=1.0.1

langchain-fireworks>=0.1.3
langchain-community>=0.2.9
55 changes: 27 additions & 28 deletions scrapegraphai/nodes/generate_answer_node.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
GenerateAnswerNode Module
"""

import asyncio
from typing import List, Optional
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
Expand Down Expand Up @@ -107,44 +107,43 @@ def execute(self, state: dict) -> dict:
template_chunks_prompt = self.additional_info + template_chunks_prompt
template_merge_prompt = self.additional_info + template_merge_prompt

chains_dict = {}
if len(doc) == 1:
prompt = PromptTemplate(
template=template_no_chunks_prompt,
input_variables=["question"],
partial_variables={"context": doc,
"format_instructions": format_instructions})
chain = prompt | self.llm_model | output_parser
answer = chain.invoke({"question": user_prompt})

state.update({self.output[0]: answer})
return state

# Use tqdm to add progress bar
chains_dict = {}
for i, chunk in enumerate(tqdm(doc, desc="Processing chunks", disable=not self.verbose)):
if len(doc) == 1:
prompt = PromptTemplate(
template=template_no_chunks_prompt,
input_variables=["question"],
partial_variables={"context": chunk,
"format_instructions": format_instructions})
chain = prompt | self.llm_model | output_parser
answer = chain.invoke({"question": user_prompt})
break

prompt = PromptTemplate(
template=template_chunks_prompt,
input_variables=["question"],
partial_variables={"context": chunk,
"chunk_id": i + 1,
"format_instructions": format_instructions})
# Dynamically name the chains based on their index
template=template_chunks,
input_variables=["question"],
partial_variables={"context": chunk,
"chunk_id": i + 1,
"format_instructions": format_instructions})
# Add chain to dictionary with dynamic name
chain_name = f"chunk{i+1}"
chains_dict[chain_name] = prompt | self.llm_model | output_parser

if len(chains_dict) > 1:
# Use dictionary unpacking to pass the dynamically named chains to RunnableParallel
map_chain = RunnableParallel(**chains_dict)
# Chain
answer = map_chain.invoke({"question": user_prompt})
# Merge the answers from the chunks
merge_prompt = PromptTemplate(
async_runner = RunnableParallel(**chains_dict)

batch_results = async_runner.invoke({"question": user_prompt})

merge_prompt = PromptTemplate(
template = template_merge_prompt,
input_variables=["context", "question"],
partial_variables={"format_instructions": format_instructions},
)
merge_chain = merge_prompt | self.llm_model | output_parser
answer = merge_chain.invoke({"context": answer, "question": user_prompt})

# Update the state with the generated answer
merge_chain = merge_prompt | self.llm_model | output_parser
answer = merge_chain.invoke({"context": batch_results, "question": user_prompt})

state.update({self.output[0]: answer})
return state
Loading