Skip to content

Commit 8ea3727

Browse files
author
Jonathan Makunga
committed
Restrict JS Gated nodels only in SM Endpoint mode
1 parent 48234ae commit 8ea3727

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
LocalModelInvocationException,
3232
LocalModelLoadException,
3333
SkipTuningComboException,
34+
JumpStartGatedModelNotSupported,
3435
)
3536
from sagemaker.serve.utils.predictors import (
3637
DjlLocalModePredictor,
@@ -443,6 +444,11 @@ def _build_for_jumpstart(self):
443444

444445
logger.info("JumpStart ID %s is packaged with Image URI: %s", self.model, image_uri)
445446

447+
if self._is_gated_model() and self.mode != Mode.SAGEMAKER_ENDPOINT:
448+
raise JumpStartGatedModelNotSupported(
449+
"JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode"
450+
)
451+
446452
if "djl-inference" in image_uri:
447453
logger.info("Building for DJL JumpStart Model ID...")
448454
self.model_server = ModelServer.DJL_SERVING
@@ -469,3 +475,12 @@ def _build_for_jumpstart(self):
469475
)
470476

471477
return self.pysdk_model
478+
479+
def _is_gated_model(self) -> bool:
480+
"""Determine if ``this`` Model is Gated"""
481+
482+
s3_uri = self.pysdk_model.model_data
483+
if isinstance(s3_uri, dict):
484+
s3_uri = s3_uri.get("S3DataSource").get("S3Uri")
485+
486+
return "private" in s3_uri

src/sagemaker/serve/utils/exceptions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,12 @@ class TaskNotFoundException(ModelBuilderException):
6969

7070
def __init__(self, message):
7171
super().__init__(message=message)
72+
73+
74+
class JumpStartGatedModelNotSupported(ModelBuilderException):
75+
"""Raise when deploying JumpStart gated model locally"""
76+
77+
fmt = "Error Message: {message}"
78+
79+
def __init__(self, message):
80+
super().__init__(message=message)

0 commit comments

Comments
 (0)