Skip to content

Support structured output shema openai #562

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions examples/anthropic/search_graph_schema_haiku.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
"""

import os
from typing import List
from dotenv import load_dotenv
load_dotenv()

from pydantic import BaseModel, Field
from scrapegraphai.graphs import SearchGraph

from pydantic import BaseModel, Field
from typing import List
load_dotenv()

# ************************************************
# Define the output schema for the graph
Expand Down
3 changes: 2 additions & 1 deletion examples/azure/smart_scraper_schema_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
Basic example of scraping pipeline using SmartScraper with schema
"""

import os, json
import os
import json
from typing import List
from pydantic import BaseModel, Field
from dotenv import load_dotenv
Expand Down
2 changes: 1 addition & 1 deletion examples/local_models/smart_scraper_schema_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class Projects(BaseModel):

graph_config = {
"llm": {
"model": "ollama/llama3",
"model": "ollama/llama3.1",
"temperature": 0,
"format": "json", # Ollama needs the format to be specified explicitly
# "base_url": "http://localhost:11434", # set ollama URL arbitrarily
Expand Down
3 changes: 2 additions & 1 deletion examples/openai/smart_scraper_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
Basic example of scraping pipeline using SmartScraper
"""

import os, json
import os
import json
from scrapegraphai.graphs import SmartScraperGraph
from scrapegraphai.utils import prettify_exec_info
from dotenv import load_dotenv
Expand Down
2 changes: 1 addition & 1 deletion examples/openai/smart_scraper_schema_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Projects(BaseModel):
graph_config = {
"llm": {
"api_key":openai_key,
"model": "gpt-4o",
"model": "gpt-4o-mini",
},
"verbose": True,
"headless": False,
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ authors = [
]

dependencies = [
"langchain>=0.2.10",
"langchain>=0.2.14",
"langchain-fireworks>=0.1.3",
"langchain_community>=0.2.9",
"langchain-google-genai>=1.0.7",
"langchain-google-vertexai>=1.0.7",
"langchain-openai>=0.1.17",
"langchain-openai>=0.1.22",
"langchain-groq>=0.1.3",
"langchain-aws>=0.1.3",
"langchain-anthropic>=0.1.11",
Expand Down
6 changes: 3 additions & 3 deletions requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ jsonschema-specifications==2023.12.1
# via jsonschema
kiwisolver==1.4.5
# via matplotlib
langchain==0.2.12
langchain==0.2.14
# via langchain-community
# via scrapegraphai
langchain-anthropic==0.1.22
Expand All @@ -264,7 +264,7 @@ langchain-aws==0.1.16
# via scrapegraphai
langchain-community==0.2.11
# via scrapegraphai
langchain-core==0.2.29
langchain-core==0.2.33
# via langchain
# via langchain-anthropic
# via langchain-aws
Expand Down Expand Up @@ -292,7 +292,7 @@ langchain-mistralai==0.1.12
# via scrapegraphai
langchain-nvidia-ai-endpoints==0.2.1
# via scrapegraphai
langchain-openai==0.1.21
langchain-openai==0.1.22
# via scrapegraphai
langchain-text-splitters==0.2.2
# via langchain
Expand Down
9 changes: 5 additions & 4 deletions requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ jinja2==3.1.4
# via torch
jiter==0.5.0
# via anthropic
# via openai
jmespath==1.0.1
# via boto3
# via botocore
Expand All @@ -187,7 +188,7 @@ jsonpatch==1.33
# via langchain-core
jsonpointer==3.0.0
# via jsonpatch
langchain==0.2.11
langchain==0.2.14
# via langchain-community
# via scrapegraphai
langchain-anthropic==0.1.20
Expand All @@ -196,7 +197,7 @@ langchain-aws==0.1.12
# via scrapegraphai
langchain-community==0.2.10
# via scrapegraphai
langchain-core==0.2.28
langchain-core==0.2.33
# via langchain
# via langchain-anthropic
# via langchain-aws
Expand Down Expand Up @@ -224,7 +225,7 @@ langchain-mistralai==0.1.12
# via scrapegraphai
langchain-nvidia-ai-endpoints==0.1.7
# via scrapegraphai
langchain-openai==0.1.17
langchain-openai==0.1.22
# via scrapegraphai
langchain-text-splitters==0.2.2
# via langchain
Expand Down Expand Up @@ -264,7 +265,7 @@ numpy==1.26.4
# via sentence-transformers
# via shapely
# via transformers
openai==1.37.0
openai==1.41.0
# via langchain-fireworks
# via langchain-openai
orjson==3.10.6
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
langchain>=0.2.10
langchain>=0.2.14
langchain-fireworks>=0.1.3
langchain_community>=0.2.9
langchain-google-genai>=1.0.7
langchain-google-vertexai>=1.0.7
langchain-openai>=0.1.17
langchain-openai>=0.1.22
langchain-groq>=0.1.3
langchain-aws>=0.1.3
langchain-anthropic>=0.1.11
Expand Down
2 changes: 1 addition & 1 deletion scrapegraphai/helpers/models_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"gpt-4-32k-0613": 32768,
"gpt-4o": 128000,
"gpt-4o-mini":128000,
"chatgpt-4o-latest":128000
"chatgpt-4o-latest": 128000
},
"google_genai": {
"gemini-pro": 128000,
Expand Down
16 changes: 14 additions & 2 deletions scrapegraphai/nodes/generate_answer_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
from langchain_openai import ChatOpenAI
from langchain_openai import ChatOpenAI, AzureChatOpenAI
from langchain_mistralai import ChatMistralAI
from langchain_anthropic import ChatAnthropic
from langchain_groq import ChatGroq
from langchain_fireworks import ChatFireworks
from langchain_google_vertexai import ChatVertexAI
from langchain_community.chat_models import ChatOllama
from tqdm import tqdm
from ..utils.logging import get_logger
Expand Down Expand Up @@ -88,12 +93,19 @@ def execute(self, state: dict) -> dict:
# Initialize the output parser
if self.node_config.get("schema", None) is not None:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])

# Use built-in structured output for providers that allow it
if isinstance(self.llm_model, (ChatOpenAI, ChatMistralAI, ChatAnthropic, ChatFireworks, ChatGroq, ChatVertexAI)):
self.llm_model = self.llm_model.with_structured_output(
schema = self.node_config["schema"],
method="json_schema")

else:
output_parser = JsonOutputParser()

format_instructions = output_parser.get_format_instructions()

if isinstance(self.llm_model, ChatOpenAI) and not self.script_creator or self.force and not self.script_creator or self.is_md_scraper:
if isinstance(self.llm_model, (ChatOpenAI, AzureChatOpenAI)) and not self.script_creator or self.force and not self.script_creator or self.is_md_scraper:
template_no_chunks_prompt = TEMPLATE_NO_CHUNKS_MD
template_chunks_prompt = TEMPLATE_CHUNKS_MD
template_merge_prompt = TEMPLATE_MERGE_MD
Expand Down
11 changes: 6 additions & 5 deletions scrapegraphai/nodes/parse_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from semchunk import chunk
from langchain_community.document_transformers import Html2TextTransformer
from langchain_core.documents import Document
from ..utils.logging import get_logger
from .base_node import BaseNode

class ParseNode(BaseNode):
Expand Down Expand Up @@ -78,16 +77,18 @@ def execute(self, state: dict) -> dict:
else:
docs_transformed = docs_transformed[0]

# Adapt the chunk size, leaving room for the reply, the prompt and the schema
chunk_size = self.node_config.get("chunk_size", 4096)
chunk_size = min(chunk_size - 500, int(chunk_size * 0.9))

if isinstance(docs_transformed, Document):

chunks = chunk(text=docs_transformed.page_content,
chunk_size=self.node_config.get("chunk_size", 4096)-250,
chunk_size=chunk_size,
token_counter=lambda text: len(text.split()),
memoize=False)
else:

chunks = chunk(text=docs_transformed,
chunk_size=self.node_config.get("chunk_size", 4096)-250,
chunk_size=chunk_size,
token_counter=lambda text: len(text.split()),
memoize=False)

Expand Down
4 changes: 2 additions & 2 deletions scrapegraphai/utils/token_calculator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Module for truncatinh in chunks the messages
Module for truncating in chunks the messages
"""
from typing import List
import tiktoken
Expand Down Expand Up @@ -27,7 +27,7 @@ def truncate_text_tokens(text: str, model: str, encoding_name: str) -> List[str]
"""

encoding = tiktoken.get_encoding(encoding_name)
max_tokens = models_tokens[model] - 500
max_tokens = min(models_tokens[model] - 500, int(models_tokens[model] * 0.9))
encoded_text = encoding.encode(text)

chunks = [encoded_text[i:i + max_tokens]
Expand Down