Skip to content

feat: jumpstart gated model artifacts #4215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/sagemaker/jumpstart/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ class JumpStartModelsAccessor(object):
_curr_region = JUMPSTART_DEFAULT_REGION_NAME

_content_bucket: Optional[str] = None
_gated_content_bucket: Optional[str] = None

_cache_kwargs: Dict[str, Any] = {}

Expand All @@ -140,6 +141,16 @@ def get_jumpstart_content_bucket() -> Optional[str]:
"""Returns JumpStart content bucket."""
return JumpStartModelsAccessor._content_bucket

@staticmethod
def set_jumpstart_gated_content_bucket(gated_content_bucket: str) -> None:
"""Sets JumpStart gated content bucket."""
JumpStartModelsAccessor._gated_content_bucket = gated_content_bucket

@staticmethod
def get_jumpstart_gated_content_bucket() -> Optional[str]:
"""Returns JumpStart gated content bucket."""
return JumpStartModelsAccessor._gated_content_bucket

@staticmethod
def _validate_and_mutate_region_cache_kwargs(
cache_kwargs: Optional[Dict[str, Any]] = None, region: Optional[str] = None
Expand Down
14 changes: 11 additions & 3 deletions src/sagemaker/jumpstart/artifacts/model_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
from sagemaker.jumpstart.utils import (
get_jumpstart_content_bucket,
get_jumpstart_gated_content_bucket,
verify_model_region_and_return_specs,
)
from sagemaker.session import Session
Expand Down Expand Up @@ -157,9 +158,16 @@ def _retrieve_model_uri(

model_artifact_key = _retrieve_training_artifact_key(model_specs, instance_type)

bucket = os.environ.get(
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE
) or get_jumpstart_content_bucket(region)
default_jumpstart_bucket: str = (
get_jumpstart_gated_content_bucket(region)
if model_specs.gated_bucket
else get_jumpstart_content_bucket(region)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not add support for env var override here?


bucket = (
os.environ.get(ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE)
or default_jumpstart_bucket
)

model_s3_uri = f"s3://{bucket}/{model_artifact_key}"

Expand Down
25 changes: 25 additions & 0 deletions src/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,82 +38,102 @@
JumpStartLaunchedRegionInfo(
region_name="us-west-2",
content_bucket="jumpstart-cache-prod-us-west-2",
gated_content_bucket="jumpstart-private-cache-prod-us-west-2",
),
JumpStartLaunchedRegionInfo(
region_name="us-east-1",
content_bucket="jumpstart-cache-prod-us-east-1",
gated_content_bucket="jumpstart-private-cache-prod-us-east-1",
),
JumpStartLaunchedRegionInfo(
region_name="us-east-2",
content_bucket="jumpstart-cache-prod-us-east-2",
gated_content_bucket="jumpstart-private-cache-prod-us-east-2",
),
JumpStartLaunchedRegionInfo(
region_name="eu-west-1",
content_bucket="jumpstart-cache-prod-eu-west-1",
gated_content_bucket="jumpstart-private-cache-prod-eu-west-1",
),
JumpStartLaunchedRegionInfo(
region_name="eu-central-1",
content_bucket="jumpstart-cache-prod-eu-central-1",
gated_content_bucket="jumpstart-private-cache-prod-eu-central-1",
),
JumpStartLaunchedRegionInfo(
region_name="eu-north-1",
content_bucket="jumpstart-cache-prod-eu-north-1",
gated_content_bucket="jumpstart-private-cache-prod-eu-north-1",
),
JumpStartLaunchedRegionInfo(
region_name="me-south-1",
content_bucket="jumpstart-cache-prod-me-south-1",
gated_content_bucket="jumpstart-private-cache-prod-me-south-1",
),
JumpStartLaunchedRegionInfo(
region_name="ap-south-1",
content_bucket="jumpstart-cache-prod-ap-south-1",
gated_content_bucket="jumpstart-private-cache-prod-ap-south-1",
),
JumpStartLaunchedRegionInfo(
region_name="eu-west-3",
content_bucket="jumpstart-cache-prod-eu-west-3",
gated_content_bucket="jumpstart-private-cache-prod-eu-west-3",
),
JumpStartLaunchedRegionInfo(
region_name="af-south-1",
content_bucket="jumpstart-cache-prod-af-south-1",
gated_content_bucket="jumpstart-private-cache-prod-af-south-1",
),
JumpStartLaunchedRegionInfo(
region_name="sa-east-1",
content_bucket="jumpstart-cache-prod-sa-east-1",
gated_content_bucket="jumpstart-private-cache-prod-sa-east-1",
),
JumpStartLaunchedRegionInfo(
region_name="ap-east-1",
content_bucket="jumpstart-cache-prod-ap-east-1",
gated_content_bucket="jumpstart-private-cache-prod-ap-east-1",
),
JumpStartLaunchedRegionInfo(
region_name="ap-northeast-2",
content_bucket="jumpstart-cache-prod-ap-northeast-2",
gated_content_bucket="jumpstart-private-cache-prod-ap-northeast-2",
),
JumpStartLaunchedRegionInfo(
region_name="eu-west-2",
content_bucket="jumpstart-cache-prod-eu-west-2",
gated_content_bucket="jumpstart-private-cache-prod-eu-west-2",
),
JumpStartLaunchedRegionInfo(
region_name="eu-south-1",
content_bucket="jumpstart-cache-prod-eu-south-1",
gated_content_bucket="jumpstart-private-cache-prod-eu-south-1",
),
JumpStartLaunchedRegionInfo(
region_name="ap-northeast-1",
content_bucket="jumpstart-cache-prod-ap-northeast-1",
gated_content_bucket="jumpstart-private-cache-prod-ap-northeast-1",
),
JumpStartLaunchedRegionInfo(
region_name="us-west-1",
content_bucket="jumpstart-cache-prod-us-west-1",
gated_content_bucket="jumpstart-private-cache-prod-us-west-1",
),
JumpStartLaunchedRegionInfo(
region_name="ap-southeast-1",
content_bucket="jumpstart-cache-prod-ap-southeast-1",
gated_content_bucket="jumpstart-private-cache-prod-ap-southeast-1",
),
JumpStartLaunchedRegionInfo(
region_name="ap-southeast-2",
content_bucket="jumpstart-cache-prod-ap-southeast-2",
gated_content_bucket="jumpstart-private-cache-prod-ap-southeast-2",
),
JumpStartLaunchedRegionInfo(
region_name="ca-central-1",
content_bucket="jumpstart-cache-prod-ca-central-1",
gated_content_bucket="jumpstart-private-cache-prod-ca-central-1",
),
JumpStartLaunchedRegionInfo(
region_name="cn-north-1",
Expand All @@ -128,6 +148,11 @@
JUMPSTART_REGION_NAME_SET = {region.region_name for region in JUMPSTART_LAUNCHED_REGIONS}

JUMPSTART_BUCKET_NAME_SET = {region.content_bucket for region in JUMPSTART_LAUNCHED_REGIONS}
JUMPSTART_GATED_BUCKET_NAME_SET = {
region.gated_content_bucket
for region in JUMPSTART_LAUNCHED_REGIONS
if region.gated_content_bucket is not None
}

JUMPSTART_DEFAULT_REGION_NAME = boto3.session.Session().region_name or "us-west-2"

Expand Down
11 changes: 9 additions & 2 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,21 @@ class JumpStartS3FileType(str, Enum):
class JumpStartLaunchedRegionInfo(JumpStartDataHolderType):
"""Data class for launched region info."""

__slots__ = ["content_bucket", "region_name"]
__slots__ = ["content_bucket", "region_name", "gated_content_bucket"]

def __init__(self, content_bucket: str, region_name: str):
def __init__(
self, content_bucket: str, region_name: str, gated_content_bucket: Optional[str] = None
):
"""Instantiates JumpStartLaunchedRegionInfo object.

Args:
content_bucket (str): Name of JumpStart s3 content bucket associated with region.
region_name (str): Name of JumpStart launched region.
gated_content_bucket (Optional[str[]): Name of JumpStart gated s3 content bucket
optionally associated with region.
"""
self.content_bucket = content_bucket
self.gated_content_bucket = gated_content_bucket
self.region_name = region_name


Expand Down Expand Up @@ -691,6 +696,7 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
"hosting_instance_type_variants",
"training_instance_type_variants",
"default_payloads",
"gated_bucket",
]

def __init__(self, spec: Dict[str, Any]):
Expand Down Expand Up @@ -767,6 +773,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
if json_obj.get("default_payloads")
else None
)
self.gated_bucket = json_obj.get("gated_bucket", False)
self.inference_volume_size: Optional[int] = json_obj.get("inference_volume_size")
self.inference_enable_network_isolation: bool = json_obj.get(
"inference_enable_network_isolation", False
Expand Down
53 changes: 52 additions & 1 deletion src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,64 @@ def get_jumpstart_launched_regions_message() -> str:
return f"JumpStart is available in {formatted_launched_regions_str} regions."


def get_jumpstart_gated_content_bucket(
region: str = constants.JUMPSTART_DEFAULT_REGION_NAME,
) -> str:
"""Returns regionalized private content bucket name for JumpStart.

Raises:
ValueError: If JumpStart is not launched in ``region`` or private content
unavailable in that region.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: missing 'is'; [...] or private content is unavailable [...]

"""

old_gated_content_bucket: Optional[
str
] = accessors.JumpStartModelsAccessor.get_jumpstart_gated_content_bucket()

info_logs: List[str] = []

gated_bucket_to_return: Optional[str] = None
if (
constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE in os.environ
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not if os.environ.get(constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE):

and len(os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE]) > 0
):
gated_bucket_to_return = os.environ[
constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE
]
info_logs.append(f"Using JumpStart private bucket override: '{gated_bucket_to_return}'")
else:
try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be cleaner to use os.environ.get and then throw a value error if the result is None

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the interest of consistency with the other method for public buckets, I'm going to keep as is. Later, I agree it can be refactored

gated_bucket_to_return = constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[
region
].gated_content_bucket
if gated_bucket_to_return is None:
raise ValueError(
f"No private content bucket for JumpStart exists in {region} region."
)
except KeyError:
formatted_launched_regions_str = get_jumpstart_launched_regions_message()
raise ValueError(
f"Unable to get private content bucket for JumpStart in {region} region. "
f"{formatted_launched_regions_str}"
)

accessors.JumpStartModelsAccessor.set_jumpstart_gated_content_bucket(gated_bucket_to_return)

if gated_bucket_to_return != old_gated_content_bucket:
accessors.JumpStartModelsAccessor.reset_cache()
for info_log in info_logs:
constants.JUMPSTART_LOGGER.info(info_log)

return gated_bucket_to_return


def get_jumpstart_content_bucket(
region: str = constants.JUMPSTART_DEFAULT_REGION_NAME,
) -> str:
"""Returns regionalized content bucket name for JumpStart.

Raises:
RuntimeError: If JumpStart is not launched in ``region``.
ValueError: If JumpStart is not launched in ``region``.
"""

old_content_bucket: Optional[
Expand Down
Loading