10
10
from langchain_openai import AzureOpenAIEmbeddings , OpenAIEmbeddings
11
11
12
12
from ..helpers import models_tokens
13
- from ..models import AzureOpenAI , Bedrock , Gemini , Groq , HuggingFace , Ollama , OpenAI
13
+ from ..models import AzureOpenAI , Bedrock , Gemini , Groq , HuggingFace , Ollama , OpenAI , Claude
14
14
15
15
16
16
class AbstractGraph (ABC ):
@@ -22,7 +22,8 @@ class AbstractGraph(ABC):
22
22
source (str): The source of the graph.
23
23
config (dict): Configuration parameters for the graph.
24
24
llm_model: An instance of a language model client, configured for generating answers.
25
- embedder_model: An instance of an embedding model client, configured for generating embeddings.
25
+ embedder_model: An instance of an embedding model client,
26
+ configured for generating embeddings.
26
27
verbose (bool): A flag indicating whether to show print statements during execution.
27
28
headless (bool): A flag indicating whether to run the graph in headless mode.
28
29
@@ -47,8 +48,8 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
47
48
self .source = source
48
49
self .config = config
49
50
self .llm_model = self ._create_llm (config ["llm" ], chat = True )
50
- self .embedder_model = self ._create_default_embedder (
51
- ) if "embeddings" not in config else self ._create_embedder (
51
+ self .embedder_model = self ._create_default_embedder (
52
+ ) if "embeddings" not in config else self ._create_embedder (
52
53
config ["embeddings" ])
53
54
54
55
# Set common configuration parameters
@@ -61,23 +62,21 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
61
62
self .final_state = None
62
63
self .execution_info = None
63
64
64
-
65
65
def _set_model_token (self , llm ):
66
66
67
67
if 'Azure' in str (type (llm )):
68
68
try :
69
69
self .model_token = models_tokens ["azure" ][llm .model_name ]
70
70
except KeyError :
71
71
raise KeyError ("Model not supported" )
72
-
72
+
73
73
elif 'HuggingFaceEndpoint' in str (type (llm )):
74
74
if 'mistral' in llm .repo_id :
75
75
try :
76
76
self .model_token = models_tokens ['mistral' ][llm .repo_id ]
77
77
except KeyError :
78
78
raise KeyError ("Model not supported" )
79
79
80
-
81
80
def _create_llm (self , llm_config : dict , chat = False ) -> object :
82
81
"""
83
82
Create a large language model instance based on the configuration provided.
@@ -103,31 +102,36 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
103
102
if chat :
104
103
self ._set_model_token (llm_params ['model_instance' ])
105
104
return llm_params ['model_instance' ]
106
-
105
+
107
106
# Instantiate the language model based on the model name
108
107
if "gpt-" in llm_params ["model" ]:
109
108
try :
110
109
self .model_token = models_tokens ["openai" ][llm_params ["model" ]]
111
- except KeyError :
112
- raise KeyError ("Model not supported" )
110
+ except KeyError as exc :
111
+ raise KeyError ("Model not supported" ) from exc
113
112
return OpenAI (llm_params )
114
113
115
114
elif "azure" in llm_params ["model" ]:
116
115
# take the model after the last dash
117
116
llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
118
117
try :
119
118
self .model_token = models_tokens ["azure" ][llm_params ["model" ]]
120
- except KeyError :
121
- raise KeyError ("Model not supported" )
119
+ except KeyError as exc :
120
+ raise KeyError ("Model not supported" ) from exc
122
121
return AzureOpenAI (llm_params )
123
122
124
123
elif "gemini" in llm_params ["model" ]:
125
124
try :
126
125
self .model_token = models_tokens ["gemini" ][llm_params ["model" ]]
127
- except KeyError :
128
- raise KeyError ("Model not supported" )
126
+ except KeyError as exc :
127
+ raise KeyError ("Model not supported" ) from exc
129
128
return Gemini (llm_params )
130
-
129
+ elif "claude" in llm_params ["model" ]:
130
+ try :
131
+ self .model_token = models_tokens ["claude" ][llm_params ["model" ]]
132
+ except KeyError as exc :
133
+ raise KeyError ("Model not supported" ) from exc
134
+ return Claude (llm_params )
131
135
elif "ollama" in llm_params ["model" ]:
132
136
llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
133
137
@@ -138,8 +142,8 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
138
142
elif llm_params ["model" ] in models_tokens ["ollama" ]:
139
143
try :
140
144
self .model_token = models_tokens ["ollama" ][llm_params ["model" ]]
141
- except KeyError :
142
- raise KeyError ("Model not supported" )
145
+ except KeyError as exc :
146
+ raise KeyError ("Model not supported" ) from exc
143
147
else :
144
148
self .model_token = 8192
145
149
except AttributeError :
@@ -149,25 +153,25 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
149
153
elif "hugging_face" in llm_params ["model" ]:
150
154
try :
151
155
self .model_token = models_tokens ["hugging_face" ][llm_params ["model" ]]
152
- except KeyError :
153
- raise KeyError ("Model not supported" )
156
+ except KeyError as exc :
157
+ raise KeyError ("Model not supported" ) from exc
154
158
return HuggingFace (llm_params )
155
159
elif "groq" in llm_params ["model" ]:
156
160
llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
157
161
158
162
try :
159
163
self .model_token = models_tokens ["groq" ][llm_params ["model" ]]
160
- except KeyError :
161
- raise KeyError ("Model not supported" )
164
+ except KeyError as exc :
165
+ raise KeyError ("Model not supported" ) from exc
162
166
return Groq (llm_params )
163
167
elif "bedrock" in llm_params ["model" ]:
164
168
llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
165
169
model_id = llm_params ["model" ]
166
170
167
171
try :
168
172
self .model_token = models_tokens ["bedrock" ][llm_params ["model" ]]
169
- except KeyError :
170
- raise KeyError ("Model not supported" )
173
+ except KeyError as exc :
174
+ raise KeyError ("Model not supported" ) from exc
171
175
return Bedrock ({
172
176
"model_id" : model_id ,
173
177
"model_kwargs" : {
@@ -177,7 +181,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
177
181
else :
178
182
raise ValueError (
179
183
"Model provided by the configuration not supported" )
180
-
184
+
181
185
def _create_default_embedder (self ) -> object :
182
186
"""
183
187
Create an embedding model instance based on the chosen llm model.
@@ -208,7 +212,7 @@ def _create_default_embedder(self) -> object:
208
212
return BedrockEmbeddings (client = None , model_id = self .llm_model .model_id )
209
213
else :
210
214
raise ValueError ("Embedding Model missing or not supported" )
211
-
215
+
212
216
def _create_embedder (self , embedder_config : dict ) -> object :
213
217
"""
214
218
Create an embedding model instance based on the configuration provided.
@@ -237,27 +241,27 @@ def _create_embedder(self, embedder_config: dict) -> object:
237
241
embedder_config ["model" ] = embedder_config ["model" ].split ("/" )[- 1 ]
238
242
try :
239
243
models_tokens ["ollama" ][embedder_config ["model" ]]
240
- except KeyError :
241
- raise KeyError ("Model not supported" )
244
+ except KeyError as exc :
245
+ raise KeyError ("Model not supported" ) from exc
242
246
return OllamaEmbeddings (** embedder_config )
243
-
247
+
244
248
elif "hugging_face" in embedder_config ["model" ]:
245
249
try :
246
250
models_tokens ["hugging_face" ][embedder_config ["model" ]]
247
- except KeyError :
248
- raise KeyError ("Model not supported" )
251
+ except KeyError as exc :
252
+ raise KeyError ("Model not supported" )from exc
249
253
return HuggingFaceHubEmbeddings (model = embedder_config ["model" ])
250
-
254
+
251
255
elif "bedrock" in embedder_config ["model" ]:
252
256
embedder_config ["model" ] = embedder_config ["model" ].split ("/" )[- 1 ]
253
257
try :
254
258
models_tokens ["bedrock" ][embedder_config ["model" ]]
255
- except KeyError :
256
- raise KeyError ("Model not supported" )
259
+ except KeyError as exc :
260
+ raise KeyError ("Model not supported" ) from exc
257
261
return BedrockEmbeddings (client = None , model_id = embedder_config ["model" ])
258
262
else :
259
263
raise ValueError (
260
- "Model provided by the configuration not supported" )
264
+ "Model provided by the configuration not supported" )
261
265
262
266
def get_state (self , key = None ) -> dict :
263
267
"""""
@@ -281,7 +285,7 @@ def get_execution_info(self):
281
285
Returns:
282
286
dict: The execution information of the graph.
283
287
"""
284
-
288
+
285
289
return self .execution_info
286
290
287
291
@abstractmethod
@@ -297,4 +301,3 @@ def run(self) -> str:
297
301
Abstract method to execute the graph and return the result.
298
302
"""
299
303
pass
300
-
0 commit comments