Skip to content

Commit 9369e99

Browse files
committed
chore: add jumpstart llama 2 tests
1 parent 1d886c4 commit 9369e99

File tree

3 files changed

+166
-0
lines changed

3 files changed

+166
-0
lines changed

tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414
import os
1515
import time
1616

17+
import pytest
18+
19+
import tests.integ
20+
1721
from sagemaker.jumpstart.model import JumpStartModel
1822
from tests.integ.sagemaker.jumpstart.constants import (
1923
ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID,
@@ -29,6 +33,8 @@
2933

3034
MAX_INIT_TIME_SECONDS = 5
3135

36+
MODEL_PACKAGE_ARN_SUPPORTED_REGIONS = {"us-west-2", "us-east-1"}
37+
3238

3339
def test_non_prepacked_jumpstart_model(setup):
3440

@@ -73,6 +79,35 @@ def test_prepacked_jumpstart_model(setup):
7379
assert response is not None
7480

7581

82+
@pytest.mark.skipif(
83+
tests.integ.test_region() not in MODEL_PACKAGE_ARN_SUPPORTED_REGIONS,
84+
reason=f"JumpStart Model Package models unavailable in {tests.integ.test_region()}.",
85+
)
86+
def test_model_package_arn_jumpstart_model(setup):
87+
88+
model_id = "meta-textgeneration-llama-2-7b"
89+
90+
model = JumpStartModel(
91+
model_id=model_id,
92+
role=get_sm_session().get_caller_identity_arn(),
93+
sagemaker_session=get_sm_session(),
94+
)
95+
96+
# uses ml.g5.2xlarge instance
97+
predictor = model.deploy(
98+
tags=[{"Key": JUMPSTART_TAG, "Value": os.environ[ENV_VAR_JUMPSTART_SDK_TEST_SUITE_ID]}],
99+
)
100+
101+
payload = {
102+
"inputs": "some-payload",
103+
"parameters": {"max_new_tokens": 256, "top_p": 0.9, "temperature": 0.6},
104+
}
105+
106+
response = predictor.predict(payload, custom_attributes="accept_eula=true")
107+
108+
assert response is not None
109+
110+
76111
def test_instatiating_model_not_too_slow(setup):
77112

78113
model_id = "catboost-regression-model"

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,63 @@
1414

1515

1616
SPECIAL_MODEL_SPECS_DICT = {
17+
"js-model-package-arn": {
18+
"model_id": "meta-textgeneration-llama-2-7b-f",
19+
"url": "https://ai.meta.com/resources/models-and-libraries/llama-downloads/",
20+
"version": "1.0.0",
21+
"min_sdk_version": "2.173.0",
22+
"training_supported": False,
23+
"incremental_training_supported": False,
24+
"hosting_ecr_specs": {
25+
"framework": "pytorch",
26+
"framework_version": "1.12.0",
27+
"py_version": "py38",
28+
},
29+
"hosting_artifact_key": "meta-infer/infer-meta-textgeneration-llama-2-7b-f.tar.gz",
30+
"hosting_script_key": "source-directory-tarballs/meta/inference/textgeneration/v1.0.0/sourcedir.tar.gz",
31+
"hosting_eula_key": "fmhMetadata/eula/llamaEula.txt",
32+
"hosting_model_package_arns": {
33+
"us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/"
34+
"llama2-7b-f-e46eb8a833643ed58aaccd81498972c3",
35+
"us-east-1": "arn:aws:sagemaker:us-east-1:865070037744:model-package/"
36+
"llama2-7b-f-e46eb8a833643ed58aaccd81498972c3",
37+
},
38+
"inference_vulnerable": False,
39+
"inference_dependencies": [],
40+
"inference_vulnerabilities": [],
41+
"training_vulnerable": False,
42+
"training_dependencies": [],
43+
"training_vulnerabilities": [],
44+
"deprecated": False,
45+
"inference_environment_variables": [],
46+
"metrics": [],
47+
"default_inference_instance_type": "ml.g5.2xlarge",
48+
"supported_inference_instance_types": [
49+
"ml.g5.2xlarge",
50+
"ml.g5.4xlarge",
51+
"ml.g5.8xlarge",
52+
"ml.g5.12xlarge",
53+
"ml.g5.24xlarge",
54+
"ml.g5.48xlarge",
55+
"ml.p4d.24xlarge",
56+
],
57+
"model_kwargs": {},
58+
"deploy_kwargs": {
59+
"model_data_download_timeout": 3600,
60+
"container_startup_health_check_timeout": 3600,
61+
},
62+
"predictor_specs": {
63+
"supported_content_types": ["application/json"],
64+
"supported_accept_types": ["application/json"],
65+
"default_content_type": "application/json",
66+
"default_accept_type": "application/json",
67+
},
68+
"inference_volume_size": 256,
69+
"inference_enable_network_isolation": True,
70+
"validation_supported": False,
71+
"fine_tuning_supported": False,
72+
"resource_name_base": "meta-textgeneration-llama-2-7b-f",
73+
},
1774
"js-trainable-model-prepacked": {
1875
"model_id": "huggingface-text2text-flan-t5-base",
1976
"url": "https://huggingface.co/google/flan-t5-base",

tests/unit/sagemaker/jumpstart/model/test_model.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import Optional, Set
1616
from unittest import mock
1717
import unittest
18+
from mock import MagicMock
1819
import pytest
1920
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
2021
from sagemaker.jumpstart.enums import JumpStartScriptScope
@@ -567,6 +568,79 @@ def test_model_id_not_found_refeshes_cach_inference(
567568
]
568569
)
569570

571+
@mock.patch("sagemaker.jumpstart.model.is_valid_model_id")
572+
@mock.patch("sagemaker.jumpstart.factory.model.Session")
573+
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
574+
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
575+
def test_jumpstart_model_package_arn(
576+
self,
577+
mock_get_model_specs: mock.Mock,
578+
mock_session: mock.Mock,
579+
mock_is_valid_model_id: mock.Mock,
580+
):
581+
582+
mock_is_valid_model_id.return_value = True
583+
584+
model_id, _ = "js-model-package-arn", "*"
585+
586+
mock_get_model_specs.side_effect = get_special_model_spec
587+
588+
mock_session.return_value = MagicMock(sagemaker_config={})
589+
590+
model = JumpStartModel(model_id=model_id)
591+
592+
model.deploy()
593+
594+
self.assertEqual(
595+
mock_session.return_value.create_model.call_args[0][2],
596+
{
597+
"ModelPackageName": "arn:aws:sagemaker:us-west-2:594846645681:model-package"
598+
"/llama2-7b-f-e46eb8a833643ed58aaccd81498972c3"
599+
},
600+
)
601+
602+
@mock.patch("sagemaker.jumpstart.model.is_valid_model_id")
603+
@mock.patch("sagemaker.jumpstart.factory.model.Session")
604+
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
605+
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
606+
def test_jumpstart_model_package_arn_override(
607+
self,
608+
mock_get_model_specs: mock.Mock,
609+
mock_session: mock.Mock,
610+
mock_is_valid_model_id: mock.Mock,
611+
):
612+
613+
mock_is_valid_model_id.return_value = True
614+
615+
# arbitrary model without model packarn arn
616+
model_id, _ = "js-trainable-model", "*"
617+
618+
mock_get_model_specs.side_effect = get_special_model_spec
619+
620+
mock_session.return_value = MagicMock(sagemaker_config={})
621+
622+
model_package_arn = (
623+
"arn:aws:sagemaker:us-west-2:867530986753:model-package/"
624+
"llama2-ynnej-f-e46eb8a833643ed58aaccd81498972c3"
625+
)
626+
model = JumpStartModel(model_id=model_id, model_package_arn=model_package_arn)
627+
628+
model.deploy()
629+
630+
self.assertEqual(
631+
mock_session.return_value.create_model.call_args[0][2],
632+
{
633+
"ModelPackageName": model_package_arn,
634+
"Environment": {
635+
"ENDPOINT_SERVER_TIMEOUT": "3600",
636+
"MODEL_CACHE_ROOT": "/opt/ml/model",
637+
"SAGEMAKER_ENV": "1",
638+
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
639+
"SAGEMAKER_PROGRAM": "inference.py",
640+
},
641+
},
642+
)
643+
570644

571645
def test_jumpstart_model_requires_model_id():
572646
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)