26
26
MODEL_NAME = "{}-{}" .format (MODEL_IMAGE , TIMESTAMP )
27
27
28
28
INSTANCE_COUNT = 2
29
- INSTANCE_TYPE = "c4.4xlarge"
29
+ INSTANCE_TYPE = "ml. c4.4xlarge"
30
30
ROLE = "some-role"
31
31
32
32
BASE_PRODUCTION_VARIANT = {
@@ -43,17 +43,119 @@ def sagemaker_session():
43
43
return Mock ()
44
44
45
45
46
- @patch ("sagemaker.production_variant" )
46
+ def test_prepare_container_def ():
47
+ env = {"FOO" : "BAR" }
48
+ model = Model (MODEL_DATA , MODEL_IMAGE , env = env )
49
+
50
+ container_def = model .prepare_container_def (INSTANCE_TYPE , "ml.eia.medium" )
51
+
52
+ expected = {"Image" : MODEL_IMAGE , "Environment" : env , "ModelDataUrl" : MODEL_DATA }
53
+ assert expected == container_def
54
+
55
+
47
56
@patch ("sagemaker.model.Model.prepare_container_def" )
48
57
@patch ("sagemaker.utils.name_from_image" )
49
- def test_deploy (name_from_image , prepare_container_def , production_variant , sagemaker_session ):
58
+ def test_create_sagemaker_model (name_from_image , prepare_container_def , sagemaker_session ):
59
+ name_from_image .return_value = MODEL_NAME
60
+
61
+ container_def = {"Image" : MODEL_IMAGE , "Environment" : {}, "ModelDataUrl" : MODEL_DATA }
62
+ prepare_container_def .return_value = container_def
63
+
64
+ model = Model (MODEL_DATA , MODEL_IMAGE , sagemaker_session = sagemaker_session )
65
+ model ._create_sagemaker_model (INSTANCE_TYPE )
66
+
67
+ prepare_container_def .assert_called_with (INSTANCE_TYPE , accelerator_type = None )
68
+ name_from_image .assert_called_with (MODEL_IMAGE )
69
+
70
+ sagemaker_session .create_model .assert_called_with (
71
+ MODEL_NAME , None , container_def , vpc_config = None , enable_network_isolation = False , tags = None
72
+ )
73
+
74
+
75
+ @patch ("sagemaker.utils.name_from_image" , Mock ())
76
+ @patch ("sagemaker.model.Model.prepare_container_def" )
77
+ def test_create_sagemaker_model_accelerator_type (prepare_container_def , sagemaker_session ):
78
+ model = Model (MODEL_DATA , MODEL_IMAGE , sagemaker_session = sagemaker_session )
79
+
80
+ accelerator_type = "ml.eia.medium"
81
+ model ._create_sagemaker_model (INSTANCE_TYPE , accelerator_type = accelerator_type )
82
+
83
+ prepare_container_def .assert_called_with (INSTANCE_TYPE , accelerator_type = accelerator_type )
84
+
85
+
86
+ @patch ("sagemaker.model.Model.prepare_container_def" )
87
+ @patch ("sagemaker.utils.name_from_image" )
88
+ def test_create_sagemaker_model_tags (name_from_image , prepare_container_def , sagemaker_session ):
89
+ container_def = {"Image" : MODEL_IMAGE , "Environment" : {}, "ModelDataUrl" : MODEL_DATA }
90
+ prepare_container_def .return_value = container_def
91
+
50
92
name_from_image .return_value = MODEL_NAME
51
93
94
+ model = Model (MODEL_DATA , MODEL_IMAGE , sagemaker_session = sagemaker_session )
95
+
96
+ tags = {"Key" : "foo" , "Value" : "bar" }
97
+ model ._create_sagemaker_model (INSTANCE_TYPE , tags = tags )
98
+
99
+ sagemaker_session .create_model .assert_called_with (
100
+ MODEL_NAME , None , container_def , vpc_config = None , enable_network_isolation = False , tags = tags
101
+ )
102
+
103
+
104
+ @patch ("sagemaker.model.Model.prepare_container_def" )
105
+ @patch ("sagemaker.utils.name_from_image" )
106
+ def test_create_sagemaker_model_optional_model_params (
107
+ name_from_image , prepare_container_def , sagemaker_session
108
+ ):
52
109
container_def = {"Image" : MODEL_IMAGE , "Environment" : {}, "ModelDataUrl" : MODEL_DATA }
53
110
prepare_container_def .return_value = container_def
54
111
112
+ vpc_config = {"Subnets" : ["123" ], "SecurityGroupIds" : ["456" , "789" ]}
113
+
114
+ model = Model (
115
+ MODEL_DATA ,
116
+ MODEL_IMAGE ,
117
+ name = MODEL_NAME ,
118
+ role = ROLE ,
119
+ vpc_config = vpc_config ,
120
+ enable_network_isolation = True ,
121
+ sagemaker_session = sagemaker_session ,
122
+ )
123
+ model ._create_sagemaker_model (INSTANCE_TYPE )
124
+
125
+ name_from_image .assert_not_called ()
126
+
127
+ sagemaker_session .create_model .assert_called_with (
128
+ MODEL_NAME ,
129
+ ROLE ,
130
+ container_def ,
131
+ vpc_config = vpc_config ,
132
+ enable_network_isolation = True ,
133
+ tags = None ,
134
+ )
135
+
136
+
137
+ @patch ("sagemaker.session.Session" )
138
+ @patch ("sagemaker.local.LocalSession" )
139
+ def test_create_sagemaker_model_creates_correct_session (local_session , session ):
140
+ model = Model (MODEL_DATA , MODEL_IMAGE )
141
+ model ._create_sagemaker_model ("local" )
142
+ assert model .sagemaker_session == local_session .return_value
143
+
144
+ model = Model (MODEL_DATA , MODEL_IMAGE )
145
+ model ._create_sagemaker_model ("ml.m5.xlarge" )
146
+ assert model .sagemaker_session == session .return_value
147
+
148
+
149
+ @patch ("sagemaker.production_variant" )
150
+ @patch ("sagemaker.model.Model.prepare_container_def" )
151
+ @patch ("sagemaker.utils.name_from_image" )
152
+ def test_deploy (name_from_image , prepare_container_def , production_variant , sagemaker_session ):
153
+ name_from_image .return_value = MODEL_NAME
55
154
production_variant .return_value = BASE_PRODUCTION_VARIANT
56
155
156
+ container_def = {"Image" : MODEL_IMAGE , "Environment" : {}, "ModelDataUrl" : MODEL_DATA }
157
+ prepare_container_def .return_value = container_def
158
+
57
159
model = Model (MODEL_DATA , MODEL_IMAGE , role = ROLE , sagemaker_session = sagemaker_session )
58
160
model .deploy (instance_type = INSTANCE_TYPE , initial_instance_count = INSTANCE_COUNT )
59
161
@@ -223,7 +325,7 @@ def test_deploy_data_capture_config(production_variant, sagemaker_session):
223
325
224
326
@patch ("sagemaker.session.Session" )
225
327
@patch ("sagemaker.local.LocalSession" )
226
- def test_deploy_creates_correct_session (local_session , session , tmpdir ):
328
+ def test_deploy_creates_correct_session (local_session , session ):
227
329
# We expect a LocalSession when deploying to instance_type = 'local'
228
330
model = Model (MODEL_DATA , MODEL_IMAGE , role = ROLE )
229
331
model .deploy (endpoint_name = "blah" , instance_type = "local" , initial_instance_count = 1 )
@@ -356,7 +458,6 @@ def test_model_create_transformer_network_isolation(create_sagemaker_model, sage
356
458
357
459
@patch ("sagemaker.session.Session" )
358
460
@patch ("sagemaker.local.LocalSession" )
359
- @patch ("sagemaker.fw_utils.tar_and_upload_dir" , Mock ())
360
461
def test_transformer_creates_correct_session (local_session , session ):
361
462
model = Model (MODEL_DATA , MODEL_IMAGE , sagemaker_session = None )
362
463
transformer = model .transformer (instance_count = 1 , instance_type = "local" )
0 commit comments