Skip to content

Commit 76add39

Browse files
committed
fix: updating neo multiversion support to include edge devices
1 parent e2624af commit 76add39

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

src/sagemaker/model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,17 @@
6969

7070
NEO_IOC_TARGET_DEVICES = ["ml_c4", "ml_c5", "ml_m4", "ml_m5", "ml_p2", "ml_p3", "ml_g4dn"]
7171

72+
NEO_MULTIVERSION_UNSUPPORTED = [
73+
"imx8mplus",
74+
"jacinto_tda4vm",
75+
"coreml",
76+
"sitara_am57x",
77+
"amba_cv2",
78+
"amba_cv22",
79+
"amba_cv25",
80+
"lambda",
81+
]
82+
7283

7384
class ModelBase(abc.ABC):
7485
"""An object that encapsulates a trained model.
@@ -836,11 +847,14 @@ def multi_version_compilation_supported(
836847
multi_version_frameworks_support_mapping = {
837848
"inferentia": ["pytorch", "tensorflow", "mxnet"],
838849
"neo_ioc_targets": ["pytorch", "tensorflow"],
850+
"neo_edge_targets": ["pytorch", "tensorflow"],
839851
}
840852
if target_instance_type in NEO_IOC_TARGET_DEVICES:
841853
return framework in multi_version_frameworks_support_mapping["neo_ioc_targets"]
842854
if target_instance_type == "ml_inf1":
843855
return framework in multi_version_frameworks_support_mapping["inferentia"]
856+
if target_instance_type not in NEO_UNSUPPORTED_MULTI_VERSION:
857+
return framework in multi_version_frameworks_support_mapping["neo_edge_targets"]
844858
return False
845859

846860
if multi_version_compilation_supported(target_instance_type, framework, framework_version):

tests/unit/sagemaker/model/test_neo.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,3 +382,42 @@ def test_compile_validates_framework_version(sagemaker_session):
382382
)
383383

384384
assert model.image_uri is None
385+
386+
model = _create_model(sagemaker_session)
387+
model.compile(
388+
target_instance_family="rasp3b",
389+
input_shape={"data": [1, 3, 1024, 1024]},
390+
output_path="s3://output",
391+
role="role",
392+
framework="pytorch",
393+
framework_version="1.6.1",
394+
job_name="compile-model",
395+
)
396+
397+
assert model.image_uri is None
398+
399+
config = model._compilation_job_config(
400+
target_instance_family="rasp3b",
401+
input_shape={"data": [1, 3, 1024, 1024]},
402+
output_path="s3://output",
403+
role="role",
404+
framework="pytorch",
405+
framework_version="1.6.1",
406+
job_name="compile-model",
407+
tags=None,
408+
)
409+
410+
assert config["input_model_config"]["FrameworkVersion"] == "1.6"
411+
412+
config = model._compilation_job_config(
413+
target_instance_family="amba_cv2",
414+
input_shape={"data": [1, 3, 1024, 1024]},
415+
output_path="s3://output",
416+
role="role",
417+
framework="pytorch",
418+
framework_version="1.6.1",
419+
job_name="compile-model",
420+
tags=None,
421+
)
422+
423+
assert config["input_model_config"].get("FrameworkVersion", None) is None

0 commit comments

Comments
 (0)