Skip to content

Commit f85b7f0

Browse files
committed
change: retrieve script uri argument name, comments
1 parent fea1020 commit f85b7f0

File tree

3 files changed

+16
-15
lines changed

3 files changed

+16
-15
lines changed

src/sagemaker/model_uris.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def retrieve(
3636
model_id (str): JumpStart model id for which to retrieve model S3 URI.
3737
model_version (str): JumpStart model version for which to retrieve model S3 URI.
3838
model_scope (str): The model type, i.e. what it is used for.
39-
Valid values: "training", "inference", "eia".
39+
Valid values: "training" and "inference".
4040
Returns:
4141
str: the model artifact URI for the corresponding model.
4242

src/sagemaker/script_uris.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,16 @@ def retrieve(
2626
region=jumpstart_constants.JUMPSTART_DEFAULT_REGION_NAME,
2727
model_id=None,
2828
model_version=None,
29-
model_scope=None,
29+
script_scope=None,
3030
):
3131
"""Retrieves the model script s3 URI for the model matching the given arguments.
3232
3333
Args:
3434
region (str): Region for which to retrieve model script S3 URI.
3535
model_id (str): JumpStart model id for which to retrieve model script S3 URI.
3636
model_version (str): JumpStart model version for which to retrieve model script S3 URI.
37-
model_scope (str): The model type, i.e. what it is used for.
38-
Valid values: "training", "inference", "eia".
37+
script_scope (str): The script type, i.e. what it is used for.
38+
Valid values: "training" and "inference".
3939
Returns:
4040
str: the model script URI for the corresponding model.
4141
@@ -50,13 +50,14 @@ def retrieve(
5050
model_specs = jumpstart_accessors.JumpStartModelsCache.get_model_specs(
5151
region, model_id, model_version
5252
)
53-
if model_scope is None:
53+
if script_scope is None:
5454
raise ValueError(
55-
"Must specify `model_scope` argument to retrieve model script uri for JumpStart models."
55+
"Must specify `script_scope` argument to retrieve model script uri for "
56+
"JumpStart models."
5657
)
57-
if model_scope == "inference":
58+
if script_scope == "inference":
5859
model_script_key = model_specs.hosting_script_key
59-
elif model_scope == "training":
60+
elif script_scope == "training":
6061
if not model_specs.training_supported:
6162
raise ValueError(f"JumpStart model id '{model_id}' does not support training.")
6263
model_script_key = model_specs.training_script_key

tests/unit/sagemaker/test_script_uris.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_jumpstart_script_uri(patched_get_model_specs):
2828
patched_get_model_specs.side_effect = get_spec_from_base_spec
2929
uri = script_uris.retrieve(
3030
region="us-west-2",
31-
model_scope="inference",
31+
script_scope="inference",
3232
model_id="pytorch-ic-mobilenet-v2",
3333
model_version="*",
3434
)
@@ -42,7 +42,7 @@ def test_jumpstart_script_uri(patched_get_model_specs):
4242

4343
uri = script_uris.retrieve(
4444
region="us-west-2",
45-
model_scope="training",
45+
script_scope="training",
4646
model_id="pytorch-ic-mobilenet-v2",
4747
model_version="*",
4848
)
@@ -54,7 +54,7 @@ def test_jumpstart_script_uri(patched_get_model_specs):
5454
patched_get_model_specs.reset_mock()
5555

5656
script_uris.retrieve(
57-
model_scope="training",
57+
script_scope="training",
5858
model_id="pytorch-ic-mobilenet-v2",
5959
model_version="*",
6060
)
@@ -65,15 +65,15 @@ def test_jumpstart_script_uri(patched_get_model_specs):
6565
with pytest.raises(ValueError):
6666
script_uris.retrieve(
6767
region="us-west-2",
68-
model_scope="BAD_SCOPE",
68+
script_scope="BAD_SCOPE",
6969
model_id="pytorch-ic-mobilenet-v2",
7070
model_version="*",
7171
)
7272

7373
with pytest.raises(ValueError):
7474
script_uris.retrieve(
7575
region="mars-south-1",
76-
model_scope="training",
76+
script_scope="training",
7777
model_id="pytorch-ic-mobilenet-v2",
7878
model_version="*",
7979
)
@@ -86,12 +86,12 @@ def test_jumpstart_script_uri(patched_get_model_specs):
8686

8787
with pytest.raises(ValueError):
8888
script_uris.retrieve(
89-
model_scope="training",
89+
script_scope="training",
9090
model_version="*",
9191
)
9292

9393
with pytest.raises(ValueError):
9494
script_uris.retrieve(
95-
model_scope="training",
95+
script_scope="training",
9696
model_id="pytorch-ic-mobilenet-v2",
9797
)

0 commit comments

Comments
 (0)