Skip to content

Commit 8e74ac5

Browse files
fix: correctly parsing output when using structured_output
1 parent 5e99071 commit 8e74ac5

File tree

5 files changed

+89
-18
lines changed

5 files changed

+89
-18
lines changed

scrapegraphai/nodes/generate_answer_csv_node.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
from langchain.prompts import PromptTemplate
77
from langchain_core.output_parsers import JsonOutputParser
88
from langchain_core.runnables import RunnableParallel
9+
from langchain_core.utils.pydantic import is_basemodel_subclass
10+
from langchain_openai import ChatOpenAI
11+
from langchain_mistralai import ChatMistralAI
912
from tqdm import tqdm
1013
from ..utils.logging import get_logger
1114
from .base_node import BaseNode
12-
from ..prompts.generate_answer_node_csv_prompts import (TEMPLATE_CHUKS_CSV,
13-
TEMPLATE_NO_CHUKS_CSV, TEMPLATE_MERGE_CSV)
15+
from ..prompts import TEMPLATE_CHUKS_CSV, TEMPLATE_NO_CHUKS_CSV, TEMPLATE_MERGE_CSV
1416

1517
class GenerateAnswerCSVNode(BaseNode):
1618
"""
@@ -92,9 +94,24 @@ def execute(self, state):
9294

9395
# Initialize the output parser
9496
if self.node_config.get("schema", None) is not None:
95-
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
97+
98+
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
99+
self.llm_model = self.llm_model.with_structured_output(
100+
schema = self.node_config["schema"],
101+
method="function_calling") # json schema works only on specific models
102+
103+
# default parser to empty lambda function
104+
output_parser = lambda x: x
105+
if is_basemodel_subclass(self.node_config["schema"]):
106+
output_parser = dict
107+
format_instructions = "NA"
108+
else:
109+
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
110+
format_instructions = output_parser.get_format_instructions()
111+
96112
else:
97113
output_parser = JsonOutputParser()
114+
format_instructions = output_parser.get_format_instructions()
98115

99116
TEMPLATE_NO_CHUKS_CSV_PROMPT = TEMPLATE_NO_CHUKS_CSV
100117
TEMPLATE_CHUKS_CSV_PROMPT = TEMPLATE_CHUKS_CSV
@@ -105,8 +122,6 @@ def execute(self, state):
105122
TEMPLATE_CHUKS_CSV_PROMPT = self.additional_info + TEMPLATE_CHUKS_CSV
106123
TEMPLATE_MERGE_CSV_PROMPT = self.additional_info + TEMPLATE_MERGE_CSV
107124

108-
format_instructions = output_parser.get_format_instructions()
109-
110125
chains_dict = {}
111126

112127
if len(doc) == 1:

scrapegraphai/nodes/generate_answer_node.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
"""
22
GenerateAnswerNode Module
33
"""
4-
from sys import modules
54
from typing import List, Optional
65
from langchain.prompts import PromptTemplate
76
from langchain_core.output_parsers import JsonOutputParser
87
from langchain_core.runnables import RunnableParallel
8+
from langchain_core.utils.pydantic import is_basemodel_subclass
99
from langchain_openai import ChatOpenAI, AzureChatOpenAI
1010
from langchain_mistralai import ChatMistralAI
1111
from langchain_community.chat_models import ChatOllama
1212
from tqdm import tqdm
13-
from ..utils.logging import get_logger
1413
from .base_node import BaseNode
1514
from ..prompts import TEMPLATE_CHUNKS, TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE, TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD, TEMPLATE_MERGE_MD
1615

@@ -91,14 +90,20 @@ def execute(self, state: dict) -> dict:
9190
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
9291
self.llm_model = self.llm_model.with_structured_output(
9392
schema = self.node_config["schema"],
94-
method="json_schema")
93+
method="function_calling") # json schema works only on specific models
94+
95+
# default parser to empty lambda function
96+
output_parser = lambda x: x
97+
if is_basemodel_subclass(self.node_config["schema"]):
98+
output_parser = dict
99+
format_instructions = "NA"
95100
else:
96101
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
102+
format_instructions = output_parser.get_format_instructions()
97103

98104
else:
99105
output_parser = JsonOutputParser()
100-
101-
format_instructions = output_parser.get_format_instructions()
106+
format_instructions = output_parser.get_format_instructions()
102107

103108
if isinstance(self.llm_model, (ChatOpenAI, AzureChatOpenAI)) and not self.script_creator or self.force and not self.script_creator or self.is_md_scraper:
104109
template_no_chunks_prompt = TEMPLATE_NO_CHUNKS_MD

scrapegraphai/nodes/generate_answer_omni_node.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from langchain.prompts import PromptTemplate
66
from langchain_core.output_parsers import JsonOutputParser
77
from langchain_core.runnables import RunnableParallel
8+
from langchain_core.utils.pydantic import is_basemodel_subclass
9+
from langchain_openai import ChatOpenAI
10+
from langchain_mistralai import ChatMistralAI
811
from tqdm import tqdm
912
from langchain_community.chat_models import ChatOllama
1013
from .base_node import BaseNode
@@ -78,9 +81,25 @@ def execute(self, state: dict) -> dict:
7881

7982
# Initialize the output parser
8083
if self.node_config.get("schema", None) is not None:
81-
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
84+
85+
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
86+
self.llm_model = self.llm_model.with_structured_output(
87+
schema = self.node_config["schema"],
88+
method="function_calling") # json schema works only on specific models
89+
90+
# default parser to empty lambda function
91+
output_parser = lambda x: x
92+
if is_basemodel_subclass(self.node_config["schema"]):
93+
output_parser = dict
94+
format_instructions = "NA"
95+
else:
96+
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
97+
format_instructions = output_parser.get_format_instructions()
98+
8299
else:
83100
output_parser = JsonOutputParser()
101+
format_instructions = output_parser.get_format_instructions()
102+
84103
TEMPLATE_NO_CHUNKS_OMNI_prompt = TEMPLATE_NO_CHUNKS_OMNI
85104
TEMPLATE_CHUNKS_OMNI_prompt = TEMPLATE_CHUNKS_OMNI
86105
TEMPLATE_MERGE_OMNI_prompt= TEMPLATE_MERGE_OMNI
@@ -90,7 +109,6 @@ def execute(self, state: dict) -> dict:
90109
TEMPLATE_CHUNKS_OMNI_prompt = self.additional_info + TEMPLATE_CHUNKS_OMNI_prompt
91110
TEMPLATE_MERGE_OMNI_prompt = self.additional_info + TEMPLATE_MERGE_OMNI_prompt
92111

93-
format_instructions = output_parser.get_format_instructions()
94112

95113

96114
chains_dict = {}

scrapegraphai/nodes/generate_answer_pdf_node.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from langchain.prompts import PromptTemplate
66
from langchain_core.output_parsers import JsonOutputParser
77
from langchain_core.runnables import RunnableParallel
8+
from langchain_core.utils.pydantic import is_basemodel_subclass
9+
from langchain_openai import ChatOpenAI
10+
from langchain_mistralai import ChatMistralAI
811
from tqdm import tqdm
912
from langchain_community.chat_models import ChatOllama
1013
from ..utils.logging import get_logger
@@ -93,9 +96,25 @@ def execute(self, state):
9396

9497
# Initialize the output parser
9598
if self.node_config.get("schema", None) is not None:
96-
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
99+
100+
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
101+
self.llm_model = self.llm_model.with_structured_output(
102+
schema = self.node_config["schema"],
103+
method="function_calling") # json schema works only on specific models
104+
105+
# default parser to empty lambda function
106+
output_parser = lambda x: x
107+
if is_basemodel_subclass(self.node_config["schema"]):
108+
output_parser = dict
109+
format_instructions = "NA"
110+
else:
111+
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
112+
format_instructions = output_parser.get_format_instructions()
113+
97114
else:
98115
output_parser = JsonOutputParser()
116+
format_instructions = output_parser.get_format_instructions()
117+
99118
TEMPLATE_NO_CHUNKS_PDF_prompt = TEMPLATE_NO_CHUNKS_PDF
100119
TEMPLATE_CHUNKS_PDF_prompt = TEMPLATE_CHUNKS_PDF
101120
TEMPLATE_MERGE_PDF_prompt = TEMPLATE_MERGE_PDF
@@ -105,8 +124,6 @@ def execute(self, state):
105124
TEMPLATE_CHUNKS_PDF_prompt = self.additional_info + TEMPLATE_CHUNKS_PDF_prompt
106125
TEMPLATE_MERGE_PDF_prompt = self.additional_info + TEMPLATE_MERGE_PDF_prompt
107126

108-
format_instructions = output_parser.get_format_instructions()
109-
110127
if len(doc) == 1:
111128
prompt = PromptTemplate(
112129
template=TEMPLATE_NO_CHUNKS_PDF_prompt,

scrapegraphai/nodes/merge_answers_node.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from typing import List, Optional
55
from langchain.prompts import PromptTemplate
66
from langchain_core.output_parsers import JsonOutputParser
7+
from langchain_core.utils.pydantic import is_basemodel_subclass
8+
from langchain_openai import ChatOpenAI
9+
from langchain_mistralai import ChatMistralAI
710
from ..utils.logging import get_logger
811
from .base_node import BaseNode
912
from ..prompts import TEMPLATE_COMBINED
@@ -68,11 +71,24 @@ def execute(self, state: dict) -> dict:
6871
answers_str += f"CONTENT WEBSITE {i+1}: {answer}\n"
6972

7073
if self.node_config.get("schema", None) is not None:
71-
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
74+
75+
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
76+
self.llm_model = self.llm_model.with_structured_output(
77+
schema = self.node_config["schema"],
78+
method="function_calling") # json schema works only on specific models
79+
80+
# default parser to empty lambda function
81+
output_parser = lambda x: x
82+
if is_basemodel_subclass(self.node_config["schema"]):
83+
output_parser = dict
84+
format_instructions = "NA"
85+
else:
86+
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
87+
format_instructions = output_parser.get_format_instructions()
88+
7289
else:
7390
output_parser = JsonOutputParser()
74-
75-
format_instructions = output_parser.get_format_instructions()
91+
format_instructions = output_parser.get_format_instructions()
7692

7793
prompt_template = PromptTemplate(
7894
template=TEMPLATE_COMBINED,

0 commit comments

Comments
 (0)