Skip to content

feat: Concat node implementation #632

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions examples/google_genai/smart_scraper_multi_concat_gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
Basic example of scraping pipeline using SmartScraper
"""

import os
import json
from dotenv import load_dotenv
from scrapegraphai.graphs import SmartScraperMultiConcatGraph

load_dotenv()

# ************************************************
# Define the configuration for the graph
# ************************************************

gemini_key = os.getenv("GOOGLE_APIKEY")

graph_config = {
"llm": {
"api_key": gemini_key,
"model": "google_genai/gemini-pro",
},
}

# *******************************************************
# Create the SmartScraperMultiGraph instance and run it
# *******************************************************

multiple_search_graph = SmartScraperMultiConcatGraph(
prompt="Who is Marco Perini?",
source= [
"https://perinim.github.io/",
"https://perinim.github.io/cv/"
],
schema=None,
config=graph_config
)

result = multiple_search_graph.run()
print(json.dumps(result, indent=4))
42 changes: 42 additions & 0 deletions examples/local_models/smart_scraper_multi_concat_ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
Basic example of scraping pipeline using SmartScraper
"""

import os
import json
from dotenv import load_dotenv
from scrapegraphai.graphs import SmartScraperMultiConcatGraph

load_dotenv()

# ************************************************
# Define the configuration for the graph
# ************************************************

graph_config = {
"llm": {
"model": "ollama/llama3.1",
"temperature": 0,
"format": "json", # Ollama needs the format to be specified explicitly
"base_url": "http://localhost:11434", # set ollama URL arbitrarily
},
"verbose": True,
"headless": False
}

# *******************************************************
# Create the SmartScraperMultiGraph instance and run it
# *******************************************************

multiple_search_graph = SmartScraperMultiConcatGraph(
prompt="Who is Marco Perini?",
source= [
"https://perinim.github.io/",
"https://perinim.github.io/cv/"
],
schema=None,
config=graph_config
)

result = multiple_search_graph.run()
print(json.dumps(result, indent=4))
1 change: 1 addition & 0 deletions scrapegraphai/graphs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@
from .markdown_scraper_multi_graph import MDScraperMultiGraph
from .search_link_graph import SearchLinkGraph
from .screenshot_scraper_graph import ScreenshotScraperGraph
from .smart_scraper_multi_concat_graph import SmartScraperMultiConcatGraph
115 changes: 115 additions & 0 deletions scrapegraphai/graphs/smart_scraper_multi_concat_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""
SmartScraperMultiGraph Module
"""

from copy import copy, deepcopy
from typing import List, Optional
from pydantic import BaseModel

from .base_graph import BaseGraph
from .abstract_graph import AbstractGraph
from .smart_scraper_graph import SmartScraperGraph

from ..nodes import (
GraphIteratorNode,
ConcatAnswersNode
)


class SmartScraperMultiConcatGraph(AbstractGraph):
"""
SmartScraperMultiGraph is a scraping pipeline that scrapes a list of URLs and generates answers to a given prompt.
It only requires a user prompt and a list of URLs.

Attributes:
prompt (str): The user prompt to search the internet.
llm_model (dict): The configuration for the language model.
embedder_model (dict): The configuration for the embedder model.
headless (bool): A flag to run the browser in headless mode.
verbose (bool): A flag to display the execution information.
model_token (int): The token limit for the language model.

Args:
prompt (str): The user prompt to search the internet.
source (List[str]): The source of the graph.
config (dict): Configuration parameters for the graph.
schema (Optional[BaseModel]): The schema for the graph output.

Example:
>>> search_graph = SmartScraperMultiConcatGraph(
... "What is Chioggia famous for?",
... {"llm": {"model": "gpt-3.5-turbo"}}
... )
>>> result = search_graph.run()
"""

def __init__(self, prompt: str, source: List[str], config: dict, schema: Optional[BaseModel] = None):

if all(isinstance(value, str) for value in config.values()):
self.copy_config = copy(config)
else:
self.copy_config = deepcopy(config)

self.copy_schema = deepcopy(schema)

super().__init__(prompt, config, source, schema)

def _create_graph(self) -> BaseGraph:
"""
Creates the graph of nodes representing the workflow for web scraping and searching.

Returns:
BaseGraph: A graph instance representing the web scraping and searching workflow.
"""

# ************************************************
# Create a SmartScraperGraph instance
# ************************************************

smart_scraper_instance = SmartScraperGraph(
prompt="",
source="",
config=self.copy_config,
schema=self.copy_schema
)

# ************************************************
# Define the graph nodes
# ************************************************

graph_iterator_node = GraphIteratorNode(
input="user_prompt & urls",
output=["results"],
node_config={
"graph_instance": smart_scraper_instance,
}
)

concat_answers_node = ConcatAnswersNode(
input="results",
output=["answer"]
)

return BaseGraph(
nodes=[
graph_iterator_node,
concat_answers_node,
],
edges=[
(graph_iterator_node, concat_answers_node),
],
entry_point=graph_iterator_node,
graph_name=self.__class__.__name__
)

def run(self) -> str:
"""
Executes the web scraping and searching process.

Returns:
str: The answer to the prompt.
"""
inputs = {"user_prompt": self.prompt, "urls": self.source}
self.final_state, self.execution_info = self.graph.execute(inputs)

return self.final_state.get("answer", "No answer found.")
1 change: 1 addition & 0 deletions scrapegraphai/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@
from .merge_generated_scripts import MergeGeneratedScriptsNode
from .fetch_screen_node import FetchScreenNode
from .generate_answer_from_image_node import GenerateAnswerFromImageNode
from .concat_answers_node import ConcatAnswersNode
76 changes: 76 additions & 0 deletions scrapegraphai/nodes/concat_answers_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""
ConcatAnswersNode Module
"""

from typing import List, Optional
from ..utils.logging import get_logger
from .base_node import BaseNode

class ConcatAnswersNode(BaseNode):
"""
A node responsible for concatenating the answers from multiple graph instances into a single answer.

Attributes:
verbose (bool): A flag indicating whether to show print statements during execution.

Args:
input (str): Boolean expression defining the input keys needed from the state.
output (List[str]): List of output keys to be updated in the state.
node_config (dict): Additional configuration for the node.
node_name (str): The unique identifier name for the node, defaulting to "GenerateAnswer".
"""

def __init__(
self,
input: str,
output: List[str],
node_config: Optional[dict] = None,
node_name: str = "ConcatAnswers",
):
super().__init__(node_name, "node", input, output, 1, node_config)

self.verbose = (
False if node_config is None else node_config.get("verbose", False)
)

def _merge_dict(self, items):

return {"products": {f"item_{i+1}": item for i, item in enumerate(items)}}

def execute(self, state: dict) -> dict:
"""
Executes the node's logic to concatenate the answers from multiple graph instances into a
single answer.

Args:
state (dict): The current state of the graph. The input keys will be used
to fetch the correct data from the state.

Returns:
dict: The updated state with the output key containing the generated answer.

Raises:
KeyError: If the input keys are not found in the state, indicating
that the necessary information for generating an answer is missing.
"""

self.logger.info(f"--- Executing {self.node_name} Node ---")

# Interpret input keys based on the provided input expression
input_keys = self.get_input_keys(state)

# Fetching data from the state based on the input keys
input_data = [state[key] for key in input_keys]

answers = input_data[0]

if len(answers) > 1:
# merge the answers in one string
answer = self._merge_dict(answers)

# Update the state with the generated answer
state.update({self.output[0]: answer})

else:
state.update({self.output[0]: answers[0]})
return state