10
10
from langchain_openai import AzureOpenAIEmbeddings , OpenAIEmbeddings
11
11
12
12
from ..helpers import models_tokens
13
- from ..models import AzureOpenAI , Bedrock , Gemini , Groq , HuggingFace , Ollama , OpenAI
13
+ from ..models import AzureOpenAI , Bedrock , Gemini , Groq , HuggingFace , Ollama , OpenAI , Anthropic
14
14
15
15
16
16
class AbstractGraph (ABC ):
@@ -47,8 +47,8 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
47
47
self .source = source
48
48
self .config = config
49
49
self .llm_model = self ._create_llm (config ["llm" ], chat = True )
50
- self .embedder_model = self ._create_default_embedder (
51
- ) if "embeddings" not in config else self ._create_embedder (
50
+ self .embedder_model = self ._create_default_embedder (
51
+ ) if "embeddings" not in config else self ._create_embedder (
52
52
config ["embeddings" ])
53
53
54
54
# Set common configuration parameters
@@ -61,23 +61,21 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
61
61
self .final_state = None
62
62
self .execution_info = None
63
63
64
-
65
64
def _set_model_token (self , llm ):
66
65
67
66
if 'Azure' in str (type (llm )):
68
67
try :
69
68
self .model_token = models_tokens ["azure" ][llm .model_name ]
70
69
except KeyError :
71
70
raise KeyError ("Model not supported" )
72
-
71
+
73
72
elif 'HuggingFaceEndpoint' in str (type (llm )):
74
73
if 'mistral' in llm .repo_id :
75
74
try :
76
75
self .model_token = models_tokens ['mistral' ][llm .repo_id ]
77
76
except KeyError :
78
77
raise KeyError ("Model not supported" )
79
78
80
-
81
79
def _create_llm (self , llm_config : dict , chat = False ) -> object :
82
80
"""
83
81
Create a large language model instance based on the configuration provided.
@@ -103,7 +101,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
103
101
if chat :
104
102
self ._set_model_token (llm_params ['model_instance' ])
105
103
return llm_params ['model_instance' ]
106
-
104
+
107
105
# Instantiate the language model based on the model name
108
106
if "gpt-" in llm_params ["model" ]:
109
107
try :
@@ -174,10 +172,13 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
174
172
"temperature" : llm_params ["temperature" ],
175
173
}
176
174
})
175
+ elif "claude-3-" in llm_params ["model" ]:
176
+ self .model_token = models_tokens ["claude" ]["claude3" ]
177
+ return Anthropic (llm_params )
177
178
else :
178
179
raise ValueError (
179
180
"Model provided by the configuration not supported" )
180
-
181
+
181
182
def _create_default_embedder (self ) -> object :
182
183
"""
183
184
Create an embedding model instance based on the chosen llm model.
@@ -208,7 +209,7 @@ def _create_default_embedder(self) -> object:
208
209
return BedrockEmbeddings (client = None , model_id = self .llm_model .model_id )
209
210
else :
210
211
raise ValueError ("Embedding Model missing or not supported" )
211
-
212
+
212
213
def _create_embedder (self , embedder_config : dict ) -> object :
213
214
"""
214
215
Create an embedding model instance based on the configuration provided.
@@ -225,7 +226,7 @@ def _create_embedder(self, embedder_config: dict) -> object:
225
226
226
227
if 'model_instance' in embedder_config :
227
228
return embedder_config ['model_instance' ]
228
-
229
+
229
230
# Instantiate the embedding model based on the model name
230
231
if "openai" in embedder_config ["model" ]:
231
232
return OpenAIEmbeddings (api_key = embedder_config ["api_key" ])
@@ -240,14 +241,14 @@ def _create_embedder(self, embedder_config: dict) -> object:
240
241
except KeyError :
241
242
raise KeyError ("Model not supported" )
242
243
return OllamaEmbeddings (** embedder_config )
243
-
244
+
244
245
elif "hugging_face" in embedder_config ["model" ]:
245
246
try :
246
247
models_tokens ["hugging_face" ][embedder_config ["model" ]]
247
248
except KeyError :
248
249
raise KeyError ("Model not supported" )
249
250
return HuggingFaceHubEmbeddings (model = embedder_config ["model" ])
250
-
251
+
251
252
elif "bedrock" in embedder_config ["model" ]:
252
253
embedder_config ["model" ] = embedder_config ["model" ].split ("/" )[- 1 ]
253
254
try :
@@ -257,7 +258,7 @@ def _create_embedder(self, embedder_config: dict) -> object:
257
258
return BedrockEmbeddings (client = None , model_id = embedder_config ["model" ])
258
259
else :
259
260
raise ValueError (
260
- "Model provided by the configuration not supported" )
261
+ "Model provided by the configuration not supported" )
261
262
262
263
def get_state (self , key = None ) -> dict :
263
264
"""""
@@ -281,7 +282,7 @@ def get_execution_info(self):
281
282
Returns:
282
283
dict: The execution information of the graph.
283
284
"""
284
-
285
+
285
286
return self .execution_info
286
287
287
288
@abstractmethod
@@ -297,4 +298,3 @@ def run(self) -> str:
297
298
Abstract method to execute the graph and return the result.
298
299
"""
299
300
pass
300
-
0 commit comments