Skip to content

Commit 4f98141

Browse files
feature: CompilationStep support for Sagemaker Pipelines
1 parent ac80ceb commit 4f98141

File tree

4 files changed

+238
-1
lines changed

4 files changed

+238
-1
lines changed

src/sagemaker/inputs.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,62 @@ class CreateModelInput(object):
136136
accelerator_type: str = attr.ib(default=None)
137137

138138

139+
@attr.s
140+
class CompilationInput(object):
141+
"""Create a class containing all the parameters.
142+
143+
It can be used when calling ``sagemaker.model.Model.compile_model()``
144+
145+
Parameters:
146+
147+
target_instance_type(str): Identifies the device that you want to
148+
run your model after compilation, for example: ml_c5. For allowed
149+
strings see
150+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
151+
input_shape(str): Specifies the name and shape of the expected
152+
inputs for your trained model in json dictionary form, for
153+
example: {'data': [1,3,1024,1024]}, or {'var1': [1,1,28,28],
154+
'var2': [1,1,28,28]}
155+
output_path(str): Specifies where to store the compiled model
156+
framework (str): The framework that is used to train the original
157+
model. Allowed values: 'mxnet', 'tensorflow', 'keras', 'pytorch',
158+
'onnx', 'xgboost'
159+
framework_version (str): The version of the framework
160+
compile_max_run (int): Timeout in seconds for compilation (default:
161+
15 * 60). After this amount of time Amazon SageMaker Neo
162+
terminates the compilation job regardless of its current status.
163+
tags (list[dict]): List of tags for labeling a compilation job. For
164+
more, see
165+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
166+
target_platform_os (str): Target Platform OS, for example: 'LINUX'.
167+
For allowed strings see
168+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
169+
It can be used instead of target_instance_family.
170+
target_platform_arch (str): Target Platform Architecture, for example: 'X86_64'.
171+
For allowed strings see
172+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
173+
It can be used instead of target_instance_family.
174+
target_platform_accelerator (str, optional): Target Platform Accelerator,
175+
for example: 'NVIDIA'. For allowed strings see
176+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
177+
It can be used instead of target_instance_family.
178+
compiler_options (dict, optional): Additional parameters for compiler.
179+
Compiler Options are TargetPlatform / target_instance_family specific. See
180+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html for details.
181+
"""
182+
183+
target_instance_type: str = attr.ib(default=None)
184+
input_shape: dict = attr.ib(factory=dict)
185+
output_path: str = attr.ib(default=None)
186+
framework: str = attr.ib(default=None)
187+
framework_version: str = attr.ib(default=None)
188+
compile_max_run: int = attr.ib(default=15 * 60)
189+
tags: list = attr.ib(factory=list)
190+
target_platform_os: str = attr.ib(default=None)
191+
target_platform_arch: str = attr.ib(default=None)
192+
target_platform_accelerator: str = attr.ib(default=None)
193+
compiler_options: dict = attr.ib(factory=dict)
194+
139195
@attr.s
140196
class TransformInput(object):
141197
"""Create a class containing all the parameters.

src/sagemaker/model.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
utils,
3030
git_utils,
3131
)
32+
from sagemaker.inputs import CompilationInput
3233
from sagemaker.deprecations import removed_kwargs
3334
from sagemaker.predictor import PredictorBase
3435
from sagemaker.transformer import Transformer
@@ -410,6 +411,57 @@ def _compilation_job_config(
410411
"job_name": job_name,
411412
}
412413

414+
def _get_compilation_args(self, estimator, inputs):
415+
"""Constructs a dict of arguments for an Amazon SageMaker compilation job from the estimator.
416+
Args:
417+
estimator (sagemaker.estimator.EstimatorBase): Estimator object
418+
created by the user.
419+
inputs (CompilationInput): class containing all the parameters that
420+
can be used when calling ``sagemaker.model.Model.compile_model()``
421+
"""
422+
423+
if not isinstance(inputs, CompilationInput):
424+
raise TypeError("Your inputs must be provided as ProcessingInput objects.")
425+
target_instance_family = inputs.target_instance_type
426+
input_shape = inputs.input_shape
427+
output_path = inputs.output_path
428+
role = estimator.role
429+
compile_max_run = inputs.compile_max_run
430+
job_name = estimator._compilation_job_name()
431+
framework = inputs.framework or self._framework()
432+
if framework is None:
433+
raise ValueError(
434+
"You must specify framework, allowed values {}".format(NEO_ALLOWED_FRAMEWORKS)
435+
)
436+
if framework not in NEO_ALLOWED_FRAMEWORKS:
437+
raise ValueError(
438+
"You must provide valid framework, allowed values {}".format(NEO_ALLOWED_FRAMEWORKS)
439+
)
440+
if self.model_data is None:
441+
raise ValueError("You must provide an S3 path to the compressed model artifacts.")
442+
tags = inputs.tags
443+
target_platform_os = inputs.target_platform_os
444+
target_platform_arch = inputs.target_platform_arch
445+
target_platform_accelerator = inputs.target_platform_accelerator
446+
compiler_options = inputs.compiler_options
447+
framework_version = inputs.framework_version or self._get_framework_version()
448+
449+
return self._compilation_job_config(
450+
target_instance_family,
451+
input_shape,
452+
output_path,
453+
role,
454+
compile_max_run,
455+
job_name,
456+
framework,
457+
tags,
458+
target_platform_os,
459+
target_platform_arch,
460+
target_platform_accelerator,
461+
compiler_options,
462+
framework_version,
463+
)
464+
413465
def _compilation_image_uri(self, region, target_instance_type, framework, framework_version):
414466
"""Retrieve the Neo or Inferentia image URI.
415467

src/sagemaker/session.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1840,6 +1840,57 @@ def compile_model(
18401840
LOGGER.info("Creating compilation-job with name: %s", job_name)
18411841
self.sagemaker_client.create_compilation_job(**compilation_job_request)
18421842

1843+
def _get_compilation_request(
1844+
self,
1845+
job_name,
1846+
input_model_config,
1847+
output_model_config,
1848+
role,
1849+
stop_condition,
1850+
tags=None,
1851+
vpc_config=None
1852+
):
1853+
"""Construct CreateCompilationJob request
1854+
1855+
Args:
1856+
input_model_config (dict): the trained model and the Amazon S3 location where it is
1857+
stored.
1858+
output_model_config (dict): Identifies the Amazon S3 location where you want Amazon
1859+
SageMaker Neo to save the results of compilation job
1860+
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker Neo
1861+
compilation jobs use this role to access model artifacts. You must grant
1862+
sufficient permissions to this role.
1863+
job_name (str): Name of the compilation job being created.
1864+
stop_condition (dict): Defines when compilation job shall finish. Contains entries
1865+
that can be understood by the service like ``MaxRuntimeInSeconds``.
1866+
tags (list[dict]): List of tags for labeling a compile model job. For more, see
1867+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
1868+
vpc_config (dict): Contains values for VpcConfig:
1869+
* subnets (list[str]): List of subnet ids.
1870+
The key in vpc_config is 'Subnets'.
1871+
* security_group_ids (list[str]): List of security group ids.
1872+
The key in vpc_config is 'SecurityGroupIds'.
1873+
Returns:
1874+
dict: A dictionary for CreateCompilationJob request
1875+
"""
1876+
1877+
compilation_request = {
1878+
"InputConfig": input_model_config,
1879+
"OutputConfig": output_model_config,
1880+
"RoleArn": role,
1881+
"StoppingCondition": stop_condition,
1882+
"CompilationJobName": job_name,
1883+
}
1884+
1885+
tags = _append_project_tags(tags)
1886+
if tags is not None:
1887+
compilation_request["Tags"] = tags
1888+
1889+
if vpc_config is not None:
1890+
compilation_request["VpcConfig"] = vpc_config
1891+
1892+
return compilation_request
1893+
18431894
def package_model_for_edge(
18441895
self,
18451896
output_model_config,

src/sagemaker/workflow/steps.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import attr
2222

2323
from sagemaker.estimator import EstimatorBase, _TrainingJob
24-
from sagemaker.inputs import CreateModelInput, TrainingInput, TransformInput, FileSystemInput
24+
from sagemaker.inputs import CreateModelInput, TrainingInput, TransformInput, FileSystemInput, CompilationInput
2525
from sagemaker.model import Model
2626
from sagemaker.processing import (
2727
ProcessingInput,
@@ -55,6 +55,7 @@ class StepTypeEnum(Enum, metaclass=DefaultEnumMeta):
5555
TRANSFORM = "Transform"
5656
CALLBACK = "Callback"
5757
TUNING = "Tuning"
58+
COMPILATION = "Compilation"
5859
LAMBDA = "Lambda"
5960

6061

@@ -681,3 +682,80 @@ def get_top_model_s3_uri(self, top_k: int, s3_bucket: str, prefix: str = "") ->
681682
"output/model.tar.gz",
682683
],
683684
)
685+
686+
687+
class CompilationStep(Step):
688+
"""Compilation step for workflow."""
689+
690+
def __init__(
691+
self,
692+
name: str,
693+
estimator: EstimatorBase,
694+
model: Model,
695+
inputs: CompilationInput = None,
696+
job_arguments: List[str] = None,
697+
depends_on: Union[List[str], List[Step]] = None,
698+
retry_policies: List[RetryPolicy] = None,
699+
display_name: str = None,
700+
description: str = None,
701+
cache_config: CacheConfig = None
702+
):
703+
"""Construct a CompilationStep, given an given an `EstimatorBase` instance and
704+
a `sagemaker.model.Model` instance.
705+
706+
In addition to the estimator and Model instances, the other arguments are those that are supplied to
707+
the `compile_model` method of the `sagemaker.model.Model.compile_model`.
708+
709+
Args:
710+
name (str): The name of the compilation step.
711+
estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
712+
model (Model): A `sagemaker.model.Model` instance.
713+
inputs (CompilationInput): A `sagemaker.inputs.CompilationInput` instance.
714+
Defaults to `None`.
715+
job_arguments (List[str]): A list of strings to be passed into the processing job.
716+
Defaults to `None`.
717+
depends_on (List[str] or List[Step]): A list of step names or step instances
718+
this `sagemaker.workflow.steps.CompilationStep` depends on
719+
retry_policies (List[RetryPolicy]): A list of retry policy
720+
display_name (str): The display name of the compilation step.
721+
description (str): The description of the compilation step.
722+
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
723+
"""
724+
super(CompilationStep, self).__init__(
725+
name, StepTypeEnum.COMPILATION, display_name, description, depends_on, retry_policies
726+
)
727+
self.estimator = estimator
728+
self.model = model
729+
self.inputs = inputs
730+
self.job_arguments = job_arguments
731+
self._properties = Properties(
732+
path=f"Steps.{name}", shape_name="DescribeCompilationJobResponse"
733+
)
734+
self.cache_config = cache_config
735+
736+
@property
737+
def arguments(self) -> RequestType:
738+
"""The arguments dict that is used to call `create_compilation_job`.
739+
740+
NOTE: The CreateTrainingJob request is not quite the args list that workflow needs.
741+
The TrainingJobName and ExperimentConfig attributes cannot be included.
742+
"""
743+
744+
compilation_args = self.model._get_compilation_args(self.estimator, self.inputs)
745+
request_dict = self.model.sagemaker_session._get_compilation_request(**compilation_args)
746+
request_dict.pop("CompilationJobName")
747+
748+
return request_dict
749+
750+
@property
751+
def properties(self):
752+
"""A Properties object representing the DescribeTrainingJobResponse data model."""
753+
return self._properties
754+
755+
def to_request(self) -> RequestType:
756+
"""Updates the dictionary with cache configuration."""
757+
request_dict = super().to_request()
758+
if self.cache_config:
759+
request_dict.update(self.cache_config.config)
760+
761+
return request_dict

0 commit comments

Comments
 (0)