10
10
from langchain_openai import AzureOpenAIEmbeddings , OpenAIEmbeddings
11
11
12
12
from ..helpers import models_tokens
13
- from ..models import AzureOpenAI , Bedrock , Gemini , Groq , HuggingFace , Ollama , OpenAI
13
+ from ..models import AzureOpenAI , Bedrock , Gemini , Groq , HuggingFace , Ollama , OpenAI , Claude
14
14
15
15
16
16
class AbstractGraph (ABC ):
@@ -22,7 +22,8 @@ class AbstractGraph(ABC):
22
22
source (str): The source of the graph.
23
23
config (dict): Configuration parameters for the graph.
24
24
llm_model: An instance of a language model client, configured for generating answers.
25
- embedder_model: An instance of an embedding model client, configured for generating embeddings.
25
+ embedder_model: An instance of an embedding model client,
26
+ configured for generating embeddings.
26
27
verbose (bool): A flag indicating whether to show print statements during execution.
27
28
headless (bool): A flag indicating whether to run the graph in headless mode.
28
29
@@ -47,8 +48,8 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
47
48
self .source = source
48
49
self .config = config
49
50
self .llm_model = self ._create_llm (config ["llm" ], chat = True )
50
- self .embedder_model = self ._create_default_embedder (
51
- ) if "embeddings" not in config else self ._create_embedder (
51
+ self .embedder_model = self ._create_default_embedder (
52
+ ) if "embeddings" not in config else self ._create_embedder (
52
53
config ["embeddings" ])
53
54
54
55
# Set common configuration parameters
@@ -61,15 +62,13 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
61
62
self .final_state = None
62
63
self .execution_info = None
63
64
64
-
65
65
def _set_model_token (self , llm ):
66
66
67
67
if 'Azure' in str (type (llm )):
68
68
try :
69
69
self .model_token = models_tokens ["azure" ][llm .model_name ]
70
- except KeyError :
71
- raise KeyError ("Model not supported" )
72
-
70
+ except KeyError as exc :
71
+ raise KeyError ("Model not supported" ) from exc
73
72
74
73
def _create_llm (self , llm_config : dict , chat = False ) -> object :
75
74
"""
@@ -96,31 +95,36 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
96
95
if chat :
97
96
self ._set_model_token (llm_params ['model_instance' ])
98
97
return llm_params ['model_instance' ]
99
-
98
+
100
99
# Instantiate the language model based on the model name
101
100
if "gpt-" in llm_params ["model" ]:
102
101
try :
103
102
self .model_token = models_tokens ["openai" ][llm_params ["model" ]]
104
- except KeyError :
105
- raise KeyError ("Model not supported" )
103
+ except KeyError as exc :
104
+ raise KeyError ("Model not supported" ) from exc
106
105
return OpenAI (llm_params )
107
106
108
107
elif "azure" in llm_params ["model" ]:
109
108
# take the model after the last dash
110
109
llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
111
110
try :
112
111
self .model_token = models_tokens ["azure" ][llm_params ["model" ]]
113
- except KeyError :
114
- raise KeyError ("Model not supported" )
112
+ except KeyError as exc :
113
+ raise KeyError ("Model not supported" ) from exc
115
114
return AzureOpenAI (llm_params )
116
115
117
116
elif "gemini" in llm_params ["model" ]:
118
117
try :
119
118
self .model_token = models_tokens ["gemini" ][llm_params ["model" ]]
120
- except KeyError :
121
- raise KeyError ("Model not supported" )
119
+ except KeyError as exc :
120
+ raise KeyError ("Model not supported" ) from exc
122
121
return Gemini (llm_params )
123
-
122
+ elif "claude" in llm_params ["model" ]:
123
+ try :
124
+ self .model_token = models_tokens ["claude" ][llm_params ["model" ]]
125
+ except KeyError as exc :
126
+ raise KeyError ("Model not supported" ) from exc
127
+ return Claude (llm_params )
124
128
elif "ollama" in llm_params ["model" ]:
125
129
llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
126
130
@@ -131,8 +135,8 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
131
135
elif llm_params ["model" ] in models_tokens ["ollama" ]:
132
136
try :
133
137
self .model_token = models_tokens ["ollama" ][llm_params ["model" ]]
134
- except KeyError :
135
- raise KeyError ("Model not supported" )
138
+ except KeyError as exc :
139
+ raise KeyError ("Model not supported" ) from exc
136
140
else :
137
141
self .model_token = 8192
138
142
except AttributeError :
@@ -142,25 +146,25 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
142
146
elif "hugging_face" in llm_params ["model" ]:
143
147
try :
144
148
self .model_token = models_tokens ["hugging_face" ][llm_params ["model" ]]
145
- except KeyError :
146
- raise KeyError ("Model not supported" )
149
+ except KeyError as exc :
150
+ raise KeyError ("Model not supported" ) from exc
147
151
return HuggingFace (llm_params )
148
152
elif "groq" in llm_params ["model" ]:
149
153
llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
150
154
151
155
try :
152
156
self .model_token = models_tokens ["groq" ][llm_params ["model" ]]
153
- except KeyError :
154
- raise KeyError ("Model not supported" )
157
+ except KeyError as exc :
158
+ raise KeyError ("Model not supported" ) from exc
155
159
return Groq (llm_params )
156
160
elif "bedrock" in llm_params ["model" ]:
157
161
llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
158
162
model_id = llm_params ["model" ]
159
163
160
164
try :
161
165
self .model_token = models_tokens ["bedrock" ][llm_params ["model" ]]
162
- except KeyError :
163
- raise KeyError ("Model not supported" )
166
+ except KeyError as exc :
167
+ raise KeyError ("Model not supported" ) from exc
164
168
return Bedrock ({
165
169
"model_id" : model_id ,
166
170
"model_kwargs" : {
@@ -170,7 +174,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
170
174
else :
171
175
raise ValueError (
172
176
"Model provided by the configuration not supported" )
173
-
177
+
174
178
def _create_default_embedder (self ) -> object :
175
179
"""
176
180
Create an embedding model instance based on the chosen llm model.
@@ -202,7 +206,7 @@ def _create_default_embedder(self) -> object:
202
206
return BedrockEmbeddings (client = None , model_id = self .llm_model .model_id )
203
207
else :
204
208
raise ValueError ("Embedding Model missing or not supported" )
205
-
209
+
206
210
def _create_embedder (self , embedder_config : dict ) -> object :
207
211
"""
208
212
Create an embedding model instance based on the configuration provided.
@@ -216,7 +220,7 @@ def _create_embedder(self, embedder_config: dict) -> object:
216
220
Raises:
217
221
KeyError: If the model is not supported.
218
222
"""
219
-
223
+
220
224
# Instantiate the embedding model based on the model name
221
225
if "openai" in embedder_config ["model" ]:
222
226
return OpenAIEmbeddings (api_key = embedder_config ["api_key" ])
@@ -228,27 +232,27 @@ def _create_embedder(self, embedder_config: dict) -> object:
228
232
embedder_config ["model" ] = embedder_config ["model" ].split ("/" )[- 1 ]
229
233
try :
230
234
models_tokens ["ollama" ][embedder_config ["model" ]]
231
- except KeyError :
232
- raise KeyError ("Model not supported" )
235
+ except KeyError as exc :
236
+ raise KeyError ("Model not supported" ) from exc
233
237
return OllamaEmbeddings (** embedder_config )
234
-
238
+
235
239
elif "hugging_face" in embedder_config ["model" ]:
236
240
try :
237
241
models_tokens ["hugging_face" ][embedder_config ["model" ]]
238
- except KeyError :
239
- raise KeyError ("Model not supported" )
242
+ except KeyError as exc :
243
+ raise KeyError ("Model not supported" )from exc
240
244
return HuggingFaceHubEmbeddings (model = embedder_config ["model" ])
241
-
245
+
242
246
elif "bedrock" in embedder_config ["model" ]:
243
247
embedder_config ["model" ] = embedder_config ["model" ].split ("/" )[- 1 ]
244
248
try :
245
249
models_tokens ["bedrock" ][embedder_config ["model" ]]
246
- except KeyError :
247
- raise KeyError ("Model not supported" )
250
+ except KeyError as exc :
251
+ raise KeyError ("Model not supported" ) from exc
248
252
return BedrockEmbeddings (client = None , model_id = embedder_config ["model" ])
249
253
else :
250
254
raise ValueError (
251
- "Model provided by the configuration not supported" )
255
+ "Model provided by the configuration not supported" )
252
256
253
257
def get_state (self , key = None ) -> dict :
254
258
"""""
@@ -272,7 +276,7 @@ def get_execution_info(self):
272
276
Returns:
273
277
dict: The execution information of the graph.
274
278
"""
275
-
279
+
276
280
return self .execution_info
277
281
278
282
@abstractmethod
@@ -288,4 +292,3 @@ def run(self) -> str:
288
292
Abstract method to execute the graph and return the result.
289
293
"""
290
294
pass
291
-
0 commit comments