Skip to content

Commit 63a5d18

Browse files
committed
fix(AbstractGraph): Bedrock init issues
Closes #633
1 parent 50c9c6b commit 63a5d18

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

scrapegraphai/graphs/abstract_graph.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def _create_llm(self, llm_config: dict) -> object:
128128
return llm_params["model_instance"]
129129

130130
known_providers = {"openai", "azure_openai", "google_genai", "google_vertexai",
131-
"ollama", "oneapi", "nvidia", "groq", "anthropic" "bedrock", "mistralai",
131+
"ollama", "oneapi", "nvidia", "groq", "anthropic", "bedrock", "mistralai",
132132
"hugging_face", "deepseek", "ernie", "fireworks", "togetherai"}
133133

134134
split_model_provider = llm_params["model"].split("/", 1)
@@ -146,6 +146,8 @@ def _create_llm(self, llm_config: dict) -> object:
146146

147147
try:
148148
if llm_params["model_provider"] not in {"oneapi", "nvidia", "ernie", "deepseek", "togetherai"}:
149+
if llm_params["model_provider"] == "bedrock":
150+
llm_params["model_kwargs"] = { "temperature" : llm_params.pop("temperature") }
149151
with warnings.catch_warnings():
150152
warnings.simplefilter("ignore")
151153
return init_chat_model(**llm_params)

tests/graphs/abstract_graph_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from langchain_openai import ChatOpenAI, AzureChatOpenAI
1313
from langchain_ollama import ChatOllama
1414
from langchain_google_genai import ChatGoogleGenerativeAI
15+
from langchain_aws import ChatBedrock
1516

1617

1718

@@ -71,6 +72,7 @@ class TestAbstractGraph:
7172
({"model": "ollama/llama2"}, ChatOllama),
7273
({"model": "oneapi/qwen-turbo", "api_key": "oneapi-api-key"}, OneApi),
7374
({"model": "deepseek/deepseek-coder", "api_key": "deepseek-api-key"}, DeepSeek),
75+
({"model": "bedrock/anthropic.claude-3-sonnet-20240229-v1:0", "region_name": "IDK"}, ChatBedrock),
7476
])
7577

7678
def test_create_llm(self, llm_config, expected_model):

0 commit comments

Comments
 (0)