@@ -215,6 +215,10 @@ def _compilation_job_config(
215
215
job_name ,
216
216
framework ,
217
217
tags ,
218
+ target_platform_os = None ,
219
+ target_platform_arch = None ,
220
+ target_platform_accelerator = None ,
221
+ compiler_options = None ,
218
222
):
219
223
"""
220
224
Args:
@@ -226,20 +230,46 @@ def _compilation_job_config(
226
230
job_name:
227
231
framework:
228
232
tags:
233
+ target_platform_os:
234
+ target_platform_arch:
235
+ target_platform_accelerator:
236
+ compiler_options:
229
237
"""
230
238
input_model_config = {
231
239
"S3Uri" : self .model_data ,
232
- "DataInputConfig" : input_shape
233
- if not isinstance (input_shape , dict )
234
- else json . dumps ( input_shape ) ,
240
+ "DataInputConfig" : json . dumps ( input_shape )
241
+ if isinstance (input_shape , dict )
242
+ else input_shape ,
235
243
"Framework" : framework ,
236
244
}
237
245
role = self .sagemaker_session .expand_role (role )
238
246
output_model_config = {
239
- "TargetDevice" : target_instance_type ,
240
247
"S3OutputLocation" : output_path ,
241
248
}
242
249
250
+ if target_instance_type is not None :
251
+ output_model_config ["TargetDevice" ] = target_instance_type
252
+ else :
253
+ if target_platform_os is None and target_platform_arch is None :
254
+ raise ValueError (
255
+ "target_instance_type or (target_platform_os and target_platform_arch) "
256
+ "should be provided"
257
+ )
258
+ target_platform = {
259
+ "Os" : target_platform_os ,
260
+ "Arch" : target_platform_arch ,
261
+ }
262
+ if target_platform_accelerator is not None :
263
+ target_platform ["Accelerator" ] = target_platform_accelerator
264
+ output_model_config ["TargetPlatform" ] = target_platform
265
+
266
+ if compiler_options is not None :
267
+ output_model_config ["CompilerOptions" ] = (
268
+ json .dumps (compiler_options )
269
+ if isinstance (compiler_options , dict )
270
+ else compiler_options
271
+ )
272
+
243
273
return {
244
274
"input_model_config" : input_model_config ,
245
275
"output_model_config" : output_model_config ,
@@ -320,6 +350,10 @@ def compile(
320
350
compile_max_run = 5 * 60 ,
321
351
framework = None ,
322
352
framework_version = None ,
353
+ target_platform_os = None ,
354
+ target_platform_arch = None ,
355
+ target_platform_accelerator = None ,
356
+ compiler_options = None ,
323
357
):
324
358
"""Compile this ``Model`` with SageMaker Neo.
325
359
@@ -328,6 +362,9 @@ def compile(
328
362
run your model after compilation, for example: ml_c5. For allowed
329
363
strings see
330
364
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
365
+ Alternatively, you can select an OS, Architecture and Accelerator using
366
+ ``target_platform_os``, ``target_platform_arch``,
367
+ and ``target_platform_accelerator``.
331
368
input_shape (dict): Specifies the name and shape of the expected
332
369
inputs for your trained model in json dictionary form, for
333
370
example: {'data': [1,3,1024,1024]}, or {'var1': [1,1,28,28],
@@ -345,6 +382,21 @@ def compile(
345
382
model. Allowed values: 'mxnet', 'tensorflow', 'keras', 'pytorch',
346
383
'onnx', 'xgboost'
347
384
framework_version (str):
385
+ target_platform_os (str): Target Platform OS, for example: 'LINUX'.
386
+ For allowed strings see
387
+ https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
388
+ It can be used instead of target_instance_family.
389
+ target_platform_arch (str): Target Platform Architecture, for example: 'X86_64'.
390
+ For allowed strings see
391
+ https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
392
+ It can be used instead of target_instance_family.
393
+ target_platform_accelerator (str, optional): Target Platform Accelerator,
394
+ for example: 'NVIDIA'. For allowed strings see
395
+ https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
396
+ It can be used instead of target_instance_family.
397
+ compiler_options (dict, optional): Additional parameters for compiler.
398
+ Compiler Options are TargetPlatform / target_instance_family specific. See
399
+ https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html for details.
348
400
349
401
Returns:
350
402
sagemaker.model.Model: A SageMaker ``Model`` object. See
@@ -375,31 +427,41 @@ def compile(
375
427
job_name ,
376
428
framework ,
377
429
tags ,
430
+ target_platform_os ,
431
+ target_platform_arch ,
432
+ target_platform_accelerator ,
433
+ compiler_options ,
378
434
)
379
435
self .sagemaker_session .compile_model (** config )
380
436
job_status = self .sagemaker_session .wait_for_compilation_job (job_name )
381
437
self .model_data = job_status ["ModelArtifacts" ]["S3ModelArtifacts" ]
382
- if target_instance_family .startswith ("ml_" ):
383
- self .image = self ._neo_image (
384
- self .sagemaker_session .boto_region_name ,
385
- target_instance_family ,
386
- framework ,
387
- framework_version ,
388
- )
389
- self ._is_compiled_model = True
390
- elif target_instance_family .startswith (INFERENTIA_INSTANCE_PREFIX ):
391
- self .image = self ._inferentia_image (
392
- self .sagemaker_session .boto_region_name ,
393
- target_instance_family ,
394
- framework ,
395
- framework_version ,
396
- )
397
- self ._is_compiled_model = True
438
+ if target_instance_family is not None :
439
+ if target_instance_family .startswith ("ml_" ):
440
+ self .image = self ._neo_image (
441
+ self .sagemaker_session .boto_region_name ,
442
+ target_instance_family ,
443
+ framework ,
444
+ framework_version ,
445
+ )
446
+ self ._is_compiled_model = True
447
+ elif target_instance_family .startswith (INFERENTIA_INSTANCE_PREFIX ):
448
+ self .image = self ._inferentia_image (
449
+ self .sagemaker_session .boto_region_name ,
450
+ target_instance_family ,
451
+ framework ,
452
+ framework_version ,
453
+ )
454
+ self ._is_compiled_model = True
455
+ else :
456
+ LOGGER .warning (
457
+ "The instance type %s is not supported for deployment via SageMaker."
458
+ "Please deploy the model manually." ,
459
+ target_instance_family ,
460
+ )
398
461
else :
399
462
LOGGER .warning (
400
- "The instance type %s is not supported to deploy via SageMaker,"
401
- "please deploy the model manually." ,
402
- target_instance_family ,
463
+ "Devices described by Target Platform OS, Architecture and Accelerator are not"
464
+ "supported for deployment via SageMaker. Please deploy the model manually."
403
465
)
404
466
return self
405
467
0 commit comments