13
13
from __future__ import absolute_import
14
14
15
15
import pytest
16
- from mock import Mock
16
+ from mock import Mock , ANY
17
17
from sagemaker .tensorflow import TensorFlow
18
18
19
19
24
24
INSTANCE_COUNT = 1
25
25
INSTANCE_TYPE_GPU = "ml.p2.xlarge"
26
26
INSTANCE_TYPE_CPU = "ml.m4.xlarge"
27
- CPU_IMAGE_NAME = "sagemaker- tensorflow-py2-cpu "
28
- GPU_IMAGE_NAME = "sagemaker-tensorflow-py2-gpu "
27
+ REPOSITORY = "tensorflow-inference "
28
+ PROCESSOR = "cpu "
29
29
REGION = "us-west-2"
30
- IMAGE_URI_FORMAT_STRING = "520713654638 .dkr.ecr.{}.amazonaws.com/{}:{}- {}-{}"
30
+ IMAGE_URI_FORMAT_STRING = "763104351884 .dkr.ecr.{}.amazonaws.com/{}:{}-{}"
31
31
REGION = "us-west-2"
32
32
ROLE = "SagemakerRole"
33
33
SOURCE_DIR = "s3://fefergerger"
34
+ ENDPOINT_DESC = {"EndpointConfigName" : "test-endpoint" }
35
+ ENDPOINT_CONFIG_DESC = {"ProductionVariants" : [{"ModelName" : "model-1" }, {"ModelName" : "model-2" }]}
34
36
35
37
36
38
@pytest .fixture ()
@@ -39,48 +41,64 @@ def sagemaker_session():
39
41
ims = Mock (
40
42
name = "sagemaker_session" ,
41
43
boto_session = boto_mock ,
44
+ boto_region_name = REGION ,
42
45
config = None ,
43
46
local_mode = False ,
44
- region_name = REGION ,
47
+ s3_resource = None ,
48
+ s3_client = None ,
45
49
)
46
50
ims .default_bucket = Mock (name = "default_bucket" , return_value = BUCKET_NAME )
47
51
ims .expand_role = Mock (name = "expand_role" , return_value = ROLE )
48
52
ims .sagemaker_client .describe_training_job = Mock (
49
53
return_value = {"ModelArtifacts" : {"S3ModelArtifacts" : "s3://m/m.tar.gz" }}
50
54
)
55
+ ims .sagemaker_client .describe_endpoint = Mock (return_value = ENDPOINT_DESC )
56
+ ims .sagemaker_client .describe_endpoint_config = Mock (return_value = ENDPOINT_CONFIG_DESC )
51
57
return ims
52
58
53
59
60
+ def test_model_dir_false (sagemaker_session ):
61
+ estimator = TensorFlow (
62
+ entry_point = SCRIPT ,
63
+ source_dir = SOURCE_DIR ,
64
+ role = ROLE ,
65
+ framework_version = "2.3.0" ,
66
+ py_version = "py37" ,
67
+ instance_type = "ml.m4.xlarge" ,
68
+ instance_count = 1 ,
69
+ model_dir = False ,
70
+ )
71
+ estimator .hyperparameters ()
72
+ assert estimator .model_dir is False
73
+
74
+
54
75
# Test that we pass all necessary fields from estimator to the session when we call deploy
55
- def test_deploy (sagemaker_session , tf_version ):
76
+ def test_deploy (sagemaker_session ):
56
77
estimator = TensorFlow (
57
78
entry_point = SCRIPT ,
58
79
source_dir = SOURCE_DIR ,
59
80
role = ROLE ,
60
- framework_version = tf_version ,
61
- train_instance_count = 2 ,
62
- train_instance_type = INSTANCE_TYPE_CPU ,
81
+ framework_version = "2.3.0" ,
82
+ py_version = "py37" ,
83
+ instance_count = 2 ,
84
+ instance_type = INSTANCE_TYPE_CPU ,
63
85
sagemaker_session = sagemaker_session ,
64
86
base_job_name = "test-cifar" ,
65
87
)
66
88
67
89
estimator .fit ("s3://mybucket/train" )
68
- print ("job succeeded: {}" .format (estimator .latest_training_job .name ))
69
90
70
91
estimator .deploy (initial_instance_count = 1 , instance_type = INSTANCE_TYPE_CPU )
71
- image = IMAGE_URI_FORMAT_STRING .format (REGION , CPU_IMAGE_NAME , tf_version , "cpu " , "py2" )
92
+ image = IMAGE_URI_FORMAT_STRING .format (REGION , REPOSITORY , "2.3.0 " , PROCESSOR )
72
93
sagemaker_session .create_model .assert_called_with (
73
- estimator . _current_job_name ,
94
+ ANY ,
74
95
ROLE ,
75
96
{
76
- "Environment" : {
77
- "SAGEMAKER_CONTAINER_LOG_LEVEL" : "20" ,
78
- "SAGEMAKER_SUBMIT_DIRECTORY" : SOURCE_DIR ,
79
- "SAGEMAKER_REQUIREMENTS" : "" ,
80
- "SAGEMAKER_REGION" : REGION ,
81
- "SAGEMAKER_PROGRAM" : SCRIPT ,
82
- },
83
97
"Image" : image ,
98
+ "Environment" : {"SAGEMAKER_TFS_NGINX_LOGLEVEL" : "info" },
84
99
"ModelDataUrl" : "s3://m/m.tar.gz" ,
85
100
},
101
+ vpc_config = None ,
102
+ enable_network_isolation = False ,
103
+ tags = None ,
86
104
)
0 commit comments