Skip to content

Commit 29243ae

Browse files
committed
fix: jumpstart unit tests
1 parent 38f8634 commit 29243ae

File tree

3 files changed

+103
-84
lines changed

3 files changed

+103
-84
lines changed

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 86 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,91 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
SPECIAL_MODEL_SPECS_DICT = {
16+
"huggingface-text2text-flan-t5-xxl-fp16": {
17+
"model_id": "huggingface-text2text-flan-t5-xxl-fp16",
18+
"url": "https://huggingface.co/google/flan-t5-xxl",
19+
"version": "1.0.0",
20+
"min_sdk_version": "2.130.0",
21+
"training_supported": False,
22+
"incremental_training_supported": False,
23+
"hosting_ecr_specs": {
24+
"framework": "pytorch",
25+
"framework_version": "1.12.0",
26+
"py_version": "py38",
27+
"huggingface_transformers_version": "4.17.0",
28+
},
29+
"hosting_artifact_key": "huggingface-infer/infer-huggingface-text2text-flan-t5-xxl-fp16.tar.gz",
30+
"hosting_script_key": "source-directory-tarballs/huggingface/inference/text2text/v1.0.2/sourcedir.tar.gz",
31+
"hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.0/infer-prepack-huggingface-"
32+
"text2text-flan-t5-xxl-fp16.tar.gz",
33+
"hosting_prepacked_artifact_version": "1.0.0",
34+
"inference_vulnerable": False,
35+
"inference_dependencies": [
36+
"accelerate==0.16.0",
37+
"bitsandbytes==0.37.0",
38+
"filelock==3.9.0",
39+
"huggingface-hub==0.12.0",
40+
"regex==2022.7.9",
41+
"tokenizers==0.13.2",
42+
"transformers==4.26.0",
43+
],
44+
"inference_vulnerabilities": [],
45+
"training_vulnerable": False,
46+
"training_dependencies": [],
47+
"training_vulnerabilities": [],
48+
"deprecated": False,
49+
"inference_environment_variables": [
50+
{
51+
"name": "SAGEMAKER_PROGRAM",
52+
"type": "text",
53+
"default": "inference.py",
54+
"scope": "container",
55+
},
56+
{
57+
"name": "SAGEMAKER_SUBMIT_DIRECTORY",
58+
"type": "text",
59+
"default": "/opt/ml/model/code",
60+
"scope": "container",
61+
},
62+
{
63+
"name": "SAGEMAKER_CONTAINER_LOG_LEVEL",
64+
"type": "text",
65+
"default": "20",
66+
"scope": "container",
67+
},
68+
{
69+
"name": "MODEL_CACHE_ROOT",
70+
"type": "text",
71+
"default": "/opt/ml/model",
72+
"scope": "container",
73+
},
74+
{"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"},
75+
{
76+
"name": "SAGEMAKER_MODEL_SERVER_WORKERS",
77+
"type": "text",
78+
"default": "1",
79+
"scope": "container",
80+
},
81+
{
82+
"name": "SAGEMAKER_MODEL_SERVER_TIMEOUT",
83+
"type": "text",
84+
"default": "3600",
85+
"scope": "container",
86+
},
87+
],
88+
"metrics": [],
89+
"default_inference_instance_type": "ml.g5.12xlarge",
90+
"supported_inference_instance_types": [
91+
"ml.g5.12xlarge",
92+
"ml.g5.24xlarge",
93+
"ml.p3.8xlarge",
94+
"ml.p3.16xlarge",
95+
"ml.g4dn.12xlarge",
96+
],
97+
}
98+
}
99+
15100
PROTOTYPICAL_MODEL_SPECS_DICT = {
16101
"pytorch-eqa-bert-base-cased": {
17102
"model_id": "pytorch-eqa-bert-base-cased",
@@ -1070,88 +1155,6 @@
10701155
},
10711156
],
10721157
},
1073-
"huggingface-text2text-flan-t5-xxl-fp16": {
1074-
"model_id": "huggingface-text2text-flan-t5-xxl-fp16",
1075-
"url": "https://huggingface.co/google/flan-t5-xxl",
1076-
"version": "1.0.0",
1077-
"min_sdk_version": "2.130.0",
1078-
"training_supported": False,
1079-
"incremental_training_supported": False,
1080-
"hosting_ecr_specs": {
1081-
"framework": "pytorch",
1082-
"framework_version": "1.12.0",
1083-
"py_version": "py38",
1084-
"huggingface_transformers_version": "4.17.0",
1085-
},
1086-
"hosting_artifact_key": "huggingface-infer/infer-huggingface-text2text-flan-t5-xxl-fp16.tar.gz",
1087-
"hosting_script_key": "source-directory-tarballs/huggingface/inference/text2text/v1.0.2/sourcedir.tar.gz",
1088-
"hosting_prepacked_artifact_key": "huggingface-infer/prepack/v1.0.0/infer-prepack-huggingface-"
1089-
"text2text-flan-t5-xxl-fp16.tar.gz",
1090-
"hosting_prepacked_artifact_version": "1.0.0",
1091-
"inference_vulnerable": False,
1092-
"inference_dependencies": [
1093-
"accelerate==0.16.0",
1094-
"bitsandbytes==0.37.0",
1095-
"filelock==3.9.0",
1096-
"huggingface-hub==0.12.0",
1097-
"regex==2022.7.9",
1098-
"tokenizers==0.13.2",
1099-
"transformers==4.26.0",
1100-
],
1101-
"inference_vulnerabilities": [],
1102-
"training_vulnerable": False,
1103-
"training_dependencies": [],
1104-
"training_vulnerabilities": [],
1105-
"deprecated": False,
1106-
"inference_environment_variables": [
1107-
{
1108-
"name": "SAGEMAKER_PROGRAM",
1109-
"type": "text",
1110-
"default": "inference.py",
1111-
"scope": "container",
1112-
},
1113-
{
1114-
"name": "SAGEMAKER_SUBMIT_DIRECTORY",
1115-
"type": "text",
1116-
"default": "/opt/ml/model/code",
1117-
"scope": "container",
1118-
},
1119-
{
1120-
"name": "SAGEMAKER_CONTAINER_LOG_LEVEL",
1121-
"type": "text",
1122-
"default": "20",
1123-
"scope": "container",
1124-
},
1125-
{
1126-
"name": "MODEL_CACHE_ROOT",
1127-
"type": "text",
1128-
"default": "/opt/ml/model",
1129-
"scope": "container",
1130-
},
1131-
{"name": "SAGEMAKER_ENV", "type": "text", "default": "1", "scope": "container"},
1132-
{
1133-
"name": "SAGEMAKER_MODEL_SERVER_WORKERS",
1134-
"type": "text",
1135-
"default": "1",
1136-
"scope": "container",
1137-
},
1138-
{
1139-
"name": "SAGEMAKER_MODEL_SERVER_TIMEOUT",
1140-
"type": "text",
1141-
"default": "3600",
1142-
"scope": "container",
1143-
},
1144-
],
1145-
"metrics": [],
1146-
"default_inference_instance_type": "ml.g5.12xlarge",
1147-
"supported_inference_instance_types": [
1148-
"ml.g5.12xlarge",
1149-
"ml.g5.24xlarge",
1150-
"ml.p3.8xlarge",
1151-
"ml.p3.16xlarge",
1152-
"ml.g4dn.12xlarge",
1153-
],
1154-
},
11551158
}
11561159

11571160
BASE_SPEC = {
@@ -1175,6 +1178,7 @@
11751178
"training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz",
11761179
"hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz",
11771180
"training_script_key": "source-directory-tarballs/pytorch/transfer_learning/ic/v1.0.0/sourcedir.tar.gz",
1181+
"hosting_prepacked_artifact_key": None,
11781182
"hyperparameters": [
11791183
{
11801184
"name": "epochs",

tests/unit/sagemaker/jumpstart/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
BASE_MANIFEST,
3030
BASE_SPEC,
3131
BASE_HEADER,
32+
SPECIAL_MODEL_SPECS_DICT,
3233
)
3334

3435

@@ -92,6 +93,18 @@ def get_prototype_model_spec(
9293
return specs
9394

9495

96+
def get_special_model_spec(
97+
region: str = None, model_id: str = None, version: str = None
98+
) -> JumpStartModelSpecs:
99+
"""This function mocks cache accessor functions. For this mock,
100+
we only retrieve model specs based on the model ID. This is reserved
101+
for special specs.
102+
"""
103+
104+
specs = JumpStartModelSpecs(SPECIAL_MODEL_SPECS_DICT[model_id])
105+
return specs
106+
107+
95108
def get_spec_from_base_spec(
96109
_obj: JumpStartModelsCache = None,
97110
region: str = None,

tests/unit/sagemaker/model_uris/jumpstart/test_combined_artifact.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
from sagemaker import model_uris
1818
import pytest
1919

20-
from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec
20+
from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec, get_special_model_spec
2121

2222

2323
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
2424
def test_jumpstart_combined_artifacts(patched_get_model_specs):
2525

26-
patched_get_model_specs.side_effect = get_prototype_model_spec
26+
patched_get_model_specs.side_effect = get_special_model_spec
2727

2828
model_id_combined_model_artifact = "huggingface-text2text-flan-t5-xxl-fp16"
2929

@@ -48,6 +48,8 @@ def test_jumpstart_combined_artifacts(patched_get_model_specs):
4848
include_script=True,
4949
)
5050

51+
patched_get_model_specs.side_effect = get_prototype_model_spec
52+
5153
model_id_combined_model_artifact_unsupported = "xgboost-classification-model"
5254

5355
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)