Skip to content

Commit 20751c9

Browse files
authored
feature: Neo: Add Granular Target Description support for compilation (#1752)
1 parent 630edf6 commit 20751c9

File tree

3 files changed

+147
-23
lines changed

3 files changed

+147
-23
lines changed

src/sagemaker/estimator.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,10 @@ def compile_model(
512512
framework_version=None,
513513
compile_max_run=15 * 60,
514514
tags=None,
515+
target_platform_os=None,
516+
target_platform_arch=None,
517+
target_platform_accelerator=None,
518+
compiler_options=None,
515519
**kwargs
516520
):
517521
"""Compile a Neo model using the input model.
@@ -536,6 +540,21 @@ def compile_model(
536540
tags (list[dict]): List of tags for labeling a compilation job. For
537541
more, see
538542
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
543+
target_platform_os (str): Target Platform OS, for example: 'LINUX'.
544+
For allowed strings see
545+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
546+
It can be used instead of target_instance_family.
547+
target_platform_arch (str): Target Platform Architecture, for example: 'X86_64'.
548+
For allowed strings see
549+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
550+
It can be used instead of target_instance_family.
551+
target_platform_accelerator (str, optional): Target Platform Accelerator,
552+
for example: 'NVIDIA'. For allowed strings see
553+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
554+
It can be used instead of target_instance_family.
555+
compiler_options (dict, optional): Additional parameters for compiler.
556+
Compiler Options are TargetPlatform / target_instance_family specific. See
557+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html for details.
539558
**kwargs: Passed to invocation of ``create_model()``.
540559
Implementations may customize ``create_model()`` to accept
541560
``**kwargs`` to customize model creation during deploy. For
@@ -565,6 +584,10 @@ def compile_model(
565584
compile_max_run,
566585
framework=framework,
567586
framework_version=framework_version,
587+
target_platform_os=target_platform_os,
588+
target_platform_arch=target_platform_arch,
589+
target_platform_accelerator=target_platform_accelerator,
590+
compiler_options=compiler_options,
568591
)
569592
return self._compiled_models[target_instance_family]
570593

src/sagemaker/model.py

Lines changed: 85 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,10 @@ def _compilation_job_config(
215215
job_name,
216216
framework,
217217
tags,
218+
target_platform_os=None,
219+
target_platform_arch=None,
220+
target_platform_accelerator=None,
221+
compiler_options=None,
218222
):
219223
"""
220224
Args:
@@ -226,20 +230,46 @@ def _compilation_job_config(
226230
job_name:
227231
framework:
228232
tags:
233+
target_platform_os:
234+
target_platform_arch:
235+
target_platform_accelerator:
236+
compiler_options:
229237
"""
230238
input_model_config = {
231239
"S3Uri": self.model_data,
232-
"DataInputConfig": input_shape
233-
if not isinstance(input_shape, dict)
234-
else json.dumps(input_shape),
240+
"DataInputConfig": json.dumps(input_shape)
241+
if isinstance(input_shape, dict)
242+
else input_shape,
235243
"Framework": framework,
236244
}
237245
role = self.sagemaker_session.expand_role(role)
238246
output_model_config = {
239-
"TargetDevice": target_instance_type,
240247
"S3OutputLocation": output_path,
241248
}
242249

250+
if target_instance_type is not None:
251+
output_model_config["TargetDevice"] = target_instance_type
252+
else:
253+
if target_platform_os is None and target_platform_arch is None:
254+
raise ValueError(
255+
"target_instance_type or (target_platform_os and target_platform_arch) "
256+
"should be provided"
257+
)
258+
target_platform = {
259+
"Os": target_platform_os,
260+
"Arch": target_platform_arch,
261+
}
262+
if target_platform_accelerator is not None:
263+
target_platform["Accelerator"] = target_platform_accelerator
264+
output_model_config["TargetPlatform"] = target_platform
265+
266+
if compiler_options is not None:
267+
output_model_config["CompilerOptions"] = (
268+
json.dumps(compiler_options)
269+
if isinstance(compiler_options, dict)
270+
else compiler_options
271+
)
272+
243273
return {
244274
"input_model_config": input_model_config,
245275
"output_model_config": output_model_config,
@@ -320,6 +350,10 @@ def compile(
320350
compile_max_run=5 * 60,
321351
framework=None,
322352
framework_version=None,
353+
target_platform_os=None,
354+
target_platform_arch=None,
355+
target_platform_accelerator=None,
356+
compiler_options=None,
323357
):
324358
"""Compile this ``Model`` with SageMaker Neo.
325359
@@ -328,6 +362,9 @@ def compile(
328362
run your model after compilation, for example: ml_c5. For allowed
329363
strings see
330364
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
365+
Alternatively, you can select an OS, Architecture and Accelerator using
366+
``target_platform_os``, ``target_platform_arch``,
367+
and ``target_platform_accelerator``.
331368
input_shape (dict): Specifies the name and shape of the expected
332369
inputs for your trained model in json dictionary form, for
333370
example: {'data': [1,3,1024,1024]}, or {'var1': [1,1,28,28],
@@ -345,6 +382,21 @@ def compile(
345382
model. Allowed values: 'mxnet', 'tensorflow', 'keras', 'pytorch',
346383
'onnx', 'xgboost'
347384
framework_version (str):
385+
target_platform_os (str): Target Platform OS, for example: 'LINUX'.
386+
For allowed strings see
387+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
388+
It can be used instead of target_instance_family.
389+
target_platform_arch (str): Target Platform Architecture, for example: 'X86_64'.
390+
For allowed strings see
391+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
392+
It can be used instead of target_instance_family.
393+
target_platform_accelerator (str, optional): Target Platform Accelerator,
394+
for example: 'NVIDIA'. For allowed strings see
395+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
396+
It can be used instead of target_instance_family.
397+
compiler_options (dict, optional): Additional parameters for compiler.
398+
Compiler Options are TargetPlatform / target_instance_family specific. See
399+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html for details.
348400
349401
Returns:
350402
sagemaker.model.Model: A SageMaker ``Model`` object. See
@@ -375,31 +427,41 @@ def compile(
375427
job_name,
376428
framework,
377429
tags,
430+
target_platform_os,
431+
target_platform_arch,
432+
target_platform_accelerator,
433+
compiler_options,
378434
)
379435
self.sagemaker_session.compile_model(**config)
380436
job_status = self.sagemaker_session.wait_for_compilation_job(job_name)
381437
self.model_data = job_status["ModelArtifacts"]["S3ModelArtifacts"]
382-
if target_instance_family.startswith("ml_"):
383-
self.image = self._neo_image(
384-
self.sagemaker_session.boto_region_name,
385-
target_instance_family,
386-
framework,
387-
framework_version,
388-
)
389-
self._is_compiled_model = True
390-
elif target_instance_family.startswith(INFERENTIA_INSTANCE_PREFIX):
391-
self.image = self._inferentia_image(
392-
self.sagemaker_session.boto_region_name,
393-
target_instance_family,
394-
framework,
395-
framework_version,
396-
)
397-
self._is_compiled_model = True
438+
if target_instance_family is not None:
439+
if target_instance_family.startswith("ml_"):
440+
self.image = self._neo_image(
441+
self.sagemaker_session.boto_region_name,
442+
target_instance_family,
443+
framework,
444+
framework_version,
445+
)
446+
self._is_compiled_model = True
447+
elif target_instance_family.startswith(INFERENTIA_INSTANCE_PREFIX):
448+
self.image = self._inferentia_image(
449+
self.sagemaker_session.boto_region_name,
450+
target_instance_family,
451+
framework,
452+
framework_version,
453+
)
454+
self._is_compiled_model = True
455+
else:
456+
LOGGER.warning(
457+
"The instance type %s is not supported for deployment via SageMaker."
458+
"Please deploy the model manually.",
459+
target_instance_family,
460+
)
398461
else:
399462
LOGGER.warning(
400-
"The instance type %s is not supported to deploy via SageMaker,"
401-
"please deploy the model manually.",
402-
target_instance_family,
463+
"Devices described by Target Platform OS, Architecture and Accelerator are not"
464+
"supported for deployment via SageMaker. Please deploy the model manually."
403465
)
404466
return self
405467

tests/unit/sagemaker/model/test_neo.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,45 @@ def test_compile_model_for_edge_device_tflite(sagemaker_session):
9595
assert model._is_compiled_model is False
9696

9797

98+
def test_compile_model_linux_arm64_nvidia(sagemaker_session):
99+
sagemaker_session.wait_for_compilation_job = Mock(
100+
return_value=DESCRIBE_COMPILATION_JOB_RESPONSE
101+
)
102+
model = _create_model(sagemaker_session)
103+
model.compile(
104+
target_instance_family=None,
105+
input_shape={"data": [1, 3, 1024, 1024]},
106+
output_path="s3://output",
107+
role="role",
108+
framework="tensorflow",
109+
job_name="compile-model",
110+
target_platform_os="LINUX",
111+
target_platform_arch="ARM64",
112+
target_platform_accelerator="NVIDIA",
113+
compiler_options={"gpu-code": "sm_72", "trt-ver": "6.0.1", "cuda-ver": "10.1"},
114+
)
115+
assert model._is_compiled_model is False
116+
117+
118+
def test_compile_model_android_armv7(sagemaker_session):
119+
sagemaker_session.wait_for_compilation_job = Mock(
120+
return_value=DESCRIBE_COMPILATION_JOB_RESPONSE
121+
)
122+
model = _create_model(sagemaker_session)
123+
model.compile(
124+
target_instance_family=None,
125+
input_shape={"data": [1, 3, 1024, 1024]},
126+
output_path="s3://output",
127+
role="role",
128+
framework="tensorflow",
129+
job_name="compile-model",
130+
target_platform_os="ANDROID",
131+
target_platform_arch="ARM_EABI",
132+
compiler_options={"ANDROID_PLATFORM": 25, "mattr": ["+neon"]},
133+
)
134+
assert model._is_compiled_model is False
135+
136+
98137
def test_compile_model_for_cloud(sagemaker_session):
99138
sagemaker_session.wait_for_compilation_job = Mock(
100139
return_value=DESCRIBE_COMPILATION_JOB_RESPONSE

0 commit comments

Comments
 (0)