Skip to content

Commit 15deec1

Browse files
author
wanyixia
committed
change: revise unit test
1 parent 99fa6cd commit 15deec1

File tree

3 files changed

+16
-2
lines changed

3 files changed

+16
-2
lines changed

src/sagemaker/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ def compile(
652652
self.model_data = job_status["ModelArtifacts"]["S3ModelArtifacts"]
653653
if target_instance_family is not None:
654654
if target_instance_family == "ml_eia2":
655-
LOGGER.info("You are using target device ml_eia2...")
655+
pass
656656
elif target_instance_family.startswith("ml_"):
657657
self.image_uri = self._compilation_image_uri(
658658
self.sagemaker_session.boto_region_name,

src/sagemaker/tensorflow/model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,12 @@ def deploy(
289289

290290
def _eia_supported(self):
291291
"""Return true if TF version is EIA enabled"""
292-
return [int(s) for s in self.framework_version.split(".")][:2] <= self.LATEST_EIA_VERSION
292+
framework_version = [int(s) for s in self.framework_version.split(".")][:2]
293+
return (
294+
framework_version != [2, 1]
295+
and framework_version != [2, 2]
296+
and framework_version <= self.LATEST_EIA_VERSION
297+
)
293298

294299
def prepare_container_def(self, instance_type=None, accelerator_type=None):
295300
"""Prepare the container definition.

tests/unit/sagemaker/tensorflow/test_tfs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,15 @@ def test_tfs_model_image_accelerator_not_supported(sagemaker_session):
142142

143143
model.deploy(instance_type="ml.c4.xlarge", initial_instance_count=1)
144144

145+
with pytest.raises(AttributeError) as e:
146+
model.deploy(
147+
instance_type="ml.c4.xlarge",
148+
accelerator_type="ml.eia1.medium",
149+
initial_instance_count=1,
150+
)
151+
152+
assert str(e.value) == "The TensorFlow version 2.1 doesn't support EIA."
153+
145154

146155
def test_tfs_model_with_log_level(sagemaker_session, tensorflow_inference_version):
147156
model = TensorFlowModel(

0 commit comments

Comments
 (0)