Skip to content

Commit de69011

Browse files
author
Prateek Chauhan
committed
feature: Add Framework Version support for PyTorch compilation (Neo)
1 parent b28bb31 commit de69011

File tree

3 files changed

+128
-7
lines changed

3 files changed

+128
-7
lines changed

src/sagemaker/image_uri_config/neo-pytorch.json

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,73 @@
22
"processors": ["cpu", "gpu"],
33
"scope": ["inference"],
44
"version_aliases": {
5-
"0.4.0": "1.4.0",
6-
"1.0.0": "1.4.0",
7-
"1.1.0": "1.4.0",
8-
"1.2.0": "1.4.0",
9-
"1.3.0": "1.4.0"
5+
"0.4.0": "1.4",
6+
"1.0.0": "1.4",
7+
"1.1.0": "1.4",
8+
"1.2.0": "1.4",
9+
"1.3.0": "1.4",
10+
"1.4.0": "1.4"
1011
},
1112
"versions": {
12-
"1.4.0": {
13+
"1.4": {
14+
"py_versions": ["py3"],
15+
"registries": {
16+
"af-south-1": "774647643957",
17+
"ap-east-1": "110948597952",
18+
"ap-northeast-1": "941853720454",
19+
"ap-northeast-2": "151534178276",
20+
"ap-south-1": "763008648453",
21+
"ap-southeast-1": "324986816169",
22+
"ap-southeast-2": "355873309152",
23+
"ca-central-1": "464438896020",
24+
"cn-north-1": "472730292857",
25+
"cn-northwest-1": "474822919863",
26+
"eu-central-1": "746233611703",
27+
"eu-north-1": "601324751636",
28+
"eu-south-1": "966458181534",
29+
"eu-west-1": "802834080501",
30+
"eu-west-2": "205493899709",
31+
"eu-west-3": "254080097072",
32+
"me-south-1": "836785723513",
33+
"sa-east-1": "756306329178",
34+
"us-east-1": "785573368785",
35+
"us-east-2": "007439368137",
36+
"us-gov-west-1": "263933020539",
37+
"us-west-1": "710691900526",
38+
"us-west-2": "301217895009"
39+
},
40+
"repository": "sagemaker-inference-pytorch"
41+
},
42+
"1.5": {
43+
"py_versions": ["py3"],
44+
"registries": {
45+
"af-south-1": "774647643957",
46+
"ap-east-1": "110948597952",
47+
"ap-northeast-1": "941853720454",
48+
"ap-northeast-2": "151534178276",
49+
"ap-south-1": "763008648453",
50+
"ap-southeast-1": "324986816169",
51+
"ap-southeast-2": "355873309152",
52+
"ca-central-1": "464438896020",
53+
"cn-north-1": "472730292857",
54+
"cn-northwest-1": "474822919863",
55+
"eu-central-1": "746233611703",
56+
"eu-north-1": "601324751636",
57+
"eu-south-1": "966458181534",
58+
"eu-west-1": "802834080501",
59+
"eu-west-2": "205493899709",
60+
"eu-west-3": "254080097072",
61+
"me-south-1": "836785723513",
62+
"sa-east-1": "756306329178",
63+
"us-east-1": "785573368785",
64+
"us-east-2": "007439368137",
65+
"us-gov-west-1": "263933020539",
66+
"us-west-1": "710691900526",
67+
"us-west-2": "301217895009"
68+
},
69+
"repository": "sagemaker-inference-pytorch"
70+
},
71+
"1.6": {
1372
"py_versions": ["py3"],
1473
"registries": {
1574
"af-south-1": "774647643957",

src/sagemaker/model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ def _compilation_job_config(
398398
target_platform_arch=None,
399399
target_platform_accelerator=None,
400400
compiler_options=None,
401+
framework_version=None,
401402
):
402403
"""Placeholder Docstring"""
403404
input_model_config = {
@@ -407,6 +408,9 @@ def _compilation_job_config(
407408
else input_shape,
408409
"Framework": framework.upper(),
409410
}
411+
if framework.upper() == "PYTORCH" and framework_version is not None:
412+
input_model_config["FrameworkVersion"] = utils.get_short_version(framework_version)
413+
410414
role = self.sagemaker_session.expand_role(role)
411415
output_model_config = {
412416
"S3OutputLocation": output_path,
@@ -572,7 +576,8 @@ def compile(
572576
framework (str): The framework that is used to train the original
573577
model. Allowed values: 'mxnet', 'tensorflow', 'keras', 'pytorch',
574578
'onnx', 'xgboost'
575-
framework_version (str):
579+
framework_version (str):The version of framework, for example:
580+
'1.5' for PyTorch
576581
target_platform_os (str): Target Platform OS, for example: 'LINUX'.
577582
For allowed strings see
578583
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
@@ -613,6 +618,7 @@ def compile(
613618
framework_version = framework_version or self._get_framework_version()
614619

615620
self._init_sagemaker_session_if_does_not_exist(target_instance_family)
621+
616622
config = self._compilation_job_config(
617623
target_instance_family,
618624
input_shape,
@@ -626,6 +632,7 @@ def compile(
626632
target_platform_arch,
627633
target_platform_accelerator,
628634
compiler_options,
635+
framework_version,
629636
)
630637
self.sagemaker_session.compile_model(**config)
631638
job_status = self.sagemaker_session.wait_for_compilation_job(job_name)

tests/unit/sagemaker/model/test_neo.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,3 +269,58 @@ def test_deploy_add_compiled_model_suffix_to_endpoint_name_from_model_name(sagem
269269

270270
model.deploy(1, "ml.c4.xlarge")
271271
assert model.endpoint_name.startswith("{}-ml-c4".format(model_name))
272+
273+
274+
@patch("sagemaker.session.Session")
275+
def test_compile_with_framework_version_15(session):
276+
session.return_value.boto_region_name = REGION
277+
278+
model = _create_model()
279+
model.compile(
280+
target_instance_family="ml_c4",
281+
input_shape={"data": [1, 3, 1024, 1024]},
282+
output_path="s3://output",
283+
role="role",
284+
framework="pytorch",
285+
framework_version="1.5",
286+
job_name="compile-model",
287+
)
288+
289+
assert "1.5" in model.image_uri
290+
291+
292+
@patch("sagemaker.session.Session")
293+
def test_compile_with_framework_version_16(session):
294+
session.return_value.boto_region_name = REGION
295+
296+
model = _create_model()
297+
model.compile(
298+
target_instance_family="ml_c4",
299+
input_shape={"data": [1, 3, 1024, 1024]},
300+
output_path="s3://output",
301+
role="role",
302+
framework="pytorch",
303+
framework_version="1.6",
304+
job_name="compile-model",
305+
)
306+
307+
assert "1.6" in model.image_uri
308+
309+
310+
@patch("sagemaker.session.Session")
311+
def test_compile_validates_framework_version(session):
312+
session.return_value.boto_region_name = REGION
313+
314+
model = _create_model()
315+
with pytest.raises(ValueError) as e:
316+
model.compile(
317+
target_instance_family="ml_c4",
318+
input_shape={"data": [1, 3, 1024, 1024]},
319+
output_path="s3://output",
320+
role="role",
321+
framework="pytorch",
322+
framework_version="1.6.1",
323+
job_name="compile-model",
324+
)
325+
326+
assert "Unsupported neo-pytorch version: 1.6.1." in str(e)

0 commit comments

Comments
 (0)