Skip to content

Commit 12fe775

Browse files
authored
fix: attach jumpstart estimator for gated model (#4546)
1 parent cc5d0dc commit 12fe775

File tree

3 files changed

+62
-3
lines changed

3 files changed

+62
-3
lines changed

src/sagemaker/jumpstart/estimator.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from sagemaker.jumpstart.utils import (
3838
validate_model_id_and_get_type,
3939
resolve_model_sagemaker_config_field,
40+
verify_model_region_and_return_specs,
4041
)
4142
from sagemaker.utils import stringify_object, format_tags, Tags
4243
from sagemaker.model_monitor.data_capture_config import DataCaptureConfig
@@ -729,11 +730,27 @@ def attach(
729730

730731
model_version = model_version or "*"
731732

733+
additional_kwargs = {"model_id": model_id, "model_version": model_version}
734+
735+
model_specs = verify_model_region_and_return_specs(
736+
model_id=model_id,
737+
version=model_version,
738+
region=sagemaker_session.boto_region_name,
739+
scope=JumpStartScriptScope.TRAINING,
740+
tolerate_deprecated_model=True, # model is already trained, so tolerate if deprecated
741+
tolerate_vulnerable_model=True, # model is already trained, so tolerate if vulnerable
742+
sagemaker_session=sagemaker_session,
743+
)
744+
745+
# eula was already accepted if the model was successfully trained
746+
if model_specs.is_gated_model():
747+
additional_kwargs.update({"environment": {"accept_eula": "true"}})
748+
732749
return cls._attach(
733750
training_job_name=training_job_name,
734751
sagemaker_session=sagemaker_session,
735752
model_channel_name=model_channel_name,
736-
additional_kwargs={"model_id": model_id, "model_version": model_version},
753+
additional_kwargs=additional_kwargs,
737754
)
738755

739756
def deploy(

tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,16 @@ def test_gated_model_training_v2(setup):
160160
}
161161
)
162162

163+
# test that we can create a JumpStartEstimator from existing job with `attach`
164+
attached_estimator = JumpStartEstimator.attach(
165+
training_job_name=estimator.latest_training_job.name,
166+
model_id=model_id,
167+
model_version=model_version,
168+
sagemaker_session=get_sm_session(),
169+
)
170+
163171
# uses ml.g5.2xlarge instance
164-
predictor = estimator.deploy(
172+
predictor = attached_estimator.deploy(
165173
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
166174
role=get_sm_session().get_caller_identity_arn(),
167175
sagemaker_session=get_sm_session(),
@@ -172,7 +180,7 @@ def test_gated_model_training_v2(setup):
172180
"parameters": {"max_new_tokens": 256, "top_p": 0.9, "temperature": 0.6},
173181
}
174182

175-
response = predictor.predict(payload, custom_attributes="accept_eula=true")
183+
response = predictor.predict(payload)
176184

177185
assert response is not None
178186

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -979,6 +979,40 @@ def test_jumpstart_estimator_tags(
979979
[{"Key": "blah", "Value": "blahagain"}] + js_tags,
980980
)
981981

982+
@mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach")
983+
@mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type")
984+
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
985+
@mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region)
986+
def test_jumpstart_estimator_attach_eula_model(
987+
self,
988+
mock_get_model_specs: mock.Mock,
989+
mock_validate_model_id_and_get_type: mock.Mock,
990+
mock_attach: mock.Mock,
991+
):
992+
993+
mock_validate_model_id_and_get_type.return_value = JumpStartModelType.OPEN_WEIGHTS
994+
995+
mock_get_model_specs.side_effect = get_special_model_spec
996+
997+
mock_session = mock.MagicMock(sagemaker_config={}, boto_region_name="us-west-2")
998+
999+
JumpStartEstimator.attach(
1000+
training_job_name="some-training-job-name",
1001+
model_id="gemma-model",
1002+
sagemaker_session=mock_session,
1003+
)
1004+
1005+
mock_attach.assert_called_once_with(
1006+
training_job_name="some-training-job-name",
1007+
sagemaker_session=mock_session,
1008+
model_channel_name="model",
1009+
additional_kwargs={
1010+
"model_id": "gemma-model",
1011+
"model_version": "*",
1012+
"environment": {"accept_eula": "true"},
1013+
},
1014+
)
1015+
9821016
@mock.patch("sagemaker.jumpstart.estimator.JumpStartEstimator._attach")
9831017
@mock.patch("sagemaker.jumpstart.estimator.get_model_id_version_from_training_job")
9841018
@mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type")

0 commit comments

Comments
 (0)