Skip to content

Fixed pydantic errors when using with_strctured_output #626

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions scrapegraphai/nodes/generate_answer_csv_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_openai import ChatOpenAI
from langchain_mistralai import ChatMistralAI
from tqdm import tqdm
from ..utils.logging import get_logger
from .base_node import BaseNode
from ..prompts.generate_answer_node_csv_prompts import (TEMPLATE_CHUKS_CSV,
TEMPLATE_NO_CHUKS_CSV, TEMPLATE_MERGE_CSV)
from ..prompts import TEMPLATE_CHUKS_CSV, TEMPLATE_NO_CHUKS_CSV, TEMPLATE_MERGE_CSV

class GenerateAnswerCSVNode(BaseNode):
"""
Expand Down Expand Up @@ -92,9 +94,24 @@ def execute(self, state):

# Initialize the output parser
if self.node_config.get("schema", None) is not None:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])

if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
self.llm_model = self.llm_model.with_structured_output(
schema = self.node_config["schema"],
method="function_calling") # json schema works only on specific models

# default parser to empty lambda function
output_parser = lambda x: x
if is_basemodel_subclass(self.node_config["schema"]):
output_parser = dict
format_instructions = "NA"
else:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
format_instructions = output_parser.get_format_instructions()

else:
output_parser = JsonOutputParser()
format_instructions = output_parser.get_format_instructions()

TEMPLATE_NO_CHUKS_CSV_PROMPT = TEMPLATE_NO_CHUKS_CSV
TEMPLATE_CHUKS_CSV_PROMPT = TEMPLATE_CHUKS_CSV
Expand All @@ -105,8 +122,6 @@ def execute(self, state):
TEMPLATE_CHUKS_CSV_PROMPT = self.additional_info + TEMPLATE_CHUKS_CSV
TEMPLATE_MERGE_CSV_PROMPT = self.additional_info + TEMPLATE_MERGE_CSV

format_instructions = output_parser.get_format_instructions()

chains_dict = {}

if len(doc) == 1:
Expand Down
15 changes: 10 additions & 5 deletions scrapegraphai/nodes/generate_answer_node.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
"""
GenerateAnswerNode Module
"""
from sys import modules
from typing import List, Optional
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_openai import ChatOpenAI, AzureChatOpenAI
from langchain_mistralai import ChatMistralAI
from langchain_community.chat_models import ChatOllama
from tqdm import tqdm
from ..utils.logging import get_logger
from .base_node import BaseNode
from ..prompts import TEMPLATE_CHUNKS, TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE, TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD, TEMPLATE_MERGE_MD

Expand Down Expand Up @@ -91,14 +90,20 @@ def execute(self, state: dict) -> dict:
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
self.llm_model = self.llm_model.with_structured_output(
schema = self.node_config["schema"],
method="json_schema")
method="function_calling") # json schema works only on specific models

# default parser to empty lambda function
output_parser = lambda x: x
if is_basemodel_subclass(self.node_config["schema"]):
output_parser = dict
format_instructions = "NA"
else:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
format_instructions = output_parser.get_format_instructions()

else:
output_parser = JsonOutputParser()

format_instructions = output_parser.get_format_instructions()
format_instructions = output_parser.get_format_instructions()

if isinstance(self.llm_model, (ChatOpenAI, AzureChatOpenAI)) and not self.script_creator or self.force and not self.script_creator or self.is_md_scraper:
template_no_chunks_prompt = TEMPLATE_NO_CHUNKS_MD
Expand Down
22 changes: 20 additions & 2 deletions scrapegraphai/nodes/generate_answer_omni_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_openai import ChatOpenAI
from langchain_mistralai import ChatMistralAI
from tqdm import tqdm
from langchain_community.chat_models import ChatOllama
from .base_node import BaseNode
Expand Down Expand Up @@ -78,9 +81,25 @@ def execute(self, state: dict) -> dict:

# Initialize the output parser
if self.node_config.get("schema", None) is not None:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])

if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
self.llm_model = self.llm_model.with_structured_output(
schema = self.node_config["schema"],
method="function_calling") # json schema works only on specific models

# default parser to empty lambda function
output_parser = lambda x: x
if is_basemodel_subclass(self.node_config["schema"]):
output_parser = dict
format_instructions = "NA"
else:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
format_instructions = output_parser.get_format_instructions()

else:
output_parser = JsonOutputParser()
format_instructions = output_parser.get_format_instructions()

TEMPLATE_NO_CHUNKS_OMNI_prompt = TEMPLATE_NO_CHUNKS_OMNI
TEMPLATE_CHUNKS_OMNI_prompt = TEMPLATE_CHUNKS_OMNI
TEMPLATE_MERGE_OMNI_prompt= TEMPLATE_MERGE_OMNI
Expand All @@ -90,7 +109,6 @@ def execute(self, state: dict) -> dict:
TEMPLATE_CHUNKS_OMNI_prompt = self.additional_info + TEMPLATE_CHUNKS_OMNI_prompt
TEMPLATE_MERGE_OMNI_prompt = self.additional_info + TEMPLATE_MERGE_OMNI_prompt

format_instructions = output_parser.get_format_instructions()


chains_dict = {}
Expand Down
23 changes: 20 additions & 3 deletions scrapegraphai/nodes/generate_answer_pdf_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_openai import ChatOpenAI
from langchain_mistralai import ChatMistralAI
from tqdm import tqdm
from langchain_community.chat_models import ChatOllama
from ..utils.logging import get_logger
Expand Down Expand Up @@ -93,9 +96,25 @@ def execute(self, state):

# Initialize the output parser
if self.node_config.get("schema", None) is not None:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])

if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
self.llm_model = self.llm_model.with_structured_output(
schema = self.node_config["schema"],
method="function_calling") # json schema works only on specific models

# default parser to empty lambda function
output_parser = lambda x: x
if is_basemodel_subclass(self.node_config["schema"]):
output_parser = dict
format_instructions = "NA"
else:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
format_instructions = output_parser.get_format_instructions()

else:
output_parser = JsonOutputParser()
format_instructions = output_parser.get_format_instructions()

TEMPLATE_NO_CHUNKS_PDF_prompt = TEMPLATE_NO_CHUNKS_PDF
TEMPLATE_CHUNKS_PDF_prompt = TEMPLATE_CHUNKS_PDF
TEMPLATE_MERGE_PDF_prompt = TEMPLATE_MERGE_PDF
Expand All @@ -105,8 +124,6 @@ def execute(self, state):
TEMPLATE_CHUNKS_PDF_prompt = self.additional_info + TEMPLATE_CHUNKS_PDF_prompt
TEMPLATE_MERGE_PDF_prompt = self.additional_info + TEMPLATE_MERGE_PDF_prompt

format_instructions = output_parser.get_format_instructions()

if len(doc) == 1:
prompt = PromptTemplate(
template=TEMPLATE_NO_CHUNKS_PDF_prompt,
Expand Down
22 changes: 19 additions & 3 deletions scrapegraphai/nodes/merge_answers_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from typing import List, Optional
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.utils.pydantic import is_basemodel_subclass
from langchain_openai import ChatOpenAI
from langchain_mistralai import ChatMistralAI
from ..utils.logging import get_logger
from .base_node import BaseNode
from ..prompts import TEMPLATE_COMBINED
Expand Down Expand Up @@ -68,11 +71,24 @@ def execute(self, state: dict) -> dict:
answers_str += f"CONTENT WEBSITE {i+1}: {answer}\n"

if self.node_config.get("schema", None) is not None:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])

if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
self.llm_model = self.llm_model.with_structured_output(
schema = self.node_config["schema"],
method="function_calling") # json schema works only on specific models

# default parser to empty lambda function
output_parser = lambda x: x
if is_basemodel_subclass(self.node_config["schema"]):
output_parser = dict
format_instructions = "NA"
else:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
format_instructions = output_parser.get_format_instructions()

else:
output_parser = JsonOutputParser()

format_instructions = output_parser.get_format_instructions()
format_instructions = output_parser.get_format_instructions()

prompt_template = PromptTemplate(
template=TEMPLATE_COMBINED,
Expand Down