Skip to content

Commit 119514b

Browse files
committed
feat: add vertexai integration
1 parent 79a2f51 commit 119514b

File tree

6 files changed

+41
-9
lines changed

6 files changed

+41
-9
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ dependencies = [
1616
"langchain==0.1.15",
1717
"langchain-openai==0.1.6",
1818
"langchain-google-genai==1.0.3",
19+
"langchain-google-vertexai==1.0.6",
1920
"langchain-groq==0.1.3",
2021
"langchain-aws==0.1.3",
2122
"langchain-anthropic==0.1.11",

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
langchain==0.1.14
22
langchain-openai==0.1.1
33
langchain-google-genai==1.0.1
4+
langchain-google-vertexai==1.0.6
45
langchain-anthropic==0.1.11
56
html2text==2020.1.16
67
faiss-cpu==1.8.0

scrapegraphai/graphs/abstract_graph.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
from langchain_aws import BedrockEmbeddings
1111
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
1212
from langchain_google_genai import GoogleGenerativeAIEmbeddings
13+
from langchain_google_vertexai import VertexAIEmbeddings
1314
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
1415
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
15-
1616
from ..helpers import models_tokens
1717
from ..models import (
1818
Anthropic,
@@ -23,7 +23,8 @@
2323
HuggingFace,
2424
Ollama,
2525
OpenAI,
26-
OneApi
26+
OneApi,
27+
VertexAI
2728
)
2829
from ..models.ernie import Ernie
2930
from ..utils.logging import set_verbosity_debug, set_verbosity_warning, set_verbosity_info
@@ -71,7 +72,7 @@ def __init__(self, prompt: str, config: dict,
7172
self.config = config
7273
self.schema = schema
7374
self.llm_model = self._create_llm(config["llm"], chat=True)
74-
self.embedder_model = self._create_default_embedder(llm_config=config["llm"] ) if "embeddings" not in config else self._create_embedder(
75+
self.embedder_model = self._create_default_embedder(llm_config=config["llm"]) if "embeddings" not in config else self._create_embedder(
7576
config["embeddings"])
7677
self.verbose = False if config is None else config.get(
7778
"verbose", False)
@@ -102,7 +103,7 @@ def __init__(self, prompt: str, config: dict,
102103
"embedder_model": self.embedder_model,
103104
"cache_path": self.cache_path,
104105
}
105-
106+
106107
self.set_common_params(common_params, overwrite=True)
107108

108109
# set burr config
@@ -125,7 +126,7 @@ def set_common_params(self, params: dict, overwrite=False):
125126

126127
for node in self.graph.nodes:
127128
node.update_config(params, overwrite)
128-
129+
129130
def _create_llm(self, llm_config: dict, chat=False) -> object:
130131
"""
131132
Create a large language model instance based on the configuration provided.
@@ -170,7 +171,6 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
170171
except KeyError as exc:
171172
raise KeyError("Model not supported") from exc
172173
return AzureOpenAI(llm_params)
173-
174174
elif "gemini" in llm_params["model"]:
175175
try:
176176
self.model_token = models_tokens["gemini"][llm_params["model"]]
@@ -183,6 +183,12 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
183183
except KeyError as exc:
184184
raise KeyError("Model not supported") from exc
185185
return Anthropic(llm_params)
186+
elif llm_params["model"].startswith("vertexai"):
187+
try:
188+
self.model_token = models_tokens["vertexai"][llm_params["model"]]
189+
except KeyError as exc:
190+
raise KeyError("Model not supported") from exc
191+
return VertexAI(llm_params)
186192
elif "ollama" in llm_params["model"]:
187193
llm_params["model"] = llm_params["model"].split("ollama/")[-1]
188194

@@ -275,10 +281,12 @@ def _create_default_embedder(self, llm_config=None) -> object:
275281
google_api_key=llm_config["api_key"], model="models/embedding-001"
276282
)
277283
if isinstance(self.llm_model, OpenAI):
278-
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key, base_url=self.llm_model.openai_api_base)
284+
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key,
285+
base_url=self.llm_model.openai_api_base)
279286
elif isinstance(self.llm_model, DeepSeek):
280-
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
281-
287+
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
288+
elif isinstance(self.llm_model, VertexAI):
289+
return VertexAIEmbeddings()
282290
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
283291
return self.llm_model
284292
elif isinstance(self.llm_model, AzureOpenAI):

scrapegraphai/helpers/models_tokens.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@
7575
"claude2.1": 200000,
7676
"claude3": 200000
7777
},
78+
"vertexai": {
79+
"gemini-1.5-flash": 128000,
80+
"gemini-1.5-pro": 128000,
81+
"gemini-1.0-pro": 128000
82+
},
7883
"bedrock": {
7984
"anthropic.claude-3-haiku-20240307-v1:0": 200000,
8085
"anthropic.claude-3-sonnet-20240229-v1:0": 200000,

scrapegraphai/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414
from .anthropic import Anthropic
1515
from .deepseek import DeepSeek
1616
from .oneapi import OneApi
17+
from .vertex import VertexAI

scrapegraphai/models/vertex.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""
2+
VertexAI Module
3+
"""
4+
from langchain_google_vertexai import ChatVertexAI
5+
6+
class VertexAI(ChatVertexAI):
7+
"""
8+
A wrapper for the ChatVertexAI class that provides default configuration
9+
and could be extended with additional methods if needed.
10+
11+
Args:
12+
llm_config (dict): Configuration parameters for the language model.
13+
"""
14+
15+
def __init__(self, llm_config: dict):
16+
super().__init__(**llm_config)

0 commit comments

Comments
 (0)