@@ -46,7 +46,7 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
46
46
self .source = source
47
47
self .config = config
48
48
self .llm_model = self ._create_llm (config ["llm" ], chat = True )
49
- self .embedder_model = self ._create_default_embedder (
49
+ self .embedder_model = self ._create_default_embedder (llm_config = config [ "llm" ]
50
50
) if "embeddings" not in config else self ._create_embedder (
51
51
config ["embeddings" ])
52
52
@@ -91,6 +91,13 @@ def _set_model_token(self, llm):
91
91
self .model_token = models_tokens ['mistral' ][llm .repo_id ]
92
92
except KeyError :
93
93
raise KeyError ("Model not supported" )
94
+
95
+ elif 'Google' in str (type (llm )):
96
+ try :
97
+ if 'gemini' in llm .model :
98
+ self .model_token = models_tokens ['gemini' ][llm .model ]
99
+ except KeyError :
100
+ raise KeyError ("Model not supported" )
94
101
95
102
def _create_llm (self , llm_config : dict , chat = False ) -> object :
96
103
"""
@@ -197,7 +204,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
197
204
raise ValueError (
198
205
"Model provided by the configuration not supported" )
199
206
200
- def _create_default_embedder (self ) -> object :
207
+ def _create_default_embedder (self , llm_config = None ) -> object :
201
208
"""
202
209
Create an embedding model instance based on the chosen llm model.
203
210
@@ -207,6 +214,8 @@ def _create_default_embedder(self) -> object:
207
214
Raises:
208
215
ValueError: If the model is not supported.
209
216
"""
217
+ if isinstance (self .llm_model , Gemini ):
218
+ return GoogleGenerativeAIEmbeddings (google_api_key = llm_config ['api_key' ], model = "models/embedding-001" )
210
219
if isinstance (self .llm_model , OpenAI ):
211
220
return OpenAIEmbeddings (api_key = self .llm_model .openai_api_key )
212
221
elif isinstance (self .llm_model , AzureOpenAIEmbeddings ):
@@ -241,7 +250,6 @@ def _create_embedder(self, embedder_config: dict) -> object:
241
250
Raises:
242
251
KeyError: If the model is not supported.
243
252
"""
244
-
245
253
if 'model_instance' in embedder_config :
246
254
return embedder_config ['model_instance' ]
247
255
# Instantiate the embedding model based on the model name
0 commit comments