Skip to content

Commit 66ea166

Browse files
fix: Added support for nested structure
1 parent 039ba2e commit 66ea166

File tree

6 files changed

+93
-28
lines changed

6 files changed

+93
-28
lines changed

scrapegraphai/nodes/generate_answer_csv_node.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from typing import List, Optional
6+
from pydantic.v1 import BaseModel as BaseModelV1
67
from langchain.prompts import PromptTemplate
78
from langchain_core.output_parsers import JsonOutputParser
89
from langchain_core.runnables import RunnableParallel
@@ -12,6 +13,7 @@
1213
from tqdm import tqdm
1314
from ..utils.logging import get_logger
1415
from .base_node import BaseNode
16+
from ..utils.llm_output_parser import typed_dict_output_parser, base_model_v2_output_parser, base_model_v1_output_parser
1517
from ..prompts import TEMPLATE_CHUKS_CSV, TEMPLATE_NO_CHUKS_CSV, TEMPLATE_MERGE_CSV
1618

1719
class GenerateAnswerCSVNode(BaseNode):
@@ -97,13 +99,13 @@ def execute(self, state):
9799

98100
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
99101
self.llm_model = self.llm_model.with_structured_output(
100-
schema = self.node_config["schema"],
101-
method="function_calling") # json schema works only on specific models
102-
103-
# default parser to empty lambda function
104-
output_parser = lambda x: x
102+
schema = self.node_config["schema"]) # json schema works only on specific models
103+
104+
output_parser = typed_dict_output_parser
105105
if is_basemodel_subclass(self.node_config["schema"]):
106-
output_parser = dict
106+
output_parser = base_model_v2_output_parser
107+
if issubclass(self.node_config["schema"], BaseModelV1):
108+
output_parser = base_model_v1_output_parser
107109
format_instructions = "NA"
108110
else:
109111
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])

scrapegraphai/nodes/generate_answer_node.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
GenerateAnswerNode Module
33
"""
44
from typing import List, Optional
5+
from pydantic.v1 import BaseModel as BaseModelV1
56
from langchain.prompts import PromptTemplate
67
from langchain_core.output_parsers import JsonOutputParser
78
from langchain_core.runnables import RunnableParallel
@@ -11,6 +12,7 @@
1112
from langchain_community.chat_models import ChatOllama
1213
from tqdm import tqdm
1314
from .base_node import BaseNode
15+
from ..utils.llm_output_parser import base_model_v1_output_parser, base_model_v2_output_parser, typed_dict_output_parser
1416
from ..prompts import (TEMPLATE_CHUNKS,
1517
TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE,
1618
TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD,
@@ -93,12 +95,12 @@ def execute(self, state: dict) -> dict:
9395
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
9496
self.llm_model = self.llm_model.with_structured_output(
9597
schema = self.node_config["schema"]) # json schema works only on specific models
96-
97-
# default parser to empty lambda function
98-
def output_parser(x):
99-
return x
98+
99+
output_parser = typed_dict_output_parser
100100
if is_basemodel_subclass(self.node_config["schema"]):
101-
output_parser = dict
101+
output_parser = base_model_v2_output_parser
102+
if issubclass(self.node_config["schema"], BaseModelV1):
103+
output_parser = base_model_v1_output_parser
102104
format_instructions = "NA"
103105
else:
104106
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])

scrapegraphai/nodes/generate_answer_omni_node.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
GenerateAnswerNode Module
33
"""
44
from typing import List, Optional
5+
from pydantic.v1 import BaseModel as BaseModelV1
56
from langchain.prompts import PromptTemplate
67
from langchain_core.output_parsers import JsonOutputParser
78
from langchain_core.runnables import RunnableParallel
@@ -11,6 +12,7 @@
1112
from tqdm import tqdm
1213
from langchain_community.chat_models import ChatOllama
1314
from .base_node import BaseNode
15+
from ..utils.llm_output_parser import typed_dict_output_parser, base_model_v2_output_parser, base_model_v1_output_parser
1416
from ..prompts.generate_answer_node_omni_prompts import (TEMPLATE_NO_CHUNKS_OMNI,
1517
TEMPLATE_CHUNKS_OMNI,
1618
TEMPLATE_MERGE_OMNI)
@@ -86,13 +88,13 @@ def execute(self, state: dict) -> dict:
8688

8789
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
8890
self.llm_model = self.llm_model.with_structured_output(
89-
schema = self.node_config["schema"],
90-
method="function_calling") # json schema works only on specific models
91-
92-
# default parser to empty lambda function
93-
output_parser = lambda x: x
91+
schema = self.node_config["schema"]) # json schema works only on specific models
92+
93+
output_parser = typed_dict_output_parser
9494
if is_basemodel_subclass(self.node_config["schema"]):
95-
output_parser = dict
95+
output_parser = base_model_v2_output_parser
96+
if issubclass(self.node_config["schema"], BaseModelV1):
97+
output_parser = base_model_v1_output_parser
9698
format_instructions = "NA"
9799
else:
98100
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])

scrapegraphai/nodes/generate_answer_pdf_node.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Module for generating the answer node
33
"""
44
from typing import List, Optional
5+
from pydantic.v1 import BaseModel as BaseModelV1
56
from langchain.prompts import PromptTemplate
67
from langchain_core.output_parsers import JsonOutputParser
78
from langchain_core.runnables import RunnableParallel
@@ -12,6 +13,7 @@
1213
from langchain_community.chat_models import ChatOllama
1314
from ..utils.logging import get_logger
1415
from .base_node import BaseNode
16+
from ..utils.llm_output_parser import typed_dict_output_parser, base_model_v2_output_parser, base_model_v1_output_parser
1517
from ..prompts.generate_answer_node_pdf_prompts import (TEMPLATE_CHUNKS_PDF,
1618
TEMPLATE_NO_CHUNKS_PDF,
1719
TEMPLATE_MERGE_PDF)
@@ -98,12 +100,13 @@ def execute(self, state):
98100

99101
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
100102
self.llm_model = self.llm_model.with_structured_output(
101-
schema = self.node_config["schema"],
102-
method="function_calling") # json schema works only on specific models
103-
104-
output_parser = lambda x: x
103+
schema = self.node_config["schema"]) # json schema works only on specific models
104+
105+
output_parser = typed_dict_output_parser
105106
if is_basemodel_subclass(self.node_config["schema"]):
106-
output_parser = dict
107+
output_parser = base_model_v2_output_parser
108+
if issubclass(self.node_config["schema"], BaseModelV1):
109+
output_parser = base_model_v1_output_parser
107110
format_instructions = "NA"
108111
else:
109112
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])

scrapegraphai/nodes/merge_answers_node.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
MergeAnswersNode Module
33
"""
44
from typing import List, Optional
5+
from pydantic.v1 import BaseModel as BaseModelV1
56
from langchain.prompts import PromptTemplate
67
from langchain_core.output_parsers import JsonOutputParser
78
from langchain_core.utils.pydantic import is_basemodel_subclass
@@ -10,6 +11,7 @@
1011
from ..utils.logging import get_logger
1112
from .base_node import BaseNode
1213
from ..prompts import TEMPLATE_COMBINED
14+
from ..utils.llm_output_parser import base_model_v1_output_parser, base_model_v2_output_parser, typed_dict_output_parser
1315

1416
class MergeAnswersNode(BaseNode):
1517
"""
@@ -74,12 +76,13 @@ def execute(self, state: dict) -> dict:
7476

7577
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
7678
self.llm_model = self.llm_model.with_structured_output(
77-
schema = self.node_config["schema"],
78-
method="function_calling") # json schema works only on specific models
79-
# default parser to empty lambda function
80-
output_parser = lambda x: x
79+
schema = self.node_config["schema"]) # json schema works only on specific models
80+
81+
output_parser = typed_dict_output_parser
8182
if is_basemodel_subclass(self.node_config["schema"]):
82-
output_parser = dict
83+
output_parser = base_model_v2_output_parser
84+
if issubclass(self.node_config["schema"], BaseModelV1):
85+
output_parser = base_model_v1_output_parser
8386
format_instructions = "NA"
8487
else:
8588
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
@@ -100,7 +103,7 @@ def execute(self, state: dict) -> dict:
100103

101104
merge_chain = prompt_template | self.llm_model | output_parser
102105
answer = merge_chain.invoke({"user_prompt": user_prompt})
103-
answer["sources"] = state.get("urls")
106+
answer["sources"] = state.get("urls", [])
104107

105108
state.update({self.output[0]: answer})
106109
return state
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""
2+
Custom output parser for the LLM model.
3+
"""
4+
from pydantic import BaseModel as BaseModelV2
5+
from pydantic.v1 import BaseModel as BaseModelV1
6+
7+
def base_model_v1_output_parser(x: BaseModelV1) -> dict:
8+
"""
9+
Parse the output of an LLM when the schema is a BaseModelv1 and `with_structured_output` is used.
10+
11+
Args:
12+
x (BaseModelV2 | BaseModelV1): The output from the LLM model.
13+
14+
Returns:
15+
dict: The parsed output.
16+
"""
17+
work_dict = x.dict()
18+
19+
# recursive dict parser
20+
def recursive_dict_parser(work_dict: dict) -> dict:
21+
dict_keys = work_dict.keys()
22+
for key in dict_keys:
23+
if isinstance(work_dict[key], BaseModelV1):
24+
work_dict[key] = work_dict[key].dict()
25+
recursive_dict_parser(work_dict[key])
26+
return work_dict
27+
28+
return recursive_dict_parser(work_dict)
29+
30+
31+
def base_model_v2_output_parser(x: BaseModelV2) -> dict:
32+
"""
33+
Parse the output of an LLM when the schema is a BaseModelv2 and `with_structured_output` is used.
34+
35+
Args:
36+
x (BaseModelV2): The output from the LLM model.
37+
38+
Returns:
39+
dict: The parsed output.
40+
"""
41+
return x.model_dump()
42+
43+
def typed_dict_output_parser(x: dict) -> dict:
44+
"""
45+
Parse the output of an LLM when the schema is a TypedDict and `with_structured_output` is used.
46+
47+
Args:
48+
x (dict): The output from the LLM model.
49+
50+
Returns:
51+
dict: The parsed output.
52+
"""
53+
return x

0 commit comments

Comments
 (0)