File tree Expand file tree Collapse file tree 3 files changed +16
-2
lines changed
tests/unit/sagemaker/tensorflow Expand file tree Collapse file tree 3 files changed +16
-2
lines changed Original file line number Diff line number Diff line change @@ -652,7 +652,7 @@ def compile(
652
652
self .model_data = job_status ["ModelArtifacts" ]["S3ModelArtifacts" ]
653
653
if target_instance_family is not None :
654
654
if target_instance_family == "ml_eia2" :
655
- LOGGER . info ( "You are using target device ml_eia2..." )
655
+ pass
656
656
elif target_instance_family .startswith ("ml_" ):
657
657
self .image_uri = self ._compilation_image_uri (
658
658
self .sagemaker_session .boto_region_name ,
Original file line number Diff line number Diff line change @@ -289,7 +289,12 @@ def deploy(
289
289
290
290
def _eia_supported (self ):
291
291
"""Return true if TF version is EIA enabled"""
292
- return [int (s ) for s in self .framework_version .split ("." )][:2 ] <= self .LATEST_EIA_VERSION
292
+ framework_version = [int (s ) for s in self .framework_version .split ("." )][:2 ]
293
+ return (
294
+ framework_version != [2 , 1 ]
295
+ and framework_version != [2 , 2 ]
296
+ and framework_version <= self .LATEST_EIA_VERSION
297
+ )
293
298
294
299
def prepare_container_def (self , instance_type = None , accelerator_type = None ):
295
300
"""Prepare the container definition.
Original file line number Diff line number Diff line change @@ -142,6 +142,15 @@ def test_tfs_model_image_accelerator_not_supported(sagemaker_session):
142
142
143
143
model .deploy (instance_type = "ml.c4.xlarge" , initial_instance_count = 1 )
144
144
145
+ with pytest .raises (AttributeError ) as e :
146
+ model .deploy (
147
+ instance_type = "ml.c4.xlarge" ,
148
+ accelerator_type = "ml.eia1.medium" ,
149
+ initial_instance_count = 1 ,
150
+ )
151
+
152
+ assert str (e .value ) == "The TensorFlow version 2.1 doesn't support EIA."
153
+
145
154
146
155
def test_tfs_model_with_log_level (sagemaker_session , tensorflow_inference_version ):
147
156
model = TensorFlowModel (
You can’t perform that action at this time.
0 commit comments