Skip to content

Commit 488adba

Browse files
authored
feat: jumpstart gated model artifacts (#4215)
* feat: jumpstart private model artifacts * chore: use gated instead of private keyword
1 parent 4befd93 commit 488adba

File tree

8 files changed

+386
-6
lines changed

8 files changed

+386
-6
lines changed

src/sagemaker/jumpstart/accessors.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ class JumpStartModelsAccessor(object):
127127
_curr_region = JUMPSTART_DEFAULT_REGION_NAME
128128

129129
_content_bucket: Optional[str] = None
130+
_gated_content_bucket: Optional[str] = None
130131

131132
_cache_kwargs: Dict[str, Any] = {}
132133

@@ -140,6 +141,16 @@ def get_jumpstart_content_bucket() -> Optional[str]:
140141
"""Returns JumpStart content bucket."""
141142
return JumpStartModelsAccessor._content_bucket
142143

144+
@staticmethod
145+
def set_jumpstart_gated_content_bucket(gated_content_bucket: str) -> None:
146+
"""Sets JumpStart gated content bucket."""
147+
JumpStartModelsAccessor._gated_content_bucket = gated_content_bucket
148+
149+
@staticmethod
150+
def get_jumpstart_gated_content_bucket() -> Optional[str]:
151+
"""Returns JumpStart gated content bucket."""
152+
return JumpStartModelsAccessor._gated_content_bucket
153+
143154
@staticmethod
144155
def _validate_and_mutate_region_cache_kwargs(
145156
cache_kwargs: Optional[Dict[str, Any]] = None, region: Optional[str] = None

src/sagemaker/jumpstart/artifacts/model_uris.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
)
2626
from sagemaker.jumpstart.utils import (
2727
get_jumpstart_content_bucket,
28+
get_jumpstart_gated_content_bucket,
2829
verify_model_region_and_return_specs,
2930
)
3031
from sagemaker.session import Session
@@ -157,9 +158,16 @@ def _retrieve_model_uri(
157158

158159
model_artifact_key = _retrieve_training_artifact_key(model_specs, instance_type)
159160

160-
bucket = os.environ.get(
161-
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE
162-
) or get_jumpstart_content_bucket(region)
161+
default_jumpstart_bucket: str = (
162+
get_jumpstart_gated_content_bucket(region)
163+
if model_specs.gated_bucket
164+
else get_jumpstart_content_bucket(region)
165+
)
166+
167+
bucket = (
168+
os.environ.get(ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE)
169+
or default_jumpstart_bucket
170+
)
163171

164172
model_s3_uri = f"s3://{bucket}/{model_artifact_key}"
165173

src/sagemaker/jumpstart/constants.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,82 +38,102 @@
3838
JumpStartLaunchedRegionInfo(
3939
region_name="us-west-2",
4040
content_bucket="jumpstart-cache-prod-us-west-2",
41+
gated_content_bucket="jumpstart-private-cache-prod-us-west-2",
4142
),
4243
JumpStartLaunchedRegionInfo(
4344
region_name="us-east-1",
4445
content_bucket="jumpstart-cache-prod-us-east-1",
46+
gated_content_bucket="jumpstart-private-cache-prod-us-east-1",
4547
),
4648
JumpStartLaunchedRegionInfo(
4749
region_name="us-east-2",
4850
content_bucket="jumpstart-cache-prod-us-east-2",
51+
gated_content_bucket="jumpstart-private-cache-prod-us-east-2",
4952
),
5053
JumpStartLaunchedRegionInfo(
5154
region_name="eu-west-1",
5255
content_bucket="jumpstart-cache-prod-eu-west-1",
56+
gated_content_bucket="jumpstart-private-cache-prod-eu-west-1",
5357
),
5458
JumpStartLaunchedRegionInfo(
5559
region_name="eu-central-1",
5660
content_bucket="jumpstart-cache-prod-eu-central-1",
61+
gated_content_bucket="jumpstart-private-cache-prod-eu-central-1",
5762
),
5863
JumpStartLaunchedRegionInfo(
5964
region_name="eu-north-1",
6065
content_bucket="jumpstart-cache-prod-eu-north-1",
66+
gated_content_bucket="jumpstart-private-cache-prod-eu-north-1",
6167
),
6268
JumpStartLaunchedRegionInfo(
6369
region_name="me-south-1",
6470
content_bucket="jumpstart-cache-prod-me-south-1",
71+
gated_content_bucket="jumpstart-private-cache-prod-me-south-1",
6572
),
6673
JumpStartLaunchedRegionInfo(
6774
region_name="ap-south-1",
6875
content_bucket="jumpstart-cache-prod-ap-south-1",
76+
gated_content_bucket="jumpstart-private-cache-prod-ap-south-1",
6977
),
7078
JumpStartLaunchedRegionInfo(
7179
region_name="eu-west-3",
7280
content_bucket="jumpstart-cache-prod-eu-west-3",
81+
gated_content_bucket="jumpstart-private-cache-prod-eu-west-3",
7382
),
7483
JumpStartLaunchedRegionInfo(
7584
region_name="af-south-1",
7685
content_bucket="jumpstart-cache-prod-af-south-1",
86+
gated_content_bucket="jumpstart-private-cache-prod-af-south-1",
7787
),
7888
JumpStartLaunchedRegionInfo(
7989
region_name="sa-east-1",
8090
content_bucket="jumpstart-cache-prod-sa-east-1",
91+
gated_content_bucket="jumpstart-private-cache-prod-sa-east-1",
8192
),
8293
JumpStartLaunchedRegionInfo(
8394
region_name="ap-east-1",
8495
content_bucket="jumpstart-cache-prod-ap-east-1",
96+
gated_content_bucket="jumpstart-private-cache-prod-ap-east-1",
8597
),
8698
JumpStartLaunchedRegionInfo(
8799
region_name="ap-northeast-2",
88100
content_bucket="jumpstart-cache-prod-ap-northeast-2",
101+
gated_content_bucket="jumpstart-private-cache-prod-ap-northeast-2",
89102
),
90103
JumpStartLaunchedRegionInfo(
91104
region_name="eu-west-2",
92105
content_bucket="jumpstart-cache-prod-eu-west-2",
106+
gated_content_bucket="jumpstart-private-cache-prod-eu-west-2",
93107
),
94108
JumpStartLaunchedRegionInfo(
95109
region_name="eu-south-1",
96110
content_bucket="jumpstart-cache-prod-eu-south-1",
111+
gated_content_bucket="jumpstart-private-cache-prod-eu-south-1",
97112
),
98113
JumpStartLaunchedRegionInfo(
99114
region_name="ap-northeast-1",
100115
content_bucket="jumpstart-cache-prod-ap-northeast-1",
116+
gated_content_bucket="jumpstart-private-cache-prod-ap-northeast-1",
101117
),
102118
JumpStartLaunchedRegionInfo(
103119
region_name="us-west-1",
104120
content_bucket="jumpstart-cache-prod-us-west-1",
121+
gated_content_bucket="jumpstart-private-cache-prod-us-west-1",
105122
),
106123
JumpStartLaunchedRegionInfo(
107124
region_name="ap-southeast-1",
108125
content_bucket="jumpstart-cache-prod-ap-southeast-1",
126+
gated_content_bucket="jumpstart-private-cache-prod-ap-southeast-1",
109127
),
110128
JumpStartLaunchedRegionInfo(
111129
region_name="ap-southeast-2",
112130
content_bucket="jumpstart-cache-prod-ap-southeast-2",
131+
gated_content_bucket="jumpstart-private-cache-prod-ap-southeast-2",
113132
),
114133
JumpStartLaunchedRegionInfo(
115134
region_name="ca-central-1",
116135
content_bucket="jumpstart-cache-prod-ca-central-1",
136+
gated_content_bucket="jumpstart-private-cache-prod-ca-central-1",
117137
),
118138
JumpStartLaunchedRegionInfo(
119139
region_name="cn-north-1",
@@ -128,6 +148,11 @@
128148
JUMPSTART_REGION_NAME_SET = {region.region_name for region in JUMPSTART_LAUNCHED_REGIONS}
129149

130150
JUMPSTART_BUCKET_NAME_SET = {region.content_bucket for region in JUMPSTART_LAUNCHED_REGIONS}
151+
JUMPSTART_GATED_BUCKET_NAME_SET = {
152+
region.gated_content_bucket
153+
for region in JUMPSTART_LAUNCHED_REGIONS
154+
if region.gated_content_bucket is not None
155+
}
131156

132157
JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name or "us-west-2"
133158

src/sagemaker/jumpstart/types.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,16 +107,21 @@ class JumpStartS3FileType(str, Enum):
107107
class JumpStartLaunchedRegionInfo(JumpStartDataHolderType):
108108
"""Data class for launched region info."""
109109

110-
__slots__ = ["content_bucket", "region_name"]
110+
__slots__ = ["content_bucket", "region_name", "gated_content_bucket"]
111111

112-
def __init__(self, content_bucket: str, region_name: str):
112+
def __init__(
113+
self, content_bucket: str, region_name: str, gated_content_bucket: Optional[str] = None
114+
):
113115
"""Instantiates JumpStartLaunchedRegionInfo object.
114116
115117
Args:
116118
content_bucket (str): Name of JumpStart s3 content bucket associated with region.
117119
region_name (str): Name of JumpStart launched region.
120+
gated_content_bucket (Optional[str[]): Name of JumpStart gated s3 content bucket
121+
optionally associated with region.
118122
"""
119123
self.content_bucket = content_bucket
124+
self.gated_content_bucket = gated_content_bucket
120125
self.region_name = region_name
121126

122127

@@ -691,6 +696,7 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
691696
"hosting_instance_type_variants",
692697
"training_instance_type_variants",
693698
"default_payloads",
699+
"gated_bucket",
694700
]
695701

696702
def __init__(self, spec: Dict[str, Any]):
@@ -767,6 +773,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
767773
if json_obj.get("default_payloads")
768774
else None
769775
)
776+
self.gated_bucket = json_obj.get("gated_bucket", False)
770777
self.inference_volume_size: Optional[int] = json_obj.get("inference_volume_size")
771778
self.inference_enable_network_isolation: bool = json_obj.get(
772779
"inference_enable_network_isolation", False

src/sagemaker/jumpstart/utils.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,64 @@ def get_jumpstart_launched_regions_message() -> str:
6363
return f"JumpStart is available in {formatted_launched_regions_str} regions."
6464

6565

66+
def get_jumpstart_gated_content_bucket(
67+
region: str = constants.JUMPSTART_DEFAULT_REGION_NAME,
68+
) -> str:
69+
"""Returns regionalized private content bucket name for JumpStart.
70+
71+
Raises:
72+
ValueError: If JumpStart is not launched in ``region`` or private content
73+
unavailable in that region.
74+
"""
75+
76+
old_gated_content_bucket: Optional[
77+
str
78+
] = accessors.JumpStartModelsAccessor.get_jumpstart_gated_content_bucket()
79+
80+
info_logs: List[str] = []
81+
82+
gated_bucket_to_return: Optional[str] = None
83+
if (
84+
constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE in os.environ
85+
and len(os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE]) > 0
86+
):
87+
gated_bucket_to_return = os.environ[
88+
constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE
89+
]
90+
info_logs.append(f"Using JumpStart private bucket override: '{gated_bucket_to_return}'")
91+
else:
92+
try:
93+
gated_bucket_to_return = constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[
94+
region
95+
].gated_content_bucket
96+
if gated_bucket_to_return is None:
97+
raise ValueError(
98+
f"No private content bucket for JumpStart exists in {region} region."
99+
)
100+
except KeyError:
101+
formatted_launched_regions_str = get_jumpstart_launched_regions_message()
102+
raise ValueError(
103+
f"Unable to get private content bucket for JumpStart in {region} region. "
104+
f"{formatted_launched_regions_str}"
105+
)
106+
107+
accessors.JumpStartModelsAccessor.set_jumpstart_gated_content_bucket(gated_bucket_to_return)
108+
109+
if gated_bucket_to_return != old_gated_content_bucket:
110+
accessors.JumpStartModelsAccessor.reset_cache()
111+
for info_log in info_logs:
112+
constants.JUMPSTART_LOGGER.info(info_log)
113+
114+
return gated_bucket_to_return
115+
116+
66117
def get_jumpstart_content_bucket(
67118
region: str = constants.JUMPSTART_DEFAULT_REGION_NAME,
68119
) -> str:
69120
"""Returns regionalized content bucket name for JumpStart.
70121
71122
Raises:
72-
RuntimeError: If JumpStart is not launched in ``region``.
123+
ValueError: If JumpStart is not launched in ``region``.
73124
"""
74125

75126
old_content_bucket: Optional[

0 commit comments

Comments
 (0)