Skip to content

Commit 2ac9e16

Browse files
committed
Fixed accidental reformatting.
1 parent e264e92 commit 2ac9e16

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

scrapegraphai/graphs/abstract_graph.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
4747
self.source = source
4848
self.config = config
4949
self.llm_model = self._create_llm(config["llm"], chat=True)
50-
self.embedder_model = self._create_default_embedder(
51-
) if "embeddings" not in config else self._create_embedder(
50+
self.embedder_model = self._create_default_embedder(
51+
) if "embeddings" not in config else self._create_embedder(
5252
config["embeddings"])
5353

5454
# Set common configuration parameters
@@ -61,21 +61,23 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
6161
self.final_state = None
6262
self.execution_info = None
6363

64+
6465
def _set_model_token(self, llm):
6566

6667
if 'Azure' in str(type(llm)):
6768
try:
6869
self.model_token = models_tokens["azure"][llm.model_name]
6970
except KeyError:
7071
raise KeyError("Model not supported")
71-
72+
7273
elif 'HuggingFaceEndpoint' in str(type(llm)):
7374
if 'mistral' in llm.repo_id:
7475
try:
7576
self.model_token = models_tokens['mistral'][llm.repo_id]
7677
except KeyError:
7778
raise KeyError("Model not supported")
7879

80+
7981
def _create_llm(self, llm_config: dict, chat=False) -> object:
8082
"""
8183
Create a large language model instance based on the configuration provided.
@@ -101,7 +103,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
101103
if chat:
102104
self._set_model_token(llm_params['model_instance'])
103105
return llm_params['model_instance']
104-
106+
105107
# Instantiate the language model based on the model name
106108
if "gpt-" in llm_params["model"]:
107109
try:
@@ -178,7 +180,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
178180
else:
179181
raise ValueError(
180182
"Model provided by the configuration not supported")
181-
183+
182184
def _create_default_embedder(self) -> object:
183185
"""
184186
Create an embedding model instance based on the chosen llm model.
@@ -209,7 +211,7 @@ def _create_default_embedder(self) -> object:
209211
return BedrockEmbeddings(client=None, model_id=self.llm_model.model_id)
210212
else:
211213
raise ValueError("Embedding Model missing or not supported")
212-
214+
213215
def _create_embedder(self, embedder_config: dict) -> object:
214216
"""
215217
Create an embedding model instance based on the configuration provided.
@@ -226,7 +228,7 @@ def _create_embedder(self, embedder_config: dict) -> object:
226228

227229
if 'model_instance' in embedder_config:
228230
return embedder_config['model_instance']
229-
231+
230232
# Instantiate the embedding model based on the model name
231233
if "openai" in embedder_config["model"]:
232234
return OpenAIEmbeddings(api_key=embedder_config["api_key"])
@@ -241,14 +243,14 @@ def _create_embedder(self, embedder_config: dict) -> object:
241243
except KeyError:
242244
raise KeyError("Model not supported")
243245
return OllamaEmbeddings(**embedder_config)
244-
246+
245247
elif "hugging_face" in embedder_config["model"]:
246248
try:
247249
models_tokens["hugging_face"][embedder_config["model"]]
248250
except KeyError:
249251
raise KeyError("Model not supported")
250252
return HuggingFaceHubEmbeddings(model=embedder_config["model"])
251-
253+
252254
elif "bedrock" in embedder_config["model"]:
253255
embedder_config["model"] = embedder_config["model"].split("/")[-1]
254256
try:
@@ -258,7 +260,7 @@ def _create_embedder(self, embedder_config: dict) -> object:
258260
return BedrockEmbeddings(client=None, model_id=embedder_config["model"])
259261
else:
260262
raise ValueError(
261-
"Model provided by the configuration not supported")
263+
"Model provided by the configuration not supported")
262264

263265
def get_state(self, key=None) -> dict:
264266
"""""
@@ -282,7 +284,7 @@ def get_execution_info(self):
282284
Returns:
283285
dict: The execution information of the graph.
284286
"""
285-
287+
286288
return self.execution_info
287289

288290
@abstractmethod
@@ -298,3 +300,4 @@ def run(self) -> str:
298300
Abstract method to execute the graph and return the result.
299301
"""
300302
pass
303+

0 commit comments

Comments
 (0)