Skip to content

Commit df0e310

Browse files
committed
feat: add fireworks integration
1 parent 79a2f51 commit df0e310

File tree

10 files changed

+149
-8
lines changed

10 files changed

+149
-8
lines changed

examples/fireworks/.env.example

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
FIREWORKS_APIKEY="your fireworks api key"
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""
2+
Basic example of scraping pipeline using SmartScraper
3+
"""
4+
5+
import os, json
6+
from dotenv import load_dotenv
7+
from scrapegraphai.graphs import SmartScraperGraph
8+
from scrapegraphai.utils import prettify_exec_info
9+
10+
load_dotenv()
11+
12+
13+
# ************************************************
14+
# Define the configuration for the graph
15+
# ************************************************
16+
17+
fireworks_api_key = os.getenv("FIREWORKS_APIKEY")
18+
19+
graph_config = {
20+
"llm": {
21+
"api_key": fireworks_api_key,
22+
"model": "fireworks/accounts/fireworks/models/mixtral-8x7b-instruct"
23+
},
24+
"embeddings": {
25+
"model": "ollama/nomic-embed-text",
26+
"temperature": 0,
27+
# "base_url": "http://localhost:11434", # set ollama URL arbitrarily
28+
},
29+
"verbose": True,
30+
"headless": False,
31+
}
32+
33+
# ************************************************
34+
# Create the SmartScraperGraph instance and run it
35+
# ************************************************
36+
37+
smart_scraper_graph = SmartScraperGraph(
38+
prompt="List me all the projects with their description",
39+
# also accepts a string with the already downloaded HTML code
40+
source="https://perinim.github.io/projects/",
41+
config=graph_config,
42+
)
43+
44+
result = smart_scraper_graph.run()
45+
print(json.dumps(result, indent=4))
46+
47+
# ************************************************
48+
# Get graph execution info
49+
# ************************************************
50+
51+
graph_exec_info = smart_scraper_graph.get_execution_info()
52+
print(prettify_exec_info(graph_exec_info))

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ dependencies = [
3333
"google==3.0.0",
3434
"undetected-playwright==0.3.0",
3535
"semchunk==1.0.1",
36+
"langchain-fireworks==0.1.3"
3637
]
3738

3839
license = "MIT"

requirements-dev.lock

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ aiofiles==23.2.1
1313
aiohttp==3.9.5
1414
# via langchain
1515
# via langchain-community
16+
# via langchain-fireworks
1617
aiosignal==1.3.1
1718
# via aiohttp
1819
alabaster==0.7.16
@@ -93,6 +94,8 @@ fastapi-pagination==0.12.24
9394
# via burr
9495
filelock==3.14.0
9596
# via huggingface-hub
97+
fireworks-ai==0.14.0
98+
# via langchain-fireworks
9699
fonttools==4.52.1
97100
# via matplotlib
98101
free-proxy==1.1.1
@@ -158,8 +161,11 @@ httptools==0.6.1
158161
httpx==0.27.0
159162
# via anthropic
160163
# via fastapi
164+
# via fireworks-ai
161165
# via groq
162166
# via openai
167+
httpx-sse==0.4.0
168+
# via fireworks-ai
163169
huggingface-hub==0.23.1
164170
# via tokenizers
165171
idna==3.7
@@ -207,10 +213,13 @@ langchain-core==0.1.52
207213
# via langchain-anthropic
208214
# via langchain-aws
209215
# via langchain-community
216+
# via langchain-fireworks
210217
# via langchain-google-genai
211218
# via langchain-groq
212219
# via langchain-openai
213220
# via langchain-text-splitters
221+
langchain-fireworks==0.1.3
222+
# via scrapegraphai
214223
langchain-google-genai==1.0.3
215224
# via scrapegraphai
216225
langchain-groq==0.1.3
@@ -259,6 +268,7 @@ numpy==1.26.4
259268
# via streamlit
260269
openai==1.30.3
261270
# via burr
271+
# via langchain-fireworks
262272
# via langchain-openai
263273
orjson==3.10.3
264274
# via fastapi
@@ -278,6 +288,7 @@ pandas==2.2.2
278288
# via sf-hamilton
279289
# via streamlit
280290
pillow==10.3.0
291+
# via fireworks-ai
281292
# via matplotlib
282293
# via streamlit
283294
playwright==1.43.0
@@ -308,6 +319,7 @@ pydantic==2.7.1
308319
# via burr
309320
# via fastapi
310321
# via fastapi-pagination
322+
# via fireworks-ai
311323
# via google-generativeai
312324
# via groq
313325
# via langchain
@@ -359,6 +371,7 @@ requests==2.32.2
359371
# via huggingface-hub
360372
# via langchain
361373
# via langchain-community
374+
# via langchain-fireworks
362375
# via langsmith
363376
# via sphinx
364377
# via streamlit

requirements.lock

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
aiohttp==3.9.5
1212
# via langchain
1313
# via langchain-community
14+
# via langchain-fireworks
1415
aiosignal==1.3.1
1516
# via aiohttp
1617
annotated-types==0.7.0
@@ -53,6 +54,8 @@ faiss-cpu==1.8.0
5354
# via scrapegraphai
5455
filelock==3.14.0
5556
# via huggingface-hub
57+
fireworks-ai==0.14.0
58+
# via langchain-fireworks
5659
free-proxy==1.1.1
5760
# via scrapegraphai
5861
frozenlist==1.4.1
@@ -105,8 +108,11 @@ httplib2==0.22.0
105108
# via google-auth-httplib2
106109
httpx==0.27.0
107110
# via anthropic
111+
# via fireworks-ai
108112
# via groq
109113
# via openai
114+
httpx-sse==0.4.0
115+
# via fireworks-ai
110116
huggingface-hub==0.23.1
111117
# via tokenizers
112118
idna==3.7
@@ -137,10 +143,13 @@ langchain-core==0.1.52
137143
# via langchain-anthropic
138144
# via langchain-aws
139145
# via langchain-community
146+
# via langchain-fireworks
140147
# via langchain-google-genai
141148
# via langchain-groq
142149
# via langchain-openai
143150
# via langchain-text-splitters
151+
langchain-fireworks==0.1.3
152+
# via scrapegraphai
144153
langchain-google-genai==1.0.3
145154
# via scrapegraphai
146155
langchain-groq==0.1.3
@@ -171,6 +180,7 @@ numpy==1.26.4
171180
# via langchain-community
172181
# via pandas
173182
openai==1.30.3
183+
# via langchain-fireworks
174184
# via langchain-openai
175185
orjson==3.10.3
176186
# via langsmith
@@ -180,6 +190,8 @@ packaging==23.2
180190
# via marshmallow
181191
pandas==2.2.2
182192
# via scrapegraphai
193+
pillow==10.3.0
194+
# via fireworks-ai
183195
playwright==1.43.0
184196
# via scrapegraphai
185197
# via undetected-playwright
@@ -200,6 +212,7 @@ pyasn1-modules==0.4.0
200212
# via google-auth
201213
pydantic==2.7.1
202214
# via anthropic
215+
# via fireworks-ai
203216
# via google-generativeai
204217
# via groq
205218
# via langchain
@@ -232,6 +245,7 @@ requests==2.32.2
232245
# via huggingface-hub
233246
# via langchain
234247
# via langchain-community
248+
# via langchain-fireworks
235249
# via langsmith
236250
# via tiktoken
237251
rsa==4.9

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@ langchain-groq==0.1.3
1717
playwright==1.43.0
1818
langchain-aws==0.1.2
1919
undetected-playwright==0.3.0
20-
semchunk==1.0.1
20+
semchunk==1.0.1
21+
langchain-fireworks==0.1.3

scrapegraphai/graphs/abstract_graph.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
1212
from langchain_google_genai import GoogleGenerativeAIEmbeddings
1313
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
14+
from langchain_fireworks import FireworksEmbeddings
1415
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
1516

1617
from ..helpers import models_tokens
@@ -23,7 +24,8 @@
2324
HuggingFace,
2425
Ollama,
2526
OpenAI,
26-
OneApi
27+
OneApi,
28+
Fireworks
2729
)
2830
from ..models.ernie import Ernie
2931
from ..utils.logging import set_verbosity_debug, set_verbosity_warning, set_verbosity_info
@@ -102,7 +104,7 @@ def __init__(self, prompt: str, config: dict,
102104
"embedder_model": self.embedder_model,
103105
"cache_path": self.cache_path,
104106
}
105-
107+
106108
self.set_common_params(common_params, overwrite=True)
107109

108110
# set burr config
@@ -125,7 +127,7 @@ def set_common_params(self, params: dict, overwrite=False):
125127

126128
for node in self.graph.nodes:
127129
node.update_config(params, overwrite)
128-
130+
129131
def _create_llm(self, llm_config: dict, chat=False) -> object:
130132
"""
131133
Create a large language model instance based on the configuration provided.
@@ -160,8 +162,15 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
160162
try:
161163
self.model_token = models_tokens["oneapi"][llm_params["model"]]
162164
except KeyError as exc:
163-
raise KeyError("Model Model not supported") from exc
165+
raise KeyError("Model not supported") from exc
164166
return OneApi(llm_params)
167+
elif "fireworks" in llm_params["model"]:
168+
try:
169+
self.model_token = models_tokens["fireworks"][llm_params["model"].split("/")[-1]]
170+
llm_params["model"] = "/".join(llm_params["model"].split("/")[1:])
171+
except KeyError as exc:
172+
raise KeyError("Model not supported") from exc
173+
return Fireworks(llm_params)
165174
elif "azure" in llm_params["model"]:
166175
# take the model after the last dash
167176
llm_params["model"] = llm_params["model"].split("/")[-1]
@@ -172,12 +181,14 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
172181
return AzureOpenAI(llm_params)
173182

174183
elif "gemini" in llm_params["model"]:
184+
llm_params["model"] = llm_params["model"].split("/")[-1]
175185
try:
176186
self.model_token = models_tokens["gemini"][llm_params["model"]]
177187
except KeyError as exc:
178188
raise KeyError("Model not supported") from exc
179189
return Gemini(llm_params)
180190
elif llm_params["model"].startswith("claude"):
191+
llm_params["model"] = llm_params["model"].split("/")[-1]
181192
try:
182193
self.model_token = models_tokens["claude"][llm_params["model"]]
183194
except KeyError as exc:
@@ -203,6 +214,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
203214

204215
return Ollama(llm_params)
205216
elif "hugging_face" in llm_params["model"]:
217+
llm_params["model"] = llm_params["model"].split("/")[-1]
206218
try:
207219
self.model_token = models_tokens["hugging_face"][llm_params["model"]]
208220
except KeyError:
@@ -277,12 +289,13 @@ def _create_default_embedder(self, llm_config=None) -> object:
277289
if isinstance(self.llm_model, OpenAI):
278290
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key, base_url=self.llm_model.openai_api_base)
279291
elif isinstance(self.llm_model, DeepSeek):
280-
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
281-
292+
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
282293
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
283294
return self.llm_model
284295
elif isinstance(self.llm_model, AzureOpenAI):
285296
return AzureOpenAIEmbeddings()
297+
elif isinstance(self.llm_model, Fireworks):
298+
return FireworksEmbeddings(model=self.llm_model.model_name)
286299
elif isinstance(self.llm_model, Ollama):
287300
# unwrap the kwargs from the model whihc is a dict
288301
params = self.llm_model._lc_kwargs
@@ -333,6 +346,13 @@ def _create_embedder(self, embedder_config: dict) -> object:
333346
except KeyError as exc:
334347
raise KeyError("Model not supported") from exc
335348
return HuggingFaceHubEmbeddings(model=embedder_params["model"])
349+
elif "fireworks" in embedder_params["model"]:
350+
embedder_params["model"] = "/".join(embedder_params["model"].split("/")[1:])
351+
try:
352+
models_tokens["fireworks"][embedder_params["model"]]
353+
except KeyError as exc:
354+
raise KeyError("Model not supported") from exc
355+
return FireworksEmbeddings(model=embedder_params["model"])
336356
elif "gemini" in embedder_params["model"]:
337357
try:
338358
models_tokens["gemini"][embedder_params["model"]]

scrapegraphai/helpers/models_tokens.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,5 +143,10 @@
143143
"ernie-bot-2-base-en": 4096,
144144
"ernie-bot-2-base-en-zh": 4096,
145145
"ernie-bot-2-base-zh-en": 4096
146-
}
146+
},
147+
"fireworks": {
148+
"llama-v2-7b": 4096,
149+
"mixtral-8x7b-instruct": 4096,
150+
"nomic-ai/nomic-embed-text-v1.5": 8192
151+
},
147152
}

scrapegraphai/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414
from .anthropic import Anthropic
1515
from .deepseek import DeepSeek
1616
from .oneapi import OneApi
17+
from .fireworks import Fireworks

scrapegraphai/models/fireworks.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""
2+
Fireworks Module
3+
"""
4+
from langchain_fireworks import ChatFireworks
5+
6+
7+
class Fireworks(ChatFireworks):
8+
"""
9+
Initializes the Fireworks class.
10+
11+
Args:
12+
llm_config (dict): A dictionary containing configuration parameters for the LLM (required).
13+
The specific keys and values will depend on the LLM implementation
14+
used by the underlying `ChatFireworks` class. Consult its documentation
15+
for details.
16+
17+
Raises:
18+
ValueError: If required keys are missing from the llm_config dictionary.
19+
"""
20+
21+
def __init__(self, llm_config: dict):
22+
"""
23+
Initializes the Fireworks class.
24+
25+
Args:
26+
llm_config (dict): A dictionary containing configuration parameters for the LLM.
27+
The specific keys and values will depend on the LLM implementation.
28+
29+
Raises:
30+
ValueError: If required keys are missing from the llm_config dictionary.
31+
"""
32+
33+
super().__init__(**llm_config)

0 commit comments

Comments
 (0)