Skip to content

enable end users to pass model instances for llm and embeddings model #128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions examples/azure/smart_scraper_azure_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""
Basic example of scraping pipeline using SmartScraper using Azure OpenAI Key
"""

import os
from dotenv import load_dotenv
from langchain_openai import AzureChatOpenAI
from langchain_openai import AzureOpenAIEmbeddings
from scrapegraphai.graphs import SmartScraperGraph
from scrapegraphai.utils import prettify_exec_info


## required environment variable in .env
# AZURE_OPENAI_ENDPOINT
# AZURE_OPENAI_CHAT_DEPLOYMENT_NAME
# MODEL_NAME
# AZURE_OPENAI_API_KEY
# OPENAI_API_TYPE
# AZURE_OPENAI_API_VERSION
# AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT_NAME
load_dotenv()


# ************************************************
# Initialize the model instances
# ************************************************

llm_model_instance = AzureChatOpenAI(
openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
azure_deployment=os.environ["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"]
)

embedder_model_instance = AzureOpenAIEmbeddings(
azure_deployment=os.environ["AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT_NAME"],
openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
)

# ************************************************
# Create the SmartScraperGraph instance and run it
# ************************************************

graph_config = {
"llm": {"model_instance": llm_model_instance},
"embeddings": {"model_instance": embedder_model_instance}
}

smart_scraper_graph = SmartScraperGraph(
prompt="List me all the events, with the following fields: company_name, event_name, event_start_date, event_start_time, event_end_date, event_end_time, location, event_mode, event_category, third_party_redirect, no_of_days,
time_in_hours, hosted_or_attending, refreshments_type, registration_available, registration_link",
# also accepts a string with the already downloaded HTML code
source="https://www.hmhco.com/event",
config=graph_config
)

result = smart_scraper_graph.run()
print(result)

# ************************************************
# Get graph execution info
# ************************************************

graph_exec_info = smart_scraper_graph.get_execution_info()
print(prettify_exec_info(graph_exec_info))
20 changes: 18 additions & 2 deletions scrapegraphai/graphs/abstract_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
self.prompt = prompt
self.source = source
self.config = config
self.llm_model = self._create_llm(config["llm"])
self.llm_model = self._create_llm(config["llm"], chat=True)
self.embedder_model = self.llm_model if "embeddings" not in config else self._create_llm(
config["embeddings"])

Expand All @@ -32,7 +32,16 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
self.final_state = None
self.execution_info = None

def _create_llm(self, llm_config: dict):
def _set_model_token(self, llm):

if 'Azure' in str(type(llm)):
try:
self.model_token = models_tokens["azure"][llm.model_name]
except KeyError:
raise KeyError("Model not supported")


def _create_llm(self, llm_config: dict, chat=False) -> object:
"""
Creates an instance of the language model (OpenAI or Gemini) based on configuration.
"""
Expand All @@ -42,6 +51,12 @@ def _create_llm(self, llm_config: dict):
}
llm_params = {**llm_defaults, **llm_config}

# If model instance is passed directly instead of the model details
if 'model_instance' in llm_params:
if chat:
self._set_model_token(llm_params['model_instance'])
return llm_params['model_instance']

# Instantiate the language model based on the model name
if "gpt-" in llm_params["model"]:
try:
Expand Down Expand Up @@ -129,3 +144,4 @@ def run(self) -> str:
Abstract method to execute the graph and return the result.
"""
pass

5 changes: 4 additions & 1 deletion scrapegraphai/helpers/models_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
"gpt-4-32k": 32768,
"gpt-4-32k-0613": 32768,
},

"azure": {
"gpt-3.5-turbo": 4096
},
Copy link
Member

@lurenss lurenss May 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be cool to have all the avaible model on AzureAI, here the link Azure AI

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure.

"gemini": {
"gemini-pro": 128000,
},
Expand All @@ -45,3 +47,4 @@
"claude3": 200000
}
}

3 changes: 3 additions & 0 deletions scrapegraphai/nodes/rag_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def execute(self, state):
if isinstance(embedding_model, OpenAI):
embeddings = OpenAIEmbeddings(
api_key=embedding_model.openai_api_key)
elif isinstance(embedding_model, AzureOpenAIEmbeddings):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DId you have the chanche to test with azure embedding model? Did you use text-ada? does it work properliy? :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I used text-ada, have attached example as well.

embeddings = embedding_model
elif isinstance(embedding_model, AzureOpenAI):
embeddings = AzureOpenAIEmbeddings()
elif isinstance(embedding_model, Ollama):
Expand Down Expand Up @@ -133,3 +135,4 @@ def execute(self, state):

state.update({self.output[0]: compressed_docs})
return state