Skip to content

Commit d9c8588

Browse files
committed
chore: integration test for gated jumpstart training model
1 parent 67d8faa commit d9c8588

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

tests/integ/sagemaker/jumpstart/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def _to_s3_path(filename: str, s3_prefix: Optional[str]) -> str:
4545
("huggingface-spc-bert-base-cased", "1.2.3"): ("training-datasets/QNLI-tiny/"),
4646
("huggingface-spc-bert-base-cased", "*"): ("training-datasets/QNLI-tiny/"),
4747
("js-trainable-model", "*"): ("training-datasets/QNLI-tiny/"),
48+
("meta-textgeneration-llama-2-7b", "*"): ("training-datasets/sec_amazon/"),
4849
}
4950

5051

tests/integ/sagemaker/jumpstart/estimator/test_jumpstart_estimator.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from __future__ import absolute_import
1414
import os
1515
import time
16+
17+
import pytest
1618
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
1719

1820
from sagemaker.jumpstart.estimator import JumpStartEstimator
@@ -63,6 +65,46 @@ def test_jumpstart_estimator(setup):
6365
assert response is not None
6466

6567

68+
# instance capacity errors require retries
69+
@pytest.mark.flaky(reruns=5, reruns_delay=60)
70+
def test_gated_model_training(setup):
71+
72+
model_id, model_version = "meta-textgeneration-llama-2-7b", "*"
73+
74+
estimator = JumpStartEstimator(
75+
model_id=model_id,
76+
role=get_sm_session().get_caller_identity_arn(),
77+
sagemaker_session=get_sm_session(),
78+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
79+
environment={"accept_eula": "true"},
80+
max_run=259200, # avoid exceeding resource limits
81+
)
82+
83+
# uses ml.g5.12xlarge instance
84+
estimator.fit(
85+
{
86+
"training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/"
87+
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
88+
}
89+
)
90+
91+
# uses ml.g5.2xlarge instance
92+
predictor = estimator.deploy(
93+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
94+
role=get_sm_session().get_caller_identity_arn(),
95+
sagemaker_session=get_sm_session(),
96+
)
97+
98+
payload = {
99+
"inputs": "some-payload",
100+
"parameters": {"max_new_tokens": 256, "top_p": 0.9, "temperature": 0.6},
101+
}
102+
103+
response = predictor.predict(payload, custom_attributes="accept_eula=true")
104+
105+
assert response is not None
106+
107+
66108
def test_instatiating_estimator_not_too_slow(setup):
67109

68110
model_id = "xgboost-classification-model"

0 commit comments

Comments
 (0)