Skip to content

Commit 3b8e600

Browse files
committed
chore: use gated instead of private keyword
1 parent bda612e commit 3b8e600

File tree

8 files changed

+61
-61
lines changed

8 files changed

+61
-61
lines changed

src/sagemaker/jumpstart/accessors.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ class JumpStartModelsAccessor(object):
127127
_curr_region = JUMPSTART_DEFAULT_REGION_NAME
128128

129129
_content_bucket: Optional[str] = None
130-
_private_content_bucket: Optional[str] = None
130+
_gated_content_bucket: Optional[str] = None
131131

132132
_cache_kwargs: Dict[str, Any] = {}
133133

@@ -142,14 +142,14 @@ def get_jumpstart_content_bucket() -> Optional[str]:
142142
return JumpStartModelsAccessor._content_bucket
143143

144144
@staticmethod
145-
def set_jumpstart_private_content_bucket(private_content_bucket: str) -> None:
146-
"""Sets JumpStart private content bucket."""
147-
JumpStartModelsAccessor._private_content_bucket = private_content_bucket
145+
def set_jumpstart_gated_content_bucket(gated_content_bucket: str) -> None:
146+
"""Sets JumpStart gated content bucket."""
147+
JumpStartModelsAccessor._gated_content_bucket = gated_content_bucket
148148

149149
@staticmethod
150-
def get_jumpstart_private_content_bucket() -> Optional[str]:
151-
"""Returns JumpStart private content bucket."""
152-
return JumpStartModelsAccessor._private_content_bucket
150+
def get_jumpstart_gated_content_bucket() -> Optional[str]:
151+
"""Returns JumpStart gated content bucket."""
152+
return JumpStartModelsAccessor._gated_content_bucket
153153

154154
@staticmethod
155155
def _validate_and_mutate_region_cache_kwargs(

src/sagemaker/jumpstart/artifacts/model_uris.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
)
2626
from sagemaker.jumpstart.utils import (
2727
get_jumpstart_content_bucket,
28-
get_jumpstart_private_content_bucket,
28+
get_jumpstart_gated_content_bucket,
2929
verify_model_region_and_return_specs,
3030
)
3131
from sagemaker.session import Session
@@ -159,8 +159,8 @@ def _retrieve_model_uri(
159159
model_artifact_key = _retrieve_training_artifact_key(model_specs, instance_type)
160160

161161
default_jumpstart_bucket: str = (
162-
get_jumpstart_private_content_bucket(region)
163-
if model_specs.private_bucket
162+
get_jumpstart_gated_content_bucket(region)
163+
if model_specs.gated_bucket
164164
else get_jumpstart_content_bucket(region)
165165
)
166166

src/sagemaker/jumpstart/constants.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -38,102 +38,102 @@
3838
JumpStartLaunchedRegionInfo(
3939
region_name="us-west-2",
4040
content_bucket="jumpstart-cache-prod-us-west-2",
41-
private_content_bucket="jumpstart-private-cache-prod-us-west-2",
41+
gated_content_bucket="jumpstart-private-cache-prod-us-west-2",
4242
),
4343
JumpStartLaunchedRegionInfo(
4444
region_name="us-east-1",
4545
content_bucket="jumpstart-cache-prod-us-east-1",
46-
private_content_bucket="jumpstart-private-cache-prod-us-east-1",
46+
gated_content_bucket="jumpstart-private-cache-prod-us-east-1",
4747
),
4848
JumpStartLaunchedRegionInfo(
4949
region_name="us-east-2",
5050
content_bucket="jumpstart-cache-prod-us-east-2",
51-
private_content_bucket="jumpstart-private-cache-prod-us-east-2",
51+
gated_content_bucket="jumpstart-private-cache-prod-us-east-2",
5252
),
5353
JumpStartLaunchedRegionInfo(
5454
region_name="eu-west-1",
5555
content_bucket="jumpstart-cache-prod-eu-west-1",
56-
private_content_bucket="jumpstart-private-cache-prod-eu-west-1",
56+
gated_content_bucket="jumpstart-private-cache-prod-eu-west-1",
5757
),
5858
JumpStartLaunchedRegionInfo(
5959
region_name="eu-central-1",
6060
content_bucket="jumpstart-cache-prod-eu-central-1",
61-
private_content_bucket="jumpstart-private-cache-prod-eu-central-1",
61+
gated_content_bucket="jumpstart-private-cache-prod-eu-central-1",
6262
),
6363
JumpStartLaunchedRegionInfo(
6464
region_name="eu-north-1",
6565
content_bucket="jumpstart-cache-prod-eu-north-1",
66-
private_content_bucket="jumpstart-private-cache-prod-eu-north-1",
66+
gated_content_bucket="jumpstart-private-cache-prod-eu-north-1",
6767
),
6868
JumpStartLaunchedRegionInfo(
6969
region_name="me-south-1",
7070
content_bucket="jumpstart-cache-prod-me-south-1",
71-
private_content_bucket="jumpstart-private-cache-prod-me-south-1",
71+
gated_content_bucket="jumpstart-private-cache-prod-me-south-1",
7272
),
7373
JumpStartLaunchedRegionInfo(
7474
region_name="ap-south-1",
7575
content_bucket="jumpstart-cache-prod-ap-south-1",
76-
private_content_bucket="jumpstart-private-cache-prod-ap-south-1",
76+
gated_content_bucket="jumpstart-private-cache-prod-ap-south-1",
7777
),
7878
JumpStartLaunchedRegionInfo(
7979
region_name="eu-west-3",
8080
content_bucket="jumpstart-cache-prod-eu-west-3",
81-
private_content_bucket="jumpstart-private-cache-prod-eu-west-3",
81+
gated_content_bucket="jumpstart-private-cache-prod-eu-west-3",
8282
),
8383
JumpStartLaunchedRegionInfo(
8484
region_name="af-south-1",
8585
content_bucket="jumpstart-cache-prod-af-south-1",
86-
private_content_bucket="jumpstart-private-cache-prod-af-south-1",
86+
gated_content_bucket="jumpstart-private-cache-prod-af-south-1",
8787
),
8888
JumpStartLaunchedRegionInfo(
8989
region_name="sa-east-1",
9090
content_bucket="jumpstart-cache-prod-sa-east-1",
91-
private_content_bucket="jumpstart-private-cache-prod-sa-east-1",
91+
gated_content_bucket="jumpstart-private-cache-prod-sa-east-1",
9292
),
9393
JumpStartLaunchedRegionInfo(
9494
region_name="ap-east-1",
9595
content_bucket="jumpstart-cache-prod-ap-east-1",
96-
private_content_bucket="jumpstart-private-cache-prod-ap-east-1",
96+
gated_content_bucket="jumpstart-private-cache-prod-ap-east-1",
9797
),
9898
JumpStartLaunchedRegionInfo(
9999
region_name="ap-northeast-2",
100100
content_bucket="jumpstart-cache-prod-ap-northeast-2",
101-
private_content_bucket="jumpstart-private-cache-prod-ap-northeast-2",
101+
gated_content_bucket="jumpstart-private-cache-prod-ap-northeast-2",
102102
),
103103
JumpStartLaunchedRegionInfo(
104104
region_name="eu-west-2",
105105
content_bucket="jumpstart-cache-prod-eu-west-2",
106-
private_content_bucket="jumpstart-private-cache-prod-eu-west-2",
106+
gated_content_bucket="jumpstart-private-cache-prod-eu-west-2",
107107
),
108108
JumpStartLaunchedRegionInfo(
109109
region_name="eu-south-1",
110110
content_bucket="jumpstart-cache-prod-eu-south-1",
111-
private_content_bucket="jumpstart-private-cache-prod-eu-south-1",
111+
gated_content_bucket="jumpstart-private-cache-prod-eu-south-1",
112112
),
113113
JumpStartLaunchedRegionInfo(
114114
region_name="ap-northeast-1",
115115
content_bucket="jumpstart-cache-prod-ap-northeast-1",
116-
private_content_bucket="jumpstart-private-cache-prod-ap-northeast-1",
116+
gated_content_bucket="jumpstart-private-cache-prod-ap-northeast-1",
117117
),
118118
JumpStartLaunchedRegionInfo(
119119
region_name="us-west-1",
120120
content_bucket="jumpstart-cache-prod-us-west-1",
121-
private_content_bucket="jumpstart-private-cache-prod-us-west-1",
121+
gated_content_bucket="jumpstart-private-cache-prod-us-west-1",
122122
),
123123
JumpStartLaunchedRegionInfo(
124124
region_name="ap-southeast-1",
125125
content_bucket="jumpstart-cache-prod-ap-southeast-1",
126-
private_content_bucket="jumpstart-private-cache-prod-ap-southeast-1",
126+
gated_content_bucket="jumpstart-private-cache-prod-ap-southeast-1",
127127
),
128128
JumpStartLaunchedRegionInfo(
129129
region_name="ap-southeast-2",
130130
content_bucket="jumpstart-cache-prod-ap-southeast-2",
131-
private_content_bucket="jumpstart-private-cache-prod-ap-southeast-2",
131+
gated_content_bucket="jumpstart-private-cache-prod-ap-southeast-2",
132132
),
133133
JumpStartLaunchedRegionInfo(
134134
region_name="ca-central-1",
135135
content_bucket="jumpstart-cache-prod-ca-central-1",
136-
private_content_bucket="jumpstart-private-cache-prod-ca-central-1",
136+
gated_content_bucket="jumpstart-private-cache-prod-ca-central-1",
137137
),
138138
JumpStartLaunchedRegionInfo(
139139
region_name="cn-north-1",
@@ -148,10 +148,10 @@
148148
JUMPSTART_REGION_NAME_SET = {region.region_name for region in JUMPSTART_LAUNCHED_REGIONS}
149149

150150
JUMPSTART_BUCKET_NAME_SET = {region.content_bucket for region in JUMPSTART_LAUNCHED_REGIONS}
151-
JUMPSTART_PRIVATE_BUCKET_NAME_SET = {
152-
region.private_content_bucket
151+
JUMPSTART_GATED_BUCKET_NAME_SET = {
152+
region.gated_content_bucket
153153
for region in JUMPSTART_LAUNCHED_REGIONS
154-
if region.private_content_bucket is not None
154+
if region.gated_content_bucket is not None
155155
}
156156

157157
JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name or "us-west-2"

src/sagemaker/jumpstart/types.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,21 +107,21 @@ class JumpStartS3FileType(str, Enum):
107107
class JumpStartLaunchedRegionInfo(JumpStartDataHolderType):
108108
"""Data class for launched region info."""
109109

110-
__slots__ = ["content_bucket", "region_name", "private_content_bucket"]
110+
__slots__ = ["content_bucket", "region_name", "gated_content_bucket"]
111111

112112
def __init__(
113-
self, content_bucket: str, region_name: str, private_content_bucket: Optional[str] = None
113+
self, content_bucket: str, region_name: str, gated_content_bucket: Optional[str] = None
114114
):
115115
"""Instantiates JumpStartLaunchedRegionInfo object.
116116
117117
Args:
118118
content_bucket (str): Name of JumpStart s3 content bucket associated with region.
119119
region_name (str): Name of JumpStart launched region.
120-
private_content_bucket (Optional[str[]): Name of JumpStart private s3 content bucket
120+
gated_content_bucket (Optional[str[]): Name of JumpStart gated s3 content bucket
121121
optionally associated with region.
122122
"""
123123
self.content_bucket = content_bucket
124-
self.private_content_bucket = private_content_bucket
124+
self.gated_content_bucket = gated_content_bucket
125125
self.region_name = region_name
126126

127127

@@ -692,7 +692,7 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
692692
"hosting_instance_type_variants",
693693
"training_instance_type_variants",
694694
"default_payloads",
695-
"private_bucket",
695+
"gated_bucket",
696696
]
697697

698698
def __init__(self, spec: Dict[str, Any]):
@@ -769,7 +769,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
769769
if json_obj.get("default_payloads")
770770
else None
771771
)
772-
self.private_bucket = json_obj.get("private_bucket", False)
772+
self.gated_bucket = json_obj.get("gated_bucket", False)
773773
self.inference_volume_size: Optional[int] = json_obj.get("inference_volume_size")
774774
self.inference_enable_network_isolation: bool = json_obj.get(
775775
"inference_enable_network_isolation", False

src/sagemaker/jumpstart/utils.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def get_jumpstart_launched_regions_message() -> str:
6363
return f"JumpStart is available in {formatted_launched_regions_str} regions."
6464

6565

66-
def get_jumpstart_private_content_bucket(
66+
def get_jumpstart_gated_content_bucket(
6767
region: str = constants.JUMPSTART_DEFAULT_REGION_NAME,
6868
) -> str:
6969
"""Returns regionalized private content bucket name for JumpStart.
@@ -73,27 +73,27 @@ def get_jumpstart_private_content_bucket(
7373
unavailable in that region.
7474
"""
7575

76-
old_private_content_bucket: Optional[
76+
old_gated_content_bucket: Optional[
7777
str
78-
] = accessors.JumpStartModelsAccessor.get_jumpstart_private_content_bucket()
78+
] = accessors.JumpStartModelsAccessor.get_jumpstart_gated_content_bucket()
7979

8080
info_logs: List[str] = []
8181

82-
private_bucket_to_return: Optional[str] = None
82+
gated_bucket_to_return: Optional[str] = None
8383
if (
8484
constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE in os.environ
8585
and len(os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE]) > 0
8686
):
87-
private_bucket_to_return = os.environ[
87+
gated_bucket_to_return = os.environ[
8888
constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE
8989
]
90-
info_logs.append(f"Using JumpStart private bucket override: '{private_bucket_to_return}'")
90+
info_logs.append(f"Using JumpStart private bucket override: '{gated_bucket_to_return}'")
9191
else:
9292
try:
93-
private_bucket_to_return = constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[
93+
gated_bucket_to_return = constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[
9494
region
95-
].private_content_bucket
96-
if private_bucket_to_return is None:
95+
].gated_content_bucket
96+
if gated_bucket_to_return is None:
9797
raise ValueError(
9898
f"No private content bucket for JumpStart exists in {region} region."
9999
)
@@ -104,14 +104,14 @@ def get_jumpstart_private_content_bucket(
104104
f"{formatted_launched_regions_str}"
105105
)
106106

107-
accessors.JumpStartModelsAccessor.set_jumpstart_private_content_bucket(private_bucket_to_return)
107+
accessors.JumpStartModelsAccessor.set_jumpstart_gated_content_bucket(gated_bucket_to_return)
108108

109-
if private_bucket_to_return != old_private_content_bucket:
109+
if gated_bucket_to_return != old_gated_content_bucket:
110110
accessors.JumpStartModelsAccessor.reset_cache()
111111
for info_log in info_logs:
112112
constants.JUMPSTART_LOGGER.info(info_log)
113113

114-
return private_bucket_to_return
114+
return gated_bucket_to_return
115115

116116

117117
def get_jumpstart_content_bucket(

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -809,7 +809,7 @@
809809
},
810810
"private-model": {
811811
"model_id": "pytorch-ic-mobilenet-v2",
812-
"private_bucket": True,
812+
"gated_bucket": True,
813813
"url": "https://pytorch.org/hub/pytorch_vision_mobilenet_v2/",
814814
"version": "1.0.0",
815815
"min_sdk_version": "2.49.0",
@@ -4017,7 +4017,7 @@
40174017
"min_sdk_version": "2.49.0",
40184018
"training_supported": True,
40194019
"incremental_training_supported": True,
4020-
"private_bucket": False,
4020+
"gated_bucket": False,
40214021
"default_payloads": None,
40224022
"hosting_ecr_specs": {
40234023
"framework": "pytorch",

tests/unit/sagemaker/jumpstart/test_artifacts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ class PrivateJumpStartBucketTest(unittest.TestCase):
438438
mock_session = Mock(s3_client=mock_client)
439439

440440
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
441-
def test_retrieve_uri_from_private_bucket(self, patched_get_model_specs):
441+
def test_retrieve_uri_from_gated_bucket(self, patched_get_model_specs):
442442
patched_get_model_specs.side_effect = get_special_model_spec
443443

444444
model_id = "private-model"

tests/unit/sagemaker/jumpstart/test_utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,25 +63,25 @@ def test_get_jumpstart_content_bucket_override():
6363
mocked_info_log.assert_called_once_with("Using JumpStart bucket override: 'some-val'")
6464

6565

66-
def test_get_jumpstart_private_content_bucket():
66+
def test_get_jumpstart_gated_content_bucket():
6767
bad_region = "bad_region"
6868
assert bad_region not in JUMPSTART_REGION_NAME_SET
6969
with pytest.raises(ValueError):
70-
utils.get_jumpstart_private_content_bucket(bad_region)
70+
utils.get_jumpstart_gated_content_bucket(bad_region)
7171

7272

73-
def test_get_jumpstart_private_content_bucket_no_args():
73+
def test_get_jumpstart_gated_content_bucket_no_args():
7474
assert (
75-
utils.get_jumpstart_private_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)
76-
== utils.get_jumpstart_private_content_bucket()
75+
utils.get_jumpstart_gated_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)
76+
== utils.get_jumpstart_gated_content_bucket()
7777
)
7878

7979

80-
def test_get_jumpstart_private_content_bucket_override():
80+
def test_get_jumpstart_gated_content_bucket_override():
8181
with patch.dict(os.environ, {ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE: "some-val"}):
8282
with patch("logging.Logger.info") as mocked_info_log:
8383
random_region = "random_region"
84-
assert "some-val" == utils.get_jumpstart_private_content_bucket(random_region)
84+
assert "some-val" == utils.get_jumpstart_gated_content_bucket(random_region)
8585
mocked_info_log.assert_called_once_with(
8686
"Using JumpStart private bucket override: 'some-val'"
8787
)

0 commit comments

Comments
 (0)