10
10
from langchain_aws import BedrockEmbeddings
11
11
from langchain_community .embeddings import HuggingFaceHubEmbeddings , OllamaEmbeddings
12
12
from langchain_google_genai import GoogleGenerativeAIEmbeddings
13
+ from langchain_google_vertexai import VertexAIEmbeddings
13
14
from langchain_google_genai .embeddings import GoogleGenerativeAIEmbeddings
14
15
from langchain_openai import AzureOpenAIEmbeddings , OpenAIEmbeddings
15
-
16
16
from ..helpers import models_tokens
17
17
from ..models import (
18
18
Anthropic ,
23
23
HuggingFace ,
24
24
Ollama ,
25
25
OpenAI ,
26
- OneApi
26
+ OneApi ,
27
+ VertexAI
27
28
)
28
29
from ..models .ernie import Ernie
29
30
from ..utils .logging import set_verbosity_debug , set_verbosity_warning , set_verbosity_info
@@ -71,7 +72,7 @@ def __init__(self, prompt: str, config: dict,
71
72
self .config = config
72
73
self .schema = schema
73
74
self .llm_model = self ._create_llm (config ["llm" ], chat = True )
74
- self .embedder_model = self ._create_default_embedder (llm_config = config ["llm" ] ) if "embeddings" not in config else self ._create_embedder (
75
+ self .embedder_model = self ._create_default_embedder (llm_config = config ["llm" ]) if "embeddings" not in config else self ._create_embedder (
75
76
config ["embeddings" ])
76
77
self .verbose = False if config is None else config .get (
77
78
"verbose" , False )
@@ -102,7 +103,7 @@ def __init__(self, prompt: str, config: dict,
102
103
"embedder_model" : self .embedder_model ,
103
104
"cache_path" : self .cache_path ,
104
105
}
105
-
106
+
106
107
self .set_common_params (common_params , overwrite = True )
107
108
108
109
# set burr config
@@ -125,7 +126,7 @@ def set_common_params(self, params: dict, overwrite=False):
125
126
126
127
for node in self .graph .nodes :
127
128
node .update_config (params , overwrite )
128
-
129
+
129
130
def _create_llm (self , llm_config : dict , chat = False ) -> object :
130
131
"""
131
132
Create a large language model instance based on the configuration provided.
@@ -170,7 +171,6 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
170
171
except KeyError as exc :
171
172
raise KeyError ("Model not supported" ) from exc
172
173
return AzureOpenAI (llm_params )
173
-
174
174
elif "gemini" in llm_params ["model" ]:
175
175
try :
176
176
self .model_token = models_tokens ["gemini" ][llm_params ["model" ]]
@@ -183,6 +183,12 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
183
183
except KeyError as exc :
184
184
raise KeyError ("Model not supported" ) from exc
185
185
return Anthropic (llm_params )
186
+ elif llm_params ["model" ].startswith ("vertexai" ):
187
+ try :
188
+ self .model_token = models_tokens ["vertexai" ][llm_params ["model" ]]
189
+ except KeyError as exc :
190
+ raise KeyError ("Model not supported" ) from exc
191
+ return VertexAI (llm_params )
186
192
elif "ollama" in llm_params ["model" ]:
187
193
llm_params ["model" ] = llm_params ["model" ].split ("ollama/" )[- 1 ]
188
194
@@ -275,10 +281,12 @@ def _create_default_embedder(self, llm_config=None) -> object:
275
281
google_api_key = llm_config ["api_key" ], model = "models/embedding-001"
276
282
)
277
283
if isinstance (self .llm_model , OpenAI ):
278
- return OpenAIEmbeddings (api_key = self .llm_model .openai_api_key , base_url = self .llm_model .openai_api_base )
284
+ return OpenAIEmbeddings (api_key = self .llm_model .openai_api_key ,
285
+ base_url = self .llm_model .openai_api_base )
279
286
elif isinstance (self .llm_model , DeepSeek ):
280
- return OpenAIEmbeddings (api_key = self .llm_model .openai_api_key )
281
-
287
+ return OpenAIEmbeddings (api_key = self .llm_model .openai_api_key )
288
+ elif isinstance (self .llm_model , VertexAI ):
289
+ return VertexAIEmbeddings ()
282
290
elif isinstance (self .llm_model , AzureOpenAIEmbeddings ):
283
291
return self .llm_model
284
292
elif isinstance (self .llm_model , AzureOpenAI ):
0 commit comments