@@ -47,8 +47,8 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
47
47
self .source = source
48
48
self .config = config
49
49
self .llm_model = self ._create_llm (config ["llm" ], chat = True )
50
- self .embedder_model = self ._create_default_embedder (
51
- ) if "embeddings" not in config else self ._create_embedder (
50
+ self .embedder_model = self ._create_default_embedder (
51
+ ) if "embeddings" not in config else self ._create_embedder (
52
52
config ["embeddings" ])
53
53
54
54
# Set common configuration parameters
@@ -61,21 +61,23 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
61
61
self .final_state = None
62
62
self .execution_info = None
63
63
64
+
64
65
def _set_model_token (self , llm ):
65
66
66
67
if 'Azure' in str (type (llm )):
67
68
try :
68
69
self .model_token = models_tokens ["azure" ][llm .model_name ]
69
70
except KeyError :
70
71
raise KeyError ("Model not supported" )
71
-
72
+
72
73
elif 'HuggingFaceEndpoint' in str (type (llm )):
73
74
if 'mistral' in llm .repo_id :
74
75
try :
75
76
self .model_token = models_tokens ['mistral' ][llm .repo_id ]
76
77
except KeyError :
77
78
raise KeyError ("Model not supported" )
78
79
80
+
79
81
def _create_llm (self , llm_config : dict , chat = False ) -> object :
80
82
"""
81
83
Create a large language model instance based on the configuration provided.
@@ -101,7 +103,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
101
103
if chat :
102
104
self ._set_model_token (llm_params ['model_instance' ])
103
105
return llm_params ['model_instance' ]
104
-
106
+
105
107
# Instantiate the language model based on the model name
106
108
if "gpt-" in llm_params ["model" ]:
107
109
try :
@@ -178,7 +180,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
178
180
else :
179
181
raise ValueError (
180
182
"Model provided by the configuration not supported" )
181
-
183
+
182
184
def _create_default_embedder (self ) -> object :
183
185
"""
184
186
Create an embedding model instance based on the chosen llm model.
@@ -209,7 +211,7 @@ def _create_default_embedder(self) -> object:
209
211
return BedrockEmbeddings (client = None , model_id = self .llm_model .model_id )
210
212
else :
211
213
raise ValueError ("Embedding Model missing or not supported" )
212
-
214
+
213
215
def _create_embedder (self , embedder_config : dict ) -> object :
214
216
"""
215
217
Create an embedding model instance based on the configuration provided.
@@ -226,7 +228,7 @@ def _create_embedder(self, embedder_config: dict) -> object:
226
228
227
229
if 'model_instance' in embedder_config :
228
230
return embedder_config ['model_instance' ]
229
-
231
+
230
232
# Instantiate the embedding model based on the model name
231
233
if "openai" in embedder_config ["model" ]:
232
234
return OpenAIEmbeddings (api_key = embedder_config ["api_key" ])
@@ -241,14 +243,14 @@ def _create_embedder(self, embedder_config: dict) -> object:
241
243
except KeyError :
242
244
raise KeyError ("Model not supported" )
243
245
return OllamaEmbeddings (** embedder_config )
244
-
246
+
245
247
elif "hugging_face" in embedder_config ["model" ]:
246
248
try :
247
249
models_tokens ["hugging_face" ][embedder_config ["model" ]]
248
250
except KeyError :
249
251
raise KeyError ("Model not supported" )
250
252
return HuggingFaceHubEmbeddings (model = embedder_config ["model" ])
251
-
253
+
252
254
elif "bedrock" in embedder_config ["model" ]:
253
255
embedder_config ["model" ] = embedder_config ["model" ].split ("/" )[- 1 ]
254
256
try :
@@ -258,7 +260,7 @@ def _create_embedder(self, embedder_config: dict) -> object:
258
260
return BedrockEmbeddings (client = None , model_id = embedder_config ["model" ])
259
261
else :
260
262
raise ValueError (
261
- "Model provided by the configuration not supported" )
263
+ "Model provided by the configuration not supported" )
262
264
263
265
def get_state (self , key = None ) -> dict :
264
266
"""""
@@ -282,7 +284,7 @@ def get_execution_info(self):
282
284
Returns:
283
285
dict: The execution information of the graph.
284
286
"""
285
-
287
+
286
288
return self .execution_info
287
289
288
290
@abstractmethod
@@ -298,3 +300,4 @@ def run(self) -> str:
298
300
Abstract method to execute the graph and return the result.
299
301
"""
300
302
pass
303
+
0 commit comments