Skip to content

Commit 28b85a3

Browse files
refactor: Output parser code
1 parent 4f8b55d commit 28b85a3

File tree

7 files changed

+101
-102
lines changed

7 files changed

+101
-102
lines changed

scrapegraphai/nodes/generate_answer_csv_node.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,14 @@
33
"""
44

55
from typing import List, Optional
6-
from pydantic.v1 import BaseModel as BaseModelV1
76
from langchain.prompts import PromptTemplate
87
from langchain_core.output_parsers import JsonOutputParser
98
from langchain_core.runnables import RunnableParallel
10-
from langchain_core.utils.pydantic import is_basemodel_subclass
119
from langchain_openai import ChatOpenAI
1210
from langchain_mistralai import ChatMistralAI
1311
from tqdm import tqdm
14-
from ..utils.logging import get_logger
1512
from .base_node import BaseNode
16-
from ..utils.llm_output_parser import typed_dict_output_parser, base_model_v2_output_parser, base_model_v1_output_parser
13+
from ..utils.output_parser import get_structured_output_parser, get_pydantic_output_parser
1714
from ..prompts import TEMPLATE_CHUKS_CSV, TEMPLATE_NO_CHUKS_CSV, TEMPLATE_MERGE_CSV
1815

1916
class GenerateAnswerCSVNode(BaseNode):
@@ -101,14 +98,10 @@ def execute(self, state):
10198
self.llm_model = self.llm_model.with_structured_output(
10299
schema = self.node_config["schema"]) # json schema works only on specific models
103100

104-
output_parser = typed_dict_output_parser
105-
if is_basemodel_subclass(self.node_config["schema"]):
106-
output_parser = base_model_v2_output_parser
107-
if issubclass(self.node_config["schema"], BaseModelV1):
108-
output_parser = base_model_v1_output_parser
101+
output_parser = get_structured_output_parser(self.node_config["schema"])
109102
format_instructions = "NA"
110103
else:
111-
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
104+
output_parser = get_pydantic_output_parser(self.node_config["schema"])
112105
format_instructions = output_parser.get_format_instructions()
113106

114107
else:

scrapegraphai/nodes/generate_answer_node.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,15 @@
22
GenerateAnswerNode Module
33
"""
44
from typing import List, Optional
5-
from pydantic.v1 import BaseModel as BaseModelV1
65
from langchain.prompts import PromptTemplate
76
from langchain_core.output_parsers import JsonOutputParser
87
from langchain_core.runnables import RunnableParallel
9-
from langchain_core.utils.pydantic import is_basemodel_subclass
108
from langchain_openai import ChatOpenAI, AzureChatOpenAI
119
from langchain_mistralai import ChatMistralAI
1210
from langchain_community.chat_models import ChatOllama
1311
from tqdm import tqdm
1412
from .base_node import BaseNode
15-
from ..utils.llm_output_parser import base_model_v1_output_parser, base_model_v2_output_parser, typed_dict_output_parser
13+
from ..utils.output_parser import get_structured_output_parser, get_pydantic_output_parser
1614
from ..prompts import (TEMPLATE_CHUNKS,
1715
TEMPLATE_NO_CHUNKS, TEMPLATE_MERGE,
1816
TEMPLATE_CHUNKS_MD, TEMPLATE_NO_CHUNKS_MD,
@@ -95,15 +93,11 @@ def execute(self, state: dict) -> dict:
9593
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI)):
9694
self.llm_model = self.llm_model.with_structured_output(
9795
schema = self.node_config["schema"]) # json schema works only on specific models
98-
99-
output_parser = typed_dict_output_parser
100-
if is_basemodel_subclass(self.node_config["schema"]):
101-
output_parser = base_model_v2_output_parser
102-
if issubclass(self.node_config["schema"], BaseModelV1):
103-
output_parser = base_model_v1_output_parser
96+
97+
output_parser = get_structured_output_parser(self.node_config["schema"])
10498
format_instructions = "NA"
10599
else:
106-
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
100+
output_parser = get_pydantic_output_parser(self.node_config["schema"])
107101
format_instructions = output_parser.get_format_instructions()
108102

109103
else:

scrapegraphai/nodes/generate_answer_omni_node.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,15 @@
22
GenerateAnswerNode Module
33
"""
44
from typing import List, Optional
5-
from pydantic.v1 import BaseModel as BaseModelV1
65
from langchain.prompts import PromptTemplate
76
from langchain_core.output_parsers import JsonOutputParser
87
from langchain_core.runnables import RunnableParallel
9-
from langchain_core.utils.pydantic import is_basemodel_subclass
108
from langchain_openai import ChatOpenAI
119
from langchain_mistralai import ChatMistralAI
1210
from tqdm import tqdm
1311
from langchain_community.chat_models import ChatOllama
1412
from .base_node import BaseNode
15-
from ..utils.llm_output_parser import typed_dict_output_parser, base_model_v2_output_parser, base_model_v1_output_parser
13+
from ..utils.output_parser import get_structured_output_parser, get_pydantic_output_parser
1614
from ..prompts.generate_answer_node_omni_prompts import (TEMPLATE_NO_CHUNKS_OMNI,
1715
TEMPLATE_CHUNKS_OMNI,
1816
TEMPLATE_MERGE_OMNI)
@@ -90,14 +88,10 @@ def execute(self, state: dict) -> dict:
9088
self.llm_model = self.llm_model.with_structured_output(
9189
schema = self.node_config["schema"]) # json schema works only on specific models
9290

93-
output_parser = typed_dict_output_parser
94-
if is_basemodel_subclass(self.node_config["schema"]):
95-
output_parser = base_model_v2_output_parser
96-
if issubclass(self.node_config["schema"], BaseModelV1):
97-
output_parser = base_model_v1_output_parser
91+
output_parser = get_structured_output_parser(self.node_config["schema"])
9892
format_instructions = "NA"
9993
else:
100-
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
94+
output_parser = get_pydantic_output_parser(self.node_config["schema"])
10195
format_instructions = output_parser.get_format_instructions()
10296

10397
else:

scrapegraphai/nodes/generate_answer_pdf_node.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,15 @@
22
Module for generating the answer node
33
"""
44
from typing import List, Optional
5-
from pydantic.v1 import BaseModel as BaseModelV1
65
from langchain.prompts import PromptTemplate
76
from langchain_core.output_parsers import JsonOutputParser
87
from langchain_core.runnables import RunnableParallel
9-
from langchain_core.utils.pydantic import is_basemodel_subclass
108
from langchain_openai import ChatOpenAI
119
from langchain_mistralai import ChatMistralAI
1210
from tqdm import tqdm
1311
from langchain_community.chat_models import ChatOllama
14-
from ..utils.logging import get_logger
1512
from .base_node import BaseNode
16-
from ..utils.llm_output_parser import typed_dict_output_parser, base_model_v2_output_parser, base_model_v1_output_parser
13+
from ..utils.output_parser import get_structured_output_parser, get_pydantic_output_parser
1714
from ..prompts.generate_answer_node_pdf_prompts import (TEMPLATE_CHUNKS_PDF,
1815
TEMPLATE_NO_CHUNKS_PDF,
1916
TEMPLATE_MERGE_PDF)
@@ -102,14 +99,10 @@ def execute(self, state):
10299
self.llm_model = self.llm_model.with_structured_output(
103100
schema = self.node_config["schema"]) # json schema works only on specific models
104101

105-
output_parser = typed_dict_output_parser
106-
if is_basemodel_subclass(self.node_config["schema"]):
107-
output_parser = base_model_v2_output_parser
108-
if issubclass(self.node_config["schema"], BaseModelV1):
109-
output_parser = base_model_v1_output_parser
102+
output_parser = get_structured_output_parser(self.node_config["schema"])
110103
format_instructions = "NA"
111104
else:
112-
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
105+
output_parser = get_pydantic_output_parser(self.node_config["schema"])
113106
format_instructions = output_parser.get_format_instructions()
114107

115108
else:

scrapegraphai/nodes/merge_answers_node.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,13 @@
22
MergeAnswersNode Module
33
"""
44
from typing import List, Optional
5-
from pydantic.v1 import BaseModel as BaseModelV1
65
from langchain.prompts import PromptTemplate
76
from langchain_core.output_parsers import JsonOutputParser
8-
from langchain_core.utils.pydantic import is_basemodel_subclass
97
from langchain_openai import ChatOpenAI
108
from langchain_mistralai import ChatMistralAI
11-
from ..utils.logging import get_logger
129
from .base_node import BaseNode
1310
from ..prompts import TEMPLATE_COMBINED
14-
from ..utils.llm_output_parser import base_model_v1_output_parser, base_model_v2_output_parser, typed_dict_output_parser
11+
from ..utils.output_parser import get_structured_output_parser, get_pydantic_output_parser
1512

1613
class MergeAnswersNode(BaseNode):
1714
"""
@@ -78,14 +75,10 @@ def execute(self, state: dict) -> dict:
7875
self.llm_model = self.llm_model.with_structured_output(
7976
schema = self.node_config["schema"]) # json schema works only on specific models
8077

81-
output_parser = typed_dict_output_parser
82-
if is_basemodel_subclass(self.node_config["schema"]):
83-
output_parser = base_model_v2_output_parser
84-
if issubclass(self.node_config["schema"], BaseModelV1):
85-
output_parser = base_model_v1_output_parser
78+
output_parser = get_structured_output_parser(self.node_config["schema"])
8679
format_instructions = "NA"
8780
else:
88-
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
81+
output_parser = get_pydantic_output_parser(self.node_config["schema"])
8982
format_instructions = output_parser.get_format_instructions()
9083

9184
else:

scrapegraphai/utils/llm_output_parser.py

Lines changed: 0 additions & 53 deletions
This file was deleted.

scrapegraphai/utils/output_parser.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
"""
2+
Functions to retrieve the correct output parser and format instructions for the LLM model.
3+
"""
4+
from pydantic import BaseModel as BaseModelV2
5+
from pydantic.v1 import BaseModel as BaseModelV1
6+
from typing import Union, Dict, Any, Type, Callable
7+
from langchain_core.output_parsers import JsonOutputParser
8+
9+
def get_structured_output_parser(schema: Union[Dict[str, Any], Type[BaseModelV1 | BaseModelV2], Type]) -> Callable:
10+
"""
11+
Get the correct output parser for the LLM model.
12+
13+
Returns:
14+
Callable: The output parser function.
15+
"""
16+
if issubclass(schema, BaseModelV1):
17+
return _base_model_v1_output_parser
18+
19+
if issubclass(schema, BaseModelV2):
20+
return _base_model_v2_output_parser
21+
22+
return _dict_output_parser
23+
24+
def get_pydantic_output_parser(schema: Union[Dict[str, Any], Type[BaseModelV1 | BaseModelV2], Type]) -> JsonOutputParser:
25+
"""
26+
Get the correct output parser for the LLM model.
27+
28+
Returns:
29+
JsonOutputParser: The output parser object.
30+
"""
31+
if issubclass(schema, BaseModelV1):
32+
raise ValueError("pydantic.v1 and langchain_core.pydantic_v1 are not supported with this LLM model. Please use pydantic v2 instead.")
33+
34+
if issubclass(schema, BaseModelV2):
35+
return JsonOutputParser(pydantic_object=schema)
36+
37+
raise ValueError("The schema is not a pydantic subclass. With this LLM model you must use a pydantic schemas.")
38+
39+
def _base_model_v1_output_parser(x: BaseModelV1) -> dict:
40+
"""
41+
Parse the output of an LLM when the schema is BaseModelv1.
42+
43+
Args:
44+
x (BaseModelV1): The output from the LLM model.
45+
46+
Returns:
47+
dict: The parsed output.
48+
"""
49+
work_dict = x.dict()
50+
51+
# recursive dict parser
52+
def recursive_dict_parser(work_dict: dict) -> dict:
53+
dict_keys = work_dict.keys()
54+
for key in dict_keys:
55+
if isinstance(work_dict[key], BaseModelV1):
56+
work_dict[key] = work_dict[key].dict()
57+
recursive_dict_parser(work_dict[key])
58+
return work_dict
59+
60+
return recursive_dict_parser(work_dict)
61+
62+
63+
def _base_model_v2_output_parser(x: BaseModelV2) -> dict:
64+
"""
65+
Parse the output of an LLM when the schema is BaseModelv2.
66+
67+
Args:
68+
x (BaseModelV2): The output from the LLM model.
69+
70+
Returns:
71+
dict: The parsed output.
72+
"""
73+
return x.model_dump()
74+
75+
def _dict_output_parser(x: dict) -> dict:
76+
"""
77+
Parse the output of an LLM when the schema is TypedDict or JsonSchema.
78+
79+
Args:
80+
x (dict): The output from the LLM model.
81+
82+
Returns:
83+
dict: The parsed output.
84+
"""
85+
return x

0 commit comments

Comments
 (0)