Skip to content

Commit 9f2a192

Browse files
authored
fix: Update neo multiversion support to include edge devices (#3875)
1 parent 57f3bb9 commit 9f2a192

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

src/sagemaker/model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,17 @@
6969

7070
NEO_IOC_TARGET_DEVICES = ["ml_c4", "ml_c5", "ml_m4", "ml_m5", "ml_p2", "ml_p3", "ml_g4dn"]
7171

72+
NEO_MULTIVERSION_UNSUPPORTED = [
73+
"imx8mplus",
74+
"jacinto_tda4vm",
75+
"coreml",
76+
"sitara_am57x",
77+
"amba_cv2",
78+
"amba_cv22",
79+
"amba_cv25",
80+
"lambda",
81+
]
82+
7283

7384
class ModelBase(abc.ABC):
7485
"""An object that encapsulates a trained model.
@@ -836,11 +847,14 @@ def multi_version_compilation_supported(
836847
multi_version_frameworks_support_mapping = {
837848
"inferentia": ["pytorch", "tensorflow", "mxnet"],
838849
"neo_ioc_targets": ["pytorch", "tensorflow"],
850+
"neo_edge_targets": ["pytorch", "tensorflow"],
839851
}
840852
if target_instance_type in NEO_IOC_TARGET_DEVICES:
841853
return framework in multi_version_frameworks_support_mapping["neo_ioc_targets"]
842854
if target_instance_type == "ml_inf1":
843855
return framework in multi_version_frameworks_support_mapping["inferentia"]
856+
if target_instance_type not in NEO_MULTIVERSION_UNSUPPORTED:
857+
return framework in multi_version_frameworks_support_mapping["neo_edge_targets"]
844858
return False
845859

846860
if multi_version_compilation_supported(target_instance_type, framework, framework_version):

tests/unit/sagemaker/model/test_neo.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,3 +382,39 @@ def test_compile_validates_framework_version(sagemaker_session):
382382
)
383383

384384
assert model.image_uri is None
385+
386+
sagemaker_session.wait_for_compilation_job = Mock(
387+
return_value={
388+
"CompilationJobStatus": "Completed",
389+
"ModelArtifacts": {"S3ModelArtifacts": "s3://output-path/model.tar.gz"},
390+
"InferenceImage": None,
391+
}
392+
)
393+
394+
config = model._compilation_job_config(
395+
"rasp3b",
396+
{"data": [1, 3, 1024, 1024]},
397+
"s3://output",
398+
"role",
399+
900,
400+
"compile-model",
401+
"pytorch",
402+
None,
403+
framework_version="1.6.1",
404+
)
405+
406+
assert config["input_model_config"]["FrameworkVersion"] == "1.6"
407+
408+
config = model._compilation_job_config(
409+
"amba_cv2",
410+
{"data": [1, 3, 1024, 1024]},
411+
"s3://output",
412+
"role",
413+
900,
414+
"compile-model",
415+
"pytorch",
416+
None,
417+
framework_version="1.6.1",
418+
)
419+
420+
assert config["input_model_config"].get("FrameworkVersion", None) is None

0 commit comments

Comments
 (0)