Skip to content

Commit 42298df

Browse files
change: enable neo framework version support on ml_inf2 and ml_trn1 (#3909)
1 parent 95b38ce commit 42298df

File tree

2 files changed

+30
-49
lines changed

2 files changed

+30
-49
lines changed

src/sagemaker/model.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -844,15 +844,20 @@ def multi_version_compilation_supported(
844844
):
845845
if target_instance_type and framework and framework_version:
846846
framework = framework.lower()
847+
847848
multi_version_frameworks_support_mapping = {
848-
"inferentia": ["pytorch", "tensorflow", "mxnet"],
849+
"ml_inf1": ["pytorch", "tensorflow", "mxnet"],
850+
"ml_inf2": ["pytorch", "tensorflow"],
851+
"ml_trn1": ["pytorch", "tensorflow"],
849852
"neo_ioc_targets": ["pytorch", "tensorflow"],
850853
"neo_edge_targets": ["pytorch", "tensorflow"],
851854
}
852855
if target_instance_type in NEO_IOC_TARGET_DEVICES:
853856
return framework in multi_version_frameworks_support_mapping["neo_ioc_targets"]
854-
if target_instance_type == "ml_inf1":
855-
return framework in multi_version_frameworks_support_mapping["inferentia"]
857+
if target_instance_type in ["ml_inf1", "ml_inf2", "ml_trn1"]:
858+
return (
859+
framework in multi_version_frameworks_support_mapping[target_instance_type]
860+
)
856861
if target_instance_type not in NEO_MULTIVERSION_UNSUPPORTED:
857862
return framework in multi_version_frameworks_support_mapping["neo_edge_targets"]
858863
return False

tests/unit/sagemaker/model/test_neo.py

Lines changed: 22 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -362,59 +362,35 @@ def test_compile_with_tensorflow_neo_in_ml_inf(session):
362362
)
363363

364364

365-
def test_compile_validates_framework_version(sagemaker_session):
366-
sagemaker_session.wait_for_compilation_job = Mock(
367-
return_value={
368-
"CompilationJobStatus": "Completed",
369-
"ModelArtifacts": {"S3ModelArtifacts": "s3://output-path/model.tar.gz"},
370-
"InferenceImage": None,
371-
}
372-
)
365+
@pytest.mark.parametrize(
366+
"target,framework,fx_version,expected_fx_version",
367+
[
368+
("ml_c4", "pytorch", "1.6", "1.6"),
369+
("rasp3b", "pytorch", "1.6.1", "1.6"),
370+
("amba_cv2", "pytorch", "1.6.1", None),
371+
("ml_c4", "tensorflow", "1.15.1", "1.15"),
372+
("ml_c4", "tensorflow", "2.15.1", "2.15"),
373+
("ml_inf1", "tensorflow", "2.15.1", "2.15"),
374+
("ml_inf2", "pytorch", "2.0", "2.0"),
375+
("ml_inf2", "pytorch", "2.0.1", "2.0"),
376+
("ml_trn1", "pytorch", "2.0.1", "2.0"),
377+
("ml_trn1", "tensorflow", "2.0.1", "2.0"),
378+
],
379+
)
380+
def test_compile_validates_framework_version(
381+
sagemaker_session, target, framework, fx_version, expected_fx_version
382+
):
373383
model = _create_model(sagemaker_session)
374-
model.compile(
375-
target_instance_family="ml_c4",
376-
input_shape={"data": [1, 3, 1024, 1024]},
377-
output_path="s3://output",
378-
role="role",
379-
framework="pytorch",
380-
framework_version="1.6.1",
381-
job_name="compile-model",
382-
)
383-
384-
assert model.image_uri is None
385-
386-
sagemaker_session.wait_for_compilation_job = Mock(
387-
return_value={
388-
"CompilationJobStatus": "Completed",
389-
"ModelArtifacts": {"S3ModelArtifacts": "s3://output-path/model.tar.gz"},
390-
"InferenceImage": None,
391-
}
392-
)
393-
394-
config = model._compilation_job_config(
395-
"rasp3b",
396-
{"data": [1, 3, 1024, 1024]},
397-
"s3://output",
398-
"role",
399-
900,
400-
"compile-model",
401-
"pytorch",
402-
None,
403-
framework_version="1.6.1",
404-
)
405-
406-
assert config["input_model_config"]["FrameworkVersion"] == "1.6"
407-
408384
config = model._compilation_job_config(
409-
"amba_cv2",
385+
target,
410386
{"data": [1, 3, 1024, 1024]},
411387
"s3://output",
412388
"role",
413389
900,
414390
"compile-model",
415-
"pytorch",
391+
framework,
416392
None,
417-
framework_version="1.6.1",
393+
framework_version=fx_version,
418394
)
419395

420-
assert config["input_model_config"].get("FrameworkVersion", None) is None
396+
assert config["input_model_config"].get("FrameworkVersion", None) == expected_fx_version

0 commit comments

Comments
 (0)