|
13 | 13 | from __future__ import absolute_import
|
14 | 14 | import os
|
15 | 15 | import time
|
| 16 | + |
| 17 | +import pytest |
16 | 18 | from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
|
17 | 19 |
|
18 | 20 | from sagemaker.jumpstart.estimator import JumpStartEstimator
|
@@ -63,6 +65,46 @@ def test_jumpstart_estimator(setup):
|
63 | 65 | assert response is not None
|
64 | 66 |
|
65 | 67 |
|
| 68 | +# instance capacity errors require retries |
| 69 | +@pytest.mark.flaky(reruns=5, reruns_delay=60) |
| 70 | +def test_gated_model_training(setup): |
| 71 | + |
| 72 | + model_id, model_version = "meta-textgeneration-llama-2-7b", "*" |
| 73 | + |
| 74 | + estimator = JumpStartEstimator( |
| 75 | + model_id=model_id, |
| 76 | + role=get_sm_session().get_caller_identity_arn(), |
| 77 | + sagemaker_session=get_sm_session(), |
| 78 | + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], |
| 79 | + environment={"accept_eula": "true"}, |
| 80 | + max_run=259200, # avoid exceeding resource limits |
| 81 | + ) |
| 82 | + |
| 83 | + # uses ml.g5.12xlarge instance |
| 84 | + estimator.fit( |
| 85 | + { |
| 86 | + "training": f"s3://{get_jumpstart_content_bucket(JUMPSTART_DEFAULT_REGION_NAME)}/" |
| 87 | + f"{get_training_dataset_for_model_and_version(model_id, model_version)}", |
| 88 | + } |
| 89 | + ) |
| 90 | + |
| 91 | + # uses ml.g5.2xlarge instance |
| 92 | + predictor = estimator.deploy( |
| 93 | + tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}], |
| 94 | + role=get_sm_session().get_caller_identity_arn(), |
| 95 | + sagemaker_session=get_sm_session(), |
| 96 | + ) |
| 97 | + |
| 98 | + payload = { |
| 99 | + "inputs": "some-payload", |
| 100 | + "parameters": {"max_new_tokens": 256, "top_p": 0.9, "temperature": 0.6}, |
| 101 | + } |
| 102 | + |
| 103 | + response = predictor.predict(payload, custom_attributes="accept_eula=true") |
| 104 | + |
| 105 | + assert response is not None |
| 106 | + |
| 107 | + |
66 | 108 | def test_instatiating_estimator_not_too_slow(setup):
|
67 | 109 |
|
68 | 110 | model_id = "xgboost-classification-model"
|
|
0 commit comments