29
29
REGION = "us-west-2"
30
30
SCRIPT_PATH = "script.py"
31
31
TIMESTAMP = "2017-10-10-14-14-15"
32
+ ECR_PREFIX_FORMAT = "{}.dkr.ecr.mars-south-3.amazonaws.com"
32
33
34
+ MOCK_ACCOUNT = "520713654638"
33
35
MOCK_FRAMEWORK = "mlfw"
34
36
MOCK_REGION = "mars-south-3"
35
37
MOCK_ACCELERATOR = "eia1.medium"
@@ -165,7 +167,9 @@ def sagemaker_session():
165
167
return session_mock
166
168
167
169
168
- def test_create_image_uri_cpu ():
170
+ @patch ("sagemaker.fw_utils.get_ecr_image_uri_prefix" )
171
+ def test_create_image_uri_cpu (ecr_prefix ):
172
+ ecr_prefix .return_value = ECR_PREFIX_FORMAT .format ("23" )
169
173
image_uri = fw_utils .create_image_uri (
170
174
MOCK_REGION , MOCK_FRAMEWORK , "ml.c4.large" , "1.0rc" , "py2" , "23"
171
175
)
@@ -176,20 +180,23 @@ def test_create_image_uri_cpu():
176
180
)
177
181
assert image_uri == "23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-cpu-py2"
178
182
183
+ ecr_prefix .return_value = "246785580436.dkr.ecr.us-gov-west-1.amazonaws.com"
179
184
image_uri = fw_utils .create_image_uri (
180
185
"us-gov-west-1" , MOCK_FRAMEWORK , "ml.c4.large" , "1.0rc" , "py2"
181
186
)
182
187
assert (
183
188
image_uri == "246785580436.dkr.ecr.us-gov-west-1.amazonaws.com/sagemaker-mlfw:1.0rc-cpu-py2"
184
189
)
185
190
191
+ ecr_prefix .return_value = "744548109606.dkr.ecr.us-iso-east-1.c2s.ic.gov"
186
192
image_uri = fw_utils .create_image_uri (
187
193
"us-iso-east-1" , MOCK_FRAMEWORK , "ml.c4.large" , "1.0rc" , "py2"
188
194
)
189
195
assert image_uri == "744548109606.dkr.ecr.us-iso-east-1.c2s.ic.gov/sagemaker-mlfw:1.0rc-cpu-py2"
190
196
191
197
192
- def test_create_image_uri_no_python ():
198
+ @patch ("sagemaker.fw_utils.get_ecr_image_uri_prefix" , return_value = ECR_PREFIX_FORMAT .format ("23" ))
199
+ def test_create_image_uri_no_python (ecr_prefix ):
193
200
image_uri = fw_utils .create_image_uri (
194
201
MOCK_REGION , MOCK_FRAMEWORK , "ml.c4.large" , "1.0rc" , account = "23"
195
202
)
@@ -201,7 +208,8 @@ def test_create_image_uri_bad_python():
201
208
fw_utils .create_image_uri (MOCK_REGION , MOCK_FRAMEWORK , "ml.c4.large" , "1.0rc" , "py0" )
202
209
203
210
204
- def test_create_image_uri_gpu ():
211
+ @patch ("sagemaker.fw_utils.get_ecr_image_uri_prefix" , return_value = ECR_PREFIX_FORMAT .format ("23" ))
212
+ def test_create_image_uri_gpu (ecr_prefix ):
205
213
image_uri = fw_utils .create_image_uri (
206
214
MOCK_REGION , MOCK_FRAMEWORK , "ml.p3.2xlarge" , "1.0rc" , "py3" , "23"
207
215
)
@@ -213,7 +221,8 @@ def test_create_image_uri_gpu():
213
221
assert image_uri == "23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-gpu-py3"
214
222
215
223
216
- def test_create_image_uri_accelerator_tfs ():
224
+ @patch ("sagemaker.fw_utils.get_ecr_image_uri_prefix" , return_value = ECR_PREFIX_FORMAT .format ("23" ))
225
+ def test_create_image_uri_accelerator_tfs (ecr_prefix ):
217
226
image_uri = fw_utils .create_image_uri (
218
227
MOCK_REGION ,
219
228
"tensorflow-serving" ,
@@ -228,7 +237,11 @@ def test_create_image_uri_accelerator_tfs():
228
237
)
229
238
230
239
231
- def test_create_image_uri_default_account ():
240
+ @patch (
241
+ "sagemaker.fw_utils.get_ecr_image_uri_prefix" ,
242
+ return_value = ECR_PREFIX_FORMAT .format (MOCK_ACCOUNT ),
243
+ )
244
+ def test_create_image_uri_default_account (ecr_prefix ):
232
245
image_uri = fw_utils .create_image_uri (
233
246
MOCK_REGION , MOCK_FRAMEWORK , "ml.p3.2xlarge" , "1.0rc" , "py3"
234
247
)
@@ -511,7 +524,11 @@ def test_create_image_uri_tensorflow(tf_version):
511
524
)
512
525
513
526
514
- def test_create_image_uri_accelerator_tf ():
527
+ @patch (
528
+ "sagemaker.fw_utils.get_ecr_image_uri_prefix" ,
529
+ return_value = ECR_PREFIX_FORMAT .format (MOCK_ACCOUNT ),
530
+ )
531
+ def test_create_image_uri_accelerator_tf (ecr_prefix ):
515
532
image_uri = fw_utils .create_image_uri (
516
533
MOCK_REGION , "tensorflow" , "ml.p3.2xlarge" , "1.0" , "py3" , accelerator_type = "ml.eia1.medium"
517
534
)
@@ -521,7 +538,11 @@ def test_create_image_uri_accelerator_tf():
521
538
)
522
539
523
540
524
- def test_create_image_uri_accelerator_mxnet_serving ():
541
+ @patch (
542
+ "sagemaker.fw_utils.get_ecr_image_uri_prefix" ,
543
+ return_value = ECR_PREFIX_FORMAT .format (MOCK_ACCOUNT ),
544
+ )
545
+ def test_create_image_uri_accelerator_mxnet_serving (ecr_prefix ):
525
546
image_uri = fw_utils .create_image_uri (
526
547
MOCK_REGION ,
527
548
"mxnet-serving" ,
@@ -536,7 +557,11 @@ def test_create_image_uri_accelerator_mxnet_serving():
536
557
)
537
558
538
559
539
- def test_create_image_uri_local_sagemaker_notebook_accelerator ():
560
+ @patch (
561
+ "sagemaker.fw_utils.get_ecr_image_uri_prefix" ,
562
+ return_value = ECR_PREFIX_FORMAT .format (MOCK_ACCOUNT ),
563
+ )
564
+ def test_create_image_uri_local_sagemaker_notebook_accelerator (ecr_prefix ):
540
565
image_uri = fw_utils .create_image_uri (
541
566
MOCK_REGION ,
542
567
"mxnet" ,
@@ -608,7 +633,11 @@ def test_invalid_instance_type():
608
633
fw_utils .create_image_uri (MOCK_REGION , MOCK_FRAMEWORK , "p3.2xlarge" , "1.0.0" , "py3" )
609
634
610
635
611
- def test_optimized_family ():
636
+ @patch (
637
+ "sagemaker.fw_utils.get_ecr_image_uri_prefix" ,
638
+ return_value = ECR_PREFIX_FORMAT .format (MOCK_ACCOUNT ),
639
+ )
640
+ def test_optimized_family (ecr_prefix ):
612
641
image_uri = fw_utils .create_image_uri (
613
642
MOCK_REGION ,
614
643
MOCK_FRAMEWORK ,
@@ -622,7 +651,11 @@ def test_optimized_family():
622
651
)
623
652
624
653
625
- def test_unoptimized_cpu_family ():
654
+ @patch (
655
+ "sagemaker.fw_utils.get_ecr_image_uri_prefix" ,
656
+ return_value = ECR_PREFIX_FORMAT .format (MOCK_ACCOUNT ),
657
+ )
658
+ def test_unoptimized_cpu_family (ecr_prefix ):
626
659
image_uri = fw_utils .create_image_uri (
627
660
MOCK_REGION , MOCK_FRAMEWORK , "ml.m4.xlarge" , "1.0.0" , "py3" , optimized_families = ["c5" , "p3" ]
628
661
)
@@ -631,7 +664,11 @@ def test_unoptimized_cpu_family():
631
664
)
632
665
633
666
634
- def test_unoptimized_gpu_family ():
667
+ @patch (
668
+ "sagemaker.fw_utils.get_ecr_image_uri_prefix" ,
669
+ return_value = ECR_PREFIX_FORMAT .format (MOCK_ACCOUNT ),
670
+ )
671
+ def test_unoptimized_gpu_family (ecr_prefix ):
635
672
image_uri = fw_utils .create_image_uri (
636
673
MOCK_REGION , MOCK_FRAMEWORK , "ml.p2.xlarge" , "1.0.0" , "py3" , optimized_families = ["c5" , "p3" ]
637
674
)
0 commit comments