Skip to content

JumpStart Gated Model Support in ModelBuilder Local Modes #4567

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Apr 11, 2024
21 changes: 21 additions & 0 deletions src/sagemaker/serve/builder/jumpstart_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,11 @@ def _build_for_jumpstart(self):

logger.info("JumpStart ID %s is packaged with Image URI: %s", self.model, image_uri)

if self._is_gated_model(pysdk_model) and self.mode != Mode.SAGEMAKER_ENDPOINT:
raise ValueError(
"JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode."
)

if "djl-inference" in image_uri:
logger.info("Building for DJL JumpStart Model ID...")
self.model_server = ModelServer.DJL_SERVING
Expand All @@ -469,3 +474,19 @@ def _build_for_jumpstart(self):
)

return self.pysdk_model

def _is_gated_model(self, model) -> bool:
"""Determine if ``this`` Model is Gated

Args:
model (Model): Jumpstart Model
Returns:
bool: ``True`` if ``this`` Model is Gated
"""
s3_uri = model.model_data
if isinstance(s3_uri, dict):
s3_uri = s3_uri.get("S3DataSource").get("S3Uri")

if s3_uri is None:
return False
return "private" in s3_uri
8 changes: 5 additions & 3 deletions tests/integ/sagemaker/serve/test_serve_js_happy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@
SAMPLE_RESPONSE = [
{"generated_text": "Hello, I'm a language model, and I'm here to help you with your English."}
]
JS_MODEL_ID = "huggingface-textgeneration1-gpt-neo-125m-fp16"
JS_GATED_MODEL_ID = "meta-textgeneration-llama-2-7b-f"
ROLE_NAME = "SageMakerRole"


@pytest.fixture
def happy_model_builder(sagemaker_session):
iam_client = sagemaker_session.boto_session.client("iam")
return ModelBuilder(
model=JS_MODEL_ID,
model=JS_GATED_MODEL_ID,
schema_builder=SchemaBuilder(SAMPLE_PROMPT, SAMPLE_RESPONSE),
role_arn=iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"],
sagemaker_session=sagemaker_session,
Expand All @@ -59,7 +59,9 @@ def test_happy_tgi_sagemaker_endpoint(happy_model_builder, gpu_instance_type):
with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
try:
logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...")
predictor = model.deploy(instance_type=gpu_instance_type, endpoint_logging=False)
predictor = model.deploy(
instance_type=gpu_instance_type, endpoint_logging=False, accept_eula=True
)
logger.info("Endpoint successfully deployed.")

updated_sample_input = happy_model_builder.schema_builder.sample_input
Expand Down
111 changes: 111 additions & 0 deletions tests/unit/sagemaker/serve/builder/test_js_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,19 @@
"123456789712.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.24.0-neuronx-sdk2.14.1"
)

mock_model_data = {
"S3DataSource": {
"S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/huggingface-llm/huggingface-llm-zephyr-7b-gemma"
"/artifacts/inference-prepack/v1.0.0/",
"S3DataType": "S3Prefix",
"CompressionType": "None",
}
}
mock_model_data_str = (
"s3://jumpstart-private-cache-prod-us-west-2/huggingface-llm/huggingface-llm-zephyr-7b-gemma"
"/artifacts/inference-prepack/v1.0.0/"
)


class TestJumpStartBuilder(unittest.TestCase):
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
Expand Down Expand Up @@ -527,3 +540,101 @@ def test_tune_for_djl_js_endpoint_mode_ex(

tuned_model = model.tune()
assert tuned_model == model

@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
@patch(
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
return_value=True,
)
@patch(
"sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model",
return_value=MagicMock(),
)
@patch(
"sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources",
return_value=({"model_type": "t5", "n_head": 71}, True),
)
@patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024)
@patch(
"sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge"
)
def test_js_gated_model_in_endpoint_mode(
self,
mock_get_nb_instance,
mock_get_ram_usage_mb,
mock_prepare_for_tgi,
mock_pre_trained_model,
mock_is_jumpstart_model,
mock_telemetry,
):
builder = ModelBuilder(
model="facebook/galactica-mock-model-id",
schema_builder=mock_schema_builder,
mode=Mode.SAGEMAKER_ENDPOINT,
)

mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri
mock_pre_trained_model.return_value.model_data = mock_model_data

model = builder.build()

assert model is not None

@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
@patch(
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
return_value=True,
)
@patch(
"sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model",
return_value=MagicMock(),
)
def test_js_gated_model_in_local_mode(
self,
mock_pre_trained_model,
mock_is_jumpstart_model,
mock_telemetry,
):
builder = ModelBuilder(
model="huggingface-llm-zephyr-7b-gemma",
schema_builder=mock_schema_builder,
mode=Mode.LOCAL_CONTAINER,
)

mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri
mock_pre_trained_model.return_value.model_data = mock_model_data_str

self.assertRaisesRegex(
ValueError,
"JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode.",
lambda: builder.build(),
)

@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
@patch(
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
return_value=True,
)
@patch(
"sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model",
return_value=MagicMock(),
)
def test_js_gated_model_ex(
self,
mock_pre_trained_model,
mock_is_jumpstart_model,
mock_telemetry,
):
builder = ModelBuilder(
model="huggingface-llm-zephyr-7b-gemma",
schema_builder=mock_schema_builder,
mode=Mode.LOCAL_CONTAINER,
)

mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri
mock_pre_trained_model.return_value.model_data = None

self.assertRaises(
ValueError,
lambda: builder.build(),
)