Skip to content

Commit 02f3a82

Browse files
committed
chore: use ValueError
1 parent 8bb9cba commit 02f3a82

File tree

3 files changed

+5
-9
lines changed

3 files changed

+5
-9
lines changed

src/sagemaker/jumpstart/artifacts.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,6 @@ def _retrieve_model_uri(
208208
VulnerableJumpStartModelError: If any of the dependencies required by the script have
209209
known security vulnerabilities.
210210
DeprecatedJumpStartModelError: If the version of the model is deprecated.
211-
NotImplementedError: If the combination of arguments doesn't support combined model
212-
and script artifact.
213211
"""
214212
if region is None:
215213
region = JUMPSTART_DEFAULT_REGION_NAME
@@ -234,13 +232,13 @@ def _retrieve_model_uri(
234232
else:
235233
model_artifact_key = getattr(model_specs, "hosting_prepacked_artifact_key", None)
236234
if model_artifact_key is None:
237-
raise NotImplementedError(error_msg_no_combined_artifact)
235+
raise ValueError(error_msg_no_combined_artifact)
238236

239237
elif model_scope == JumpStartScriptScope.TRAINING:
240238
if not include_script:
241239
model_artifact_key = model_specs.training_artifact_key
242240
else:
243-
raise NotImplementedError(error_msg_no_combined_artifact)
241+
raise ValueError(error_msg_no_combined_artifact)
244242

245243
bucket = os.environ.get(
246244
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE

src/sagemaker/model_uris.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,6 @@ def retrieve(
6060
VulnerableJumpStartModelError: If any of the dependencies required by the script have
6161
known security vulnerabilities.
6262
DeprecatedJumpStartModelError: If the version of the model is deprecated.
63-
NotImplementedError: If the combination of arguments doesn't support combined model
64-
and script artifact.
6563
"""
6664
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
6765
raise ValueError("Must specify `model_id` and `model_version` when retrieving model URIs.")

tests/unit/sagemaker/model_uris/jumpstart/test_combined_artifact.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,18 @@ def test_jumpstart_combined_artifacts(patched_get_model_specs):
3939
"prepack/v1.0.0/infer-prepack-huggingface-text2text-flan-t5-xxl-fp16.tar.gz"
4040
)
4141

42-
with pytest.raises(NotImplementedError):
42+
with pytest.raises(ValueError):
4343
model_uris.retrieve(
4444
region="us-west-2",
45-
model_scope="transfer_learning",
45+
model_scope="training",
4646
model_id=model_id_combined_model_artifact,
4747
model_version="*",
4848
include_script=True,
4949
)
5050

5151
model_id_combined_model_artifact_unsupported = "xgboost-classification-model"
5252

53-
with pytest.raises(NotImplementedError):
53+
with pytest.raises(ValueError):
5454
model_uris.retrieve(
5555
region="us-west-2",
5656
model_scope="inference",

0 commit comments

Comments
 (0)