Skip to content

Commit d1f6b9f

Browse files
authored
Merge pull request #562 from ScrapeGraphAI/support_structured_output_shema_openai
Support structured output shema openai
2 parents 6a08cc8 + 7d2fc67 commit d1f6b9f

13 files changed

+44
-29
lines changed

examples/anthropic/search_graph_schema_haiku.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@
33
"""
44

55
import os
6+
from typing import List
67
from dotenv import load_dotenv
7-
load_dotenv()
8-
8+
from pydantic import BaseModel, Field
99
from scrapegraphai.graphs import SearchGraph
1010

11-
from pydantic import BaseModel, Field
12-
from typing import List
11+
load_dotenv()
1312

1413
# ************************************************
1514
# Define the output schema for the graph

examples/azure/smart_scraper_schema_azure.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
Basic example of scraping pipeline using SmartScraper with schema
33
"""
44

5-
import os, json
5+
import os
6+
import json
67
from typing import List
78
from pydantic import BaseModel, Field
89
from dotenv import load_dotenv

examples/local_models/smart_scraper_schema_ollama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class Projects(BaseModel):
1919

2020
graph_config = {
2121
"llm": {
22-
"model": "ollama/llama3",
22+
"model": "ollama/llama3.1",
2323
"temperature": 0,
2424
"format": "json", # Ollama needs the format to be specified explicitly
2525
# "base_url": "http://localhost:11434", # set ollama URL arbitrarily

examples/openai/smart_scraper_openai.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
Basic example of scraping pipeline using SmartScraper
33
"""
44

5-
import os, json
5+
import os
6+
import json
67
from scrapegraphai.graphs import SmartScraperGraph
78
from scrapegraphai.utils import prettify_exec_info
89
from dotenv import load_dotenv

examples/openai/smart_scraper_schema_openai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class Projects(BaseModel):
3030
graph_config = {
3131
"llm": {
3232
"api_key":openai_key,
33-
"model": "gpt-4o",
33+
"model": "gpt-4o-mini",
3434
},
3535
"verbose": True,
3636
"headless": False,

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@ authors = [
1414
]
1515

1616
dependencies = [
17-
"langchain>=0.2.10",
17+
"langchain>=0.2.14",
1818
"langchain-fireworks>=0.1.3",
1919
"langchain_community>=0.2.9",
2020
"langchain-google-genai>=1.0.7",
2121
"langchain-google-vertexai>=1.0.7",
22-
"langchain-openai>=0.1.17",
22+
"langchain-openai>=0.1.22",
2323
"langchain-groq>=0.1.3",
2424
"langchain-aws>=0.1.3",
2525
"langchain-anthropic>=0.1.11",

requirements-dev.lock

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ jsonschema-specifications==2023.12.1
255255
# via jsonschema
256256
kiwisolver==1.4.5
257257
# via matplotlib
258-
langchain==0.2.12
258+
langchain==0.2.14
259259
# via langchain-community
260260
# via scrapegraphai
261261
langchain-anthropic==0.1.22
@@ -264,7 +264,7 @@ langchain-aws==0.1.16
264264
# via scrapegraphai
265265
langchain-community==0.2.11
266266
# via scrapegraphai
267-
langchain-core==0.2.29
267+
langchain-core==0.2.33
268268
# via langchain
269269
# via langchain-anthropic
270270
# via langchain-aws
@@ -292,7 +292,7 @@ langchain-mistralai==0.1.12
292292
# via scrapegraphai
293293
langchain-nvidia-ai-endpoints==0.2.1
294294
# via scrapegraphai
295-
langchain-openai==0.1.21
295+
langchain-openai==0.1.22
296296
# via scrapegraphai
297297
langchain-text-splitters==0.2.2
298298
# via langchain

requirements.lock

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ jinja2==3.1.4
178178
# via torch
179179
jiter==0.5.0
180180
# via anthropic
181+
# via openai
181182
jmespath==1.0.1
182183
# via boto3
183184
# via botocore
@@ -187,7 +188,7 @@ jsonpatch==1.33
187188
# via langchain-core
188189
jsonpointer==3.0.0
189190
# via jsonpatch
190-
langchain==0.2.11
191+
langchain==0.2.14
191192
# via langchain-community
192193
# via scrapegraphai
193194
langchain-anthropic==0.1.20
@@ -196,7 +197,7 @@ langchain-aws==0.1.12
196197
# via scrapegraphai
197198
langchain-community==0.2.10
198199
# via scrapegraphai
199-
langchain-core==0.2.28
200+
langchain-core==0.2.33
200201
# via langchain
201202
# via langchain-anthropic
202203
# via langchain-aws
@@ -224,7 +225,7 @@ langchain-mistralai==0.1.12
224225
# via scrapegraphai
225226
langchain-nvidia-ai-endpoints==0.1.7
226227
# via scrapegraphai
227-
langchain-openai==0.1.17
228+
langchain-openai==0.1.22
228229
# via scrapegraphai
229230
langchain-text-splitters==0.2.2
230231
# via langchain
@@ -264,7 +265,7 @@ numpy==1.26.4
264265
# via sentence-transformers
265266
# via shapely
266267
# via transformers
267-
openai==1.37.0
268+
openai==1.41.0
268269
# via langchain-fireworks
269270
# via langchain-openai
270271
orjson==3.10.6

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
langchain>=0.2.10
1+
langchain>=0.2.14
22
langchain-fireworks>=0.1.3
33
langchain_community>=0.2.9
44
langchain-google-genai>=1.0.7
55
langchain-google-vertexai>=1.0.7
6-
langchain-openai>=0.1.17
6+
langchain-openai>=0.1.22
77
langchain-groq>=0.1.3
88
langchain-aws>=0.1.3
99
langchain-anthropic>=0.1.11

scrapegraphai/helpers/models_tokens.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
"gpt-4-32k-0613": 32768,
4444
"gpt-4o": 128000,
4545
"gpt-4o-mini":128000,
46-
"chatgpt-4o-latest":128000
46+
"chatgpt-4o-latest": 128000
4747
},
4848
"google_genai": {
4949
"gemini-pro": 128000,

scrapegraphai/nodes/generate_answer_node.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@
55
from langchain.prompts import PromptTemplate
66
from langchain_core.output_parsers import JsonOutputParser
77
from langchain_core.runnables import RunnableParallel
8-
from langchain_openai import ChatOpenAI
8+
from langchain_openai import ChatOpenAI, AzureChatOpenAI
9+
from langchain_mistralai import ChatMistralAI
10+
from langchain_anthropic import ChatAnthropic
11+
from langchain_groq import ChatGroq
12+
from langchain_fireworks import ChatFireworks
13+
from langchain_google_vertexai import ChatVertexAI
914
from langchain_community.chat_models import ChatOllama
1015
from tqdm import tqdm
1116
from ..utils.logging import get_logger
@@ -88,12 +93,19 @@ def execute(self, state: dict) -> dict:
8893
# Initialize the output parser
8994
if self.node_config.get("schema", None) is not None:
9095
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
96+
97+
# Use built-in structured output for providers that allow it
98+
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI, ChatAnthropic, ChatFireworks, ChatGroq, ChatVertexAI)):
99+
self.llm_model = self.llm_model.with_structured_output(
100+
schema = self.node_config["schema"],
101+
method="json_schema")
102+
91103
else:
92104
output_parser = JsonOutputParser()
93105

94106
format_instructions = output_parser.get_format_instructions()
95107

96-
if isinstance(self.llm_model, ChatOpenAI) and not self.script_creator or self.force and not self.script_creator or self.is_md_scraper:
108+
if isinstance(self.llm_model, (ChatOpenAI, AzureChatOpenAI)) and not self.script_creator or self.force and not self.script_creator or self.is_md_scraper:
97109
template_no_chunks_prompt = TEMPLATE_NO_CHUNKS_MD
98110
template_chunks_prompt = TEMPLATE_CHUNKS_MD
99111
template_merge_prompt = TEMPLATE_MERGE_MD

scrapegraphai/nodes/parse_node.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from semchunk import chunk
66
from langchain_community.document_transformers import Html2TextTransformer
77
from langchain_core.documents import Document
8-
from ..utils.logging import get_logger
98
from .base_node import BaseNode
109

1110
class ParseNode(BaseNode):
@@ -78,16 +77,18 @@ def execute(self, state: dict) -> dict:
7877
else:
7978
docs_transformed = docs_transformed[0]
8079

80+
# Adapt the chunk size, leaving room for the reply, the prompt and the schema
81+
chunk_size = self.node_config.get("chunk_size", 4096)
82+
chunk_size = min(chunk_size - 500, int(chunk_size * 0.9))
83+
8184
if isinstance(docs_transformed, Document):
82-
8385
chunks = chunk(text=docs_transformed.page_content,
84-
chunk_size=self.node_config.get("chunk_size", 4096)-250,
86+
chunk_size=chunk_size,
8587
token_counter=lambda text: len(text.split()),
8688
memoize=False)
8789
else:
88-
8990
chunks = chunk(text=docs_transformed,
90-
chunk_size=self.node_config.get("chunk_size", 4096)-250,
91+
chunk_size=chunk_size,
9192
token_counter=lambda text: len(text.split()),
9293
memoize=False)
9394

scrapegraphai/utils/token_calculator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Module for truncatinh in chunks the messages
2+
Module for truncating in chunks the messages
33
"""
44
from typing import List
55
import tiktoken
@@ -27,7 +27,7 @@ def truncate_text_tokens(text: str, model: str, encoding_name: str) -> List[str]
2727
"""
2828

2929
encoding = tiktoken.get_encoding(encoding_name)
30-
max_tokens = models_tokens[model] - 500
30+
max_tokens = min(models_tokens[model] - 500, int(models_tokens[model] * 0.9))
3131
encoded_text = encoding.encode(text)
3232

3333
chunks = [encoded_text[i:i + max_tokens]

0 commit comments

Comments
 (0)