Skip to content

Commit e6bedb6

Browse files
f-aguzziVinciGit00
andcommitted
fix(AbstractGraph): pass kwargs to Ernie and Nvidia models
Co-Authored-By: Marco Vinciguerra <[email protected]>
1 parent c3f1520 commit e6bedb6

File tree

3 files changed

+2
-70
lines changed

3 files changed

+2
-70
lines changed

requirements-dev.lock

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
# features: []
77
# all-features: false
88
# with-sources: false
9-
# generate-hashes: false
10-
# universal: false
119

1210
-e file:.
1311
aiofiles==24.1.0
@@ -112,7 +110,6 @@ filelock==3.15.4
112110
# via huggingface-hub
113111
# via torch
114112
# via transformers
115-
# via triton
116113
fireworks-ai==0.14.0
117114
# via langchain-fireworks
118115
fonttools==4.53.1
@@ -362,34 +359,6 @@ numpy==1.26.4
362359
# via shapely
363360
# via streamlit
364361
# via transformers
365-
nvidia-cublas-cu12==12.1.3.1
366-
# via nvidia-cudnn-cu12
367-
# via nvidia-cusolver-cu12
368-
# via torch
369-
nvidia-cuda-cupti-cu12==12.1.105
370-
# via torch
371-
nvidia-cuda-nvrtc-cu12==12.1.105
372-
# via torch
373-
nvidia-cuda-runtime-cu12==12.1.105
374-
# via torch
375-
nvidia-cudnn-cu12==8.9.2.26
376-
# via torch
377-
nvidia-cufft-cu12==11.0.2.54
378-
# via torch
379-
nvidia-curand-cu12==10.3.2.106
380-
# via torch
381-
nvidia-cusolver-cu12==11.4.5.107
382-
# via torch
383-
nvidia-cusparse-cu12==12.1.0.106
384-
# via nvidia-cusolver-cu12
385-
# via torch
386-
nvidia-nccl-cu12==2.19.3
387-
# via torch
388-
nvidia-nvjitlink-cu12==12.6.20
389-
# via nvidia-cusolver-cu12
390-
# via nvidia-cusparse-cu12
391-
nvidia-nvtx-cu12==12.1.105
392-
# via torch
393362
openai==1.37.0
394363
# via burr
395364
# via langchain-fireworks
@@ -631,8 +600,6 @@ tqdm==4.66.4
631600
transformers==4.43.3
632601
# via langchain-huggingface
633602
# via sentence-transformers
634-
triton==2.2.0
635-
# via torch
636603
typer==0.12.3
637604
# via fastapi-cli
638605
typing-extensions==4.12.2
@@ -676,8 +643,6 @@ uvicorn==0.30.3
676643
# via fastapi
677644
uvloop==0.19.0
678645
# via uvicorn
679-
watchdog==4.0.1
680-
# via streamlit
681646
watchfiles==0.22.0
682647
# via uvicorn
683648
websockets==12.0

requirements.lock

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
# features: []
77
# all-features: false
88
# with-sources: false
9-
# generate-hashes: false
10-
# universal: false
119

1210
-e file:.
1311
aiohttp==3.9.5
@@ -69,7 +67,6 @@ filelock==3.15.4
6967
# via huggingface-hub
7068
# via torch
7169
# via transformers
72-
# via triton
7370
fireworks-ai==0.14.0
7471
# via langchain-fireworks
7572
free-proxy==1.1.1
@@ -267,34 +264,6 @@ numpy==1.26.4
267264
# via sentence-transformers
268265
# via shapely
269266
# via transformers
270-
nvidia-cublas-cu12==12.1.3.1
271-
# via nvidia-cudnn-cu12
272-
# via nvidia-cusolver-cu12
273-
# via torch
274-
nvidia-cuda-cupti-cu12==12.1.105
275-
# via torch
276-
nvidia-cuda-nvrtc-cu12==12.1.105
277-
# via torch
278-
nvidia-cuda-runtime-cu12==12.1.105
279-
# via torch
280-
nvidia-cudnn-cu12==8.9.2.26
281-
# via torch
282-
nvidia-cufft-cu12==11.0.2.54
283-
# via torch
284-
nvidia-curand-cu12==10.3.2.106
285-
# via torch
286-
nvidia-cusolver-cu12==11.4.5.107
287-
# via torch
288-
nvidia-cusparse-cu12==12.1.0.106
289-
# via nvidia-cusolver-cu12
290-
# via torch
291-
nvidia-nccl-cu12==2.19.3
292-
# via torch
293-
nvidia-nvjitlink-cu12==12.6.20
294-
# via nvidia-cusolver-cu12
295-
# via nvidia-cusparse-cu12
296-
nvidia-nvtx-cu12==12.1.105
297-
# via torch
298267
openai==1.37.0
299268
# via langchain-fireworks
300269
# via langchain-openai
@@ -446,8 +415,6 @@ tqdm==4.66.4
446415
transformers==4.43.3
447416
# via langchain-huggingface
448417
# via sentence-transformers
449-
triton==2.2.0
450-
# via torch
451418
typing-extensions==4.12.2
452419
# via anthropic
453420
# via anyio

scrapegraphai/graphs/abstract_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def handle_model(model_name, provider, token_key, default_token=8192):
211211
except KeyError:
212212
print("model not found, using default token size (8192)")
213213
self.model_token = 8192
214-
return ErnieBotChat(llm_params)
214+
return ErnieBotChat(**llm_params)
215215

216216
if "oneapi" in llm_params["model"]:
217217
# take the model after the last dash
@@ -228,7 +228,7 @@ def handle_model(model_name, provider, token_key, default_token=8192):
228228
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
229229
except KeyError as exc:
230230
raise KeyError("Model not supported") from exc
231-
return ChatNVIDIA(**llm_config)
231+
return ChatNVIDIA(**llm_params)
232232

233233
# Raise an error if the model did not match any of the previous cases
234234
raise ValueError("Model provided by the configuration not supported")

0 commit comments

Comments
 (0)