1
1
"""
2
2
AbstractGraph Module
3
3
"""
4
+
4
5
from abc import ABC , abstractmethod
5
6
from typing import Optional
7
+
6
8
from langchain_aws import BedrockEmbeddings
7
- from langchain_openai import AzureOpenAIEmbeddings , OpenAIEmbeddings
8
9
from langchain_community .embeddings import HuggingFaceHubEmbeddings , OllamaEmbeddings
9
10
from langchain_google_genai import GoogleGenerativeAIEmbeddings
10
- from ..helpers import models_tokens
11
- from ..utils .logging import set_verbosity
12
- from ..models import AzureOpenAI , Bedrock , Gemini , Groq , HuggingFace , Ollama , OpenAI , Anthropic
13
11
from langchain_google_genai .embeddings import GoogleGenerativeAIEmbeddings
12
+ from langchain_openai import AzureOpenAIEmbeddings , OpenAIEmbeddings
13
+
14
+ from ..helpers import models_tokens
15
+ from ..models import (
16
+ Anthropic ,
17
+ AzureOpenAI ,
18
+ Bedrock ,
19
+ Gemini ,
20
+ Groq ,
21
+ HuggingFace ,
22
+ Ollama ,
23
+ OpenAI ,
24
+ )
25
+ from ..utils .logging import set_verbosity_debug , set_verbosity_warning
26
+
14
27
15
28
class AbstractGraph (ABC ):
16
29
"""
@@ -46,29 +59,35 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
46
59
self .source = source
47
60
self .config = config
48
61
self .llm_model = self ._create_llm (config ["llm" ], chat = True )
49
- self .embedder_model = self ._create_default_embedder (llm_config = config ["llm" ]
50
- ) if "embeddings" not in config else self ._create_embedder (
51
- config ["embeddings" ])
62
+ self .embedder_model = (
63
+ self ._create_default_embedder (llm_config = config ["llm" ])
64
+ if "embeddings" not in config
65
+ else self ._create_embedder (config ["embeddings" ])
66
+ )
52
67
53
68
# Create the graph
54
69
self .graph = self ._create_graph ()
55
70
self .final_state = None
56
71
self .execution_info = None
57
72
58
73
# Set common configuration parameters
59
-
60
- verbose = False if config is None else config .get (
61
- "verbose" , False )
62
- set_verbosity (config .get ("verbose" , "info" ))
63
- self .headless = True if config is None else config .get (
64
- "headless" , True )
74
+
75
+ verbose = bool (config and config .get ("verbose" ))
76
+
77
+ if verbose :
78
+ set_verbosity_debug ()
79
+ else :
80
+ set_verbosity_warning ()
81
+
82
+ self .headless = True if config is None else config .get ("headless" , True )
65
83
self .loader_kwargs = config .get ("loader_kwargs" , {})
66
84
67
- common_params = {"headless" : self .headless ,
68
-
69
- "loader_kwargs" : self .loader_kwargs ,
70
- "llm_model" : self .llm_model ,
71
- "embedder_model" : self .embedder_model }
85
+ common_params = {
86
+ "headless" : self .headless ,
87
+ "loader_kwargs" : self .loader_kwargs ,
88
+ "llm_model" : self .llm_model ,
89
+ "embedder_model" : self .embedder_model ,
90
+ }
72
91
self .set_common_params (common_params , overwrite = False )
73
92
74
93
def set_common_params (self , params : dict , overwrite = False ):
@@ -81,25 +100,25 @@ def set_common_params(self, params: dict, overwrite=False):
81
100
82
101
for node in self .graph .nodes :
83
102
node .update_config (params , overwrite )
84
-
103
+
85
104
def _set_model_token (self , llm ):
86
105
87
- if ' Azure' in str (type (llm )):
106
+ if " Azure" in str (type (llm )):
88
107
try :
89
108
self .model_token = models_tokens ["azure" ][llm .model_name ]
90
109
except KeyError :
91
110
raise KeyError ("Model not supported" )
92
111
93
- elif ' HuggingFaceEndpoint' in str (type (llm )):
94
- if ' mistral' in llm .repo_id :
112
+ elif " HuggingFaceEndpoint" in str (type (llm )):
113
+ if " mistral" in llm .repo_id :
95
114
try :
96
- self .model_token = models_tokens [' mistral' ][llm .repo_id ]
115
+ self .model_token = models_tokens [" mistral" ][llm .repo_id ]
97
116
except KeyError :
98
117
raise KeyError ("Model not supported" )
99
- elif ' Google' in str (type (llm )):
118
+ elif " Google" in str (type (llm )):
100
119
try :
101
- if ' gemini' in llm .model :
102
- self .model_token = models_tokens [' gemini' ][llm .model ]
120
+ if " gemini" in llm .model :
121
+ self .model_token = models_tokens [" gemini" ][llm .model ]
103
122
except KeyError :
104
123
raise KeyError ("Model not supported" )
105
124
@@ -117,17 +136,14 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
117
136
KeyError: If the model is not supported.
118
137
"""
119
138
120
- llm_defaults = {
121
- "temperature" : 0 ,
122
- "streaming" : False
123
- }
139
+ llm_defaults = {"temperature" : 0 , "streaming" : False }
124
140
llm_params = {** llm_defaults , ** llm_config }
125
141
126
142
# If model instance is passed directly instead of the model details
127
- if ' model_instance' in llm_params :
143
+ if " model_instance" in llm_params :
128
144
if chat :
129
- self ._set_model_token (llm_params [' model_instance' ])
130
- return llm_params [' model_instance' ]
145
+ self ._set_model_token (llm_params [" model_instance" ])
146
+ return llm_params [" model_instance" ]
131
147
132
148
# Instantiate the language model based on the model name
133
149
if "gpt-" in llm_params ["model" ]:
@@ -193,18 +209,20 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
193
209
elif "bedrock" in llm_params ["model" ]:
194
210
llm_params ["model" ] = llm_params ["model" ].split ("/" )[- 1 ]
195
211
model_id = llm_params ["model" ]
196
- client = llm_params .get (' client' , None )
212
+ client = llm_params .get (" client" , None )
197
213
try :
198
214
self .model_token = models_tokens ["bedrock" ][llm_params ["model" ]]
199
215
except KeyError as exc :
200
216
raise KeyError ("Model not supported" ) from exc
201
- return Bedrock ({
202
- "client" : client ,
203
- "model_id" : model_id ,
204
- "model_kwargs" : {
205
- "temperature" : llm_params ["temperature" ],
217
+ return Bedrock (
218
+ {
219
+ "client" : client ,
220
+ "model_id" : model_id ,
221
+ "model_kwargs" : {
222
+ "temperature" : llm_params ["temperature" ],
223
+ },
206
224
}
207
- } )
225
+ )
208
226
elif "claude-3-" in llm_params ["model" ]:
209
227
self .model_token = models_tokens ["claude" ]["claude3" ]
210
228
return Anthropic (llm_params )
@@ -215,8 +233,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
215
233
raise KeyError ("Model not supported" ) from exc
216
234
return DeepSeek (llm_params )
217
235
else :
218
- raise ValueError (
219
- "Model provided by the configuration not supported" )
236
+ raise ValueError ("Model provided by the configuration not supported" )
220
237
221
238
def _create_default_embedder (self , llm_config = None ) -> object :
222
239
"""
@@ -229,8 +246,9 @@ def _create_default_embedder(self, llm_config=None) -> object:
229
246
ValueError: If the model is not supported.
230
247
"""
231
248
if isinstance (self .llm_model , Gemini ):
232
- return GoogleGenerativeAIEmbeddings (google_api_key = llm_config ['api_key' ],
233
- model = "models/embedding-001" )
249
+ return GoogleGenerativeAIEmbeddings (
250
+ google_api_key = llm_config ["api_key" ], model = "models/embedding-001"
251
+ )
234
252
if isinstance (self .llm_model , OpenAI ):
235
253
return OpenAIEmbeddings (api_key = self .llm_model .openai_api_key )
236
254
elif isinstance (self .llm_model , AzureOpenAIEmbeddings ):
@@ -265,8 +283,8 @@ def _create_embedder(self, embedder_config: dict) -> object:
265
283
Raises:
266
284
KeyError: If the model is not supported.
267
285
"""
268
- if ' model_instance' in embedder_config :
269
- return embedder_config [' model_instance' ]
286
+ if " model_instance" in embedder_config :
287
+ return embedder_config [" model_instance" ]
270
288
# Instantiate the embedding model based on the model name
271
289
if "openai" in embedder_config ["model" ]:
272
290
return OpenAIEmbeddings (api_key = embedder_config ["api_key" ])
@@ -283,28 +301,27 @@ def _create_embedder(self, embedder_config: dict) -> object:
283
301
try :
284
302
models_tokens ["hugging_face" ][embedder_config ["model" ]]
285
303
except KeyError as exc :
286
- raise KeyError ("Model not supported" )from exc
304
+ raise KeyError ("Model not supported" ) from exc
287
305
return HuggingFaceHubEmbeddings (model = embedder_config ["model" ])
288
306
elif "gemini" in embedder_config ["model" ]:
289
307
try :
290
308
models_tokens ["gemini" ][embedder_config ["model" ]]
291
309
except KeyError as exc :
292
- raise KeyError ("Model not supported" )from exc
310
+ raise KeyError ("Model not supported" ) from exc
293
311
return GoogleGenerativeAIEmbeddings (model = embedder_config ["model" ])
294
312
elif "bedrock" in embedder_config ["model" ]:
295
313
embedder_config ["model" ] = embedder_config ["model" ].split ("/" )[- 1 ]
296
- client = embedder_config .get (' client' , None )
314
+ client = embedder_config .get (" client" , None )
297
315
try :
298
316
models_tokens ["bedrock" ][embedder_config ["model" ]]
299
317
except KeyError as exc :
300
318
raise KeyError ("Model not supported" ) from exc
301
319
return BedrockEmbeddings (client = client , model_id = embedder_config ["model" ])
302
320
else :
303
- raise ValueError (
304
- "Model provided by the configuration not supported" )
321
+ raise ValueError ("Model provided by the configuration not supported" )
305
322
306
323
def get_state (self , key = None ) -> dict :
307
- """""
324
+ """ ""
308
325
Get the final state of the graph.
309
326
310
327
Args:
0 commit comments