Skip to content

Commit 718777b

Browse files
author
Jonathan Makunga
committed
Add tests
1 parent 8ea3727 commit 718777b

File tree

4 files changed

+130
-17
lines changed

4 files changed

+130
-17
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
LocalModelInvocationException,
3232
LocalModelLoadException,
3333
SkipTuningComboException,
34-
JumpStartGatedModelNotSupported,
3534
)
3635
from sagemaker.serve.utils.predictors import (
3736
DjlLocalModePredictor,
@@ -444,9 +443,9 @@ def _build_for_jumpstart(self):
444443

445444
logger.info("JumpStart ID %s is packaged with Image URI: %s", self.model, image_uri)
446445

447-
if self._is_gated_model() and self.mode != Mode.SAGEMAKER_ENDPOINT:
448-
raise JumpStartGatedModelNotSupported(
449-
"JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode"
446+
if self._is_gated_model(pysdk_model) and self.mode != Mode.SAGEMAKER_ENDPOINT:
447+
raise ValueError(
448+
"JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode."
450449
)
451450

452451
if "djl-inference" in image_uri:
@@ -476,11 +475,9 @@ def _build_for_jumpstart(self):
476475

477476
return self.pysdk_model
478477

479-
def _is_gated_model(self) -> bool:
478+
def _is_gated_model(self, model) -> bool:
480479
"""Determine if ``this`` Model is Gated"""
481-
482-
s3_uri = self.pysdk_model.model_data
480+
s3_uri = model.model_data
483481
if isinstance(s3_uri, dict):
484482
s3_uri = s3_uri.get("S3DataSource").get("S3Uri")
485-
486483
return "private" in s3_uri

src/sagemaker/serve/utils/exceptions.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,3 @@ class TaskNotFoundException(ModelBuilderException):
6969

7070
def __init__(self, message):
7171
super().__init__(message=message)
72-
73-
74-
class JumpStartGatedModelNotSupported(ModelBuilderException):
75-
"""Raise when deploying JumpStart gated model locally"""
76-
77-
fmt = "Error Message: {message}"
78-
79-
def __init__(self, message):
80-
super().__init__(message=message)

tests/integ/sagemaker/serve/test_serve_js_happy.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import pytest
1616

17+
from sagemaker.serve import Mode
1718
from sagemaker.serve.builder.model_builder import ModelBuilder
1819
from sagemaker.serve.builder.schema_builder import SchemaBuilder
1920
from tests.integ.sagemaker.serve.constants import (
@@ -32,6 +33,7 @@
3233
{"generated_text": "Hello, I'm a language model, and I'm here to help you with your English."}
3334
]
3435
JS_MODEL_ID = "huggingface-textgeneration1-gpt-neo-125m-fp16"
36+
JS_Gated_MODEL_ID = "huggingface-llm-zephyr-7b-gemma"
3537
ROLE_NAME = "SageMakerRole"
3638

3739

@@ -46,6 +48,17 @@ def happy_model_builder(sagemaker_session):
4648
)
4749

4850

51+
@pytest.fixture
52+
def happy_model_builder_gated_model(sagemaker_session):
53+
iam_client = sagemaker_session.boto_session.client("iam")
54+
return ModelBuilder(
55+
model=JS_Gated_MODEL_ID,
56+
schema_builder=SchemaBuilder(SAMPLE_PROMPT, SAMPLE_RESPONSE),
57+
role_arn=iam_client.get_role(RoleName=ROLE_NAME)["Role"]["Arn"],
58+
sagemaker_session=sagemaker_session,
59+
)
60+
61+
4962
@pytest.mark.skipif(
5063
PYTHON_VERSION_IS_NOT_310,
5164
reason="The goal of these test are to test the serving components of our feature",
@@ -75,3 +88,33 @@ def test_happy_tgi_sagemaker_endpoint(happy_model_builder, gpu_instance_type):
7588
)
7689
if caught_ex:
7790
raise caught_ex
91+
92+
93+
@pytest.mark.skipif(
94+
PYTHON_VERSION_IS_NOT_310,
95+
reason="The goal of these test are to test the serving components of our feature",
96+
)
97+
@pytest.mark.slow_test
98+
def test_happy_js_gated_model(happy_model_builder_gated_model, gpu_instance_type):
99+
logger.info("Running in SAGEMAKER_ENDPOINT mode...")
100+
happy_model_builder_gated_model.build()
101+
102+
103+
@pytest.mark.skipif(
104+
PYTHON_VERSION_IS_NOT_310,
105+
reason="The goal of these test are to test the serving components of our feature",
106+
)
107+
@pytest.mark.slow_test
108+
def test_js_gated_model_throws(happy_model_builder_gated_model, gpu_instance_type):
109+
logger.info("Running in Local mode...")
110+
model_builder = ModelBuilder(
111+
model=JS_Gated_MODEL_ID,
112+
schema_builder=SchemaBuilder(SAMPLE_PROMPT, SAMPLE_RESPONSE),
113+
mode=Mode.LOCAL_CONTAINER,
114+
)
115+
116+
with pytest.raises(
117+
ValueError,
118+
match="JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode.",
119+
):
120+
model_builder.build()

tests/unit/sagemaker/serve/builder/test_js_builder.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,19 @@
6767
"123456789712.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.24.0-neuronx-sdk2.14.1"
6868
)
6969

70+
mock_model_data = {
71+
"S3DataSource": {
72+
"S3Uri": "s3://jumpstart-private-cache-prod-us-west-2/huggingface-llm/huggingface-llm-zephyr-7b-gemma"
73+
"/artifacts/inference-prepack/v1.0.0/",
74+
"S3DataType": "S3Prefix",
75+
"CompressionType": "None",
76+
}
77+
}
78+
mock_model_data_str = (
79+
"s3://jumpstart-private-cache-prod-us-west-2/huggingface-llm/huggingface-llm-zephyr-7b-gemma"
80+
"/artifacts/inference-prepack/v1.0.0/"
81+
)
82+
7083

7184
class TestJumpStartBuilder(unittest.TestCase):
7285
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
@@ -527,3 +540,72 @@ def test_tune_for_djl_js_endpoint_mode_ex(
527540

528541
tuned_model = model.tune()
529542
assert tuned_model == model
543+
544+
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
545+
@patch(
546+
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
547+
return_value=True,
548+
)
549+
@patch(
550+
"sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model",
551+
return_value=MagicMock(),
552+
)
553+
@patch(
554+
"sagemaker.serve.builder.jumpstart_builder.prepare_tgi_js_resources",
555+
return_value=({"model_type": "t5", "n_head": 71}, True),
556+
)
557+
@patch("sagemaker.serve.builder.jumpstart_builder._get_ram_usage_mb", return_value=1024)
558+
@patch(
559+
"sagemaker.serve.builder.jumpstart_builder._get_nb_instance", return_value="ml.g5.24xlarge"
560+
)
561+
def test_js_gated_model_in_endpoint_mode(
562+
self,
563+
mock_get_nb_instance,
564+
mock_get_ram_usage_mb,
565+
mock_prepare_for_tgi,
566+
mock_pre_trained_model,
567+
mock_is_jumpstart_model,
568+
mock_telemetry,
569+
):
570+
builder = ModelBuilder(
571+
model="facebook/galactica-mock-model-id",
572+
schema_builder=mock_schema_builder,
573+
mode=Mode.SAGEMAKER_ENDPOINT,
574+
)
575+
576+
mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri
577+
mock_pre_trained_model.return_value.model_data = mock_model_data
578+
579+
model = builder.build()
580+
581+
assert model is not None
582+
583+
@patch("sagemaker.serve.builder.jumpstart_builder._capture_telemetry", side_effect=None)
584+
@patch(
585+
"sagemaker.serve.builder.jumpstart_builder.JumpStart._is_jumpstart_model_id",
586+
return_value=True,
587+
)
588+
@patch(
589+
"sagemaker.serve.builder.jumpstart_builder.JumpStart._create_pre_trained_js_model",
590+
return_value=MagicMock(),
591+
)
592+
def test_js_gated_model_in_local_mode(
593+
self,
594+
mock_pre_trained_model,
595+
mock_is_jumpstart_model,
596+
mock_telemetry,
597+
):
598+
builder = ModelBuilder(
599+
model="huggingface-llm-zephyr-7b-gemma",
600+
schema_builder=mock_schema_builder,
601+
mode=Mode.LOCAL_CONTAINER,
602+
)
603+
604+
mock_pre_trained_model.return_value.image_uri = mock_tgi_image_uri
605+
mock_pre_trained_model.return_value.model_data = mock_model_data_str
606+
607+
self.assertRaisesRegex(
608+
ValueError,
609+
"JumpStart Gated Models are only supported in SAGEMAKER_ENDPOINT mode.",
610+
lambda: builder.build(),
611+
)

0 commit comments

Comments
 (0)