Skip to content

Commit 6cf53c9

Browse files
Added Unit test.
Bug fixes and changes for review comments
1 parent 16b2e94 commit 6cf53c9

File tree

4 files changed

+75
-14
lines changed

4 files changed

+75
-14
lines changed

src/sagemaker/inputs.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -152,29 +152,33 @@ class CompilationInput(object):
152152
example: {'data': [1,3,1024,1024]}, or {'var1': [1,1,28,28],
153153
'var2': [1,1,28,28]}
154154
output_path(str): Specifies where to store the compiled model
155-
framework (str): The framework that is used to train the original
155+
framework (str, optional): The framework that is used to train the original
156156
model. Allowed values: 'mxnet', 'tensorflow', 'keras', 'pytorch',
157-
'onnx', 'xgboost'
158-
framework_version (str): The version of the framework
159-
compile_max_run (int): Timeout in seconds for compilation (default:
157+
'onnx', 'xgboost' (default: None)
158+
framework_version (str, optional): The version of the framework (default: None)
159+
compile_max_run (int, optional): Timeout in seconds for compilation (default:
160160
15 * 60). After this amount of time Amazon SageMaker Neo
161161
terminates the compilation job regardless of its current status.
162-
tags (list[dict]): List of tags for labeling a compilation job. For
163-
more, see
162+
tags (list[dict], optional): List of tags for labeling a compilation job.
163+
For more, see
164164
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
165-
target_platform_os (str): Target Platform OS, for example: 'LINUX'.
165+
job_name (str, optional): The name of the compilation job (default: None)
166+
target_platform_os (str, optional): Target Platform OS, for example: 'LINUX'.
167+
(default: None)
166168
For allowed strings see
167169
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
168170
It can be used instead of target_instance_family.
169-
target_platform_arch (str): Target Platform Architecture, for example: 'X86_64'.
171+
target_platform_arch (str, optional): Target Platform Architecture, for example: 'X86_64'.
172+
(default: None)
170173
For allowed strings see
171174
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
172175
It can be used instead of target_instance_family.
173176
target_platform_accelerator (str, optional): Target Platform Accelerator,
174-
for example: 'NVIDIA'. For allowed strings see
177+
for example: 'NVIDIA'. (default: None)
178+
For allowed strings see
175179
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html.
176180
It can be used instead of target_instance_family.
177-
compiler_options (dict, optional): Additional parameters for compiler.
181+
compiler_options (dict, optional): Additional parameters for compiler. (default: None)
178182
Compiler Options are TargetPlatform / target_instance_family specific. See
179183
https://docs.aws.amazon.com/sagemaker/latest/dg/API_OutputConfig.html for details.
180184
"""
@@ -186,10 +190,11 @@ class CompilationInput(object):
186190
framework_version: str = attr.ib(default=None)
187191
compile_max_run: int = attr.ib(default=15 * 60)
188192
tags: list = attr.ib(factory=list)
193+
job_name: str = attr.ib(default=None)
189194
target_platform_os: str = attr.ib(default=None)
190195
target_platform_arch: str = attr.ib(default=None)
191196
target_platform_accelerator: str = attr.ib(default=None)
192-
compiler_options: dict = attr.ib(factory=dict)
197+
compiler_options: dict = attr.ib(default=None)
193198

194199

195200
@attr.s

src/sagemaker/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ def _get_compilation_args(self, estimator, inputs):
422422
"""
423423

424424
if not isinstance(inputs, CompilationInput):
425-
raise TypeError("Your inputs must be provided as ProcessingInput objects.")
425+
raise TypeError("Your inputs must be provided as CompilationInput objects.")
426426
target_instance_family = inputs.target_instance_type
427427
input_shape = inputs.input_shape
428428
output_path = inputs.output_path

src/sagemaker/workflow/steps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,7 @@ def get_top_model_s3_uri(self, top_k: int, s3_bucket: str, prefix: str = "") ->
682682
)
683683

684684

685-
class CompilationStep(Step):
685+
class CompilationStep(ConfigurableRetryStep):
686686
"""Compilation step for workflow."""
687687

688688
def __init__(

tests/unit/sagemaker/workflow/test_steps.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from sagemaker.debugger import DEBUGGER_FLAG, ProfilerConfig
2727
from sagemaker.estimator import Estimator
2828
from sagemaker.tensorflow import TensorFlow
29-
from sagemaker.inputs import TrainingInput, TransformInput, CreateModelInput
29+
from sagemaker.inputs import TrainingInput, TransformInput, CreateModelInput, CompilationInput
3030
from sagemaker.model import Model
3131
from sagemaker.processing import (
3232
Processor,
@@ -52,6 +52,7 @@
5252
)
5353
from sagemaker.workflow.steps import (
5454
ProcessingStep,
55+
CompilationStep,
5556
ConfigurableRetryStep,
5657
StepTypeEnum,
5758
TrainingStep,
@@ -1054,3 +1055,58 @@ def test_multi_algo_tuning_step(sagemaker_session):
10541055
],
10551056
},
10561057
}
1058+
1059+
1060+
def test_compilation_step(sagemaker_session):
1061+
estimator = Estimator(
1062+
image_uri=IMAGE_URI,
1063+
role=ROLE,
1064+
instance_count=1,
1065+
instance_type="ml.c5.4xlarge",
1066+
profiler_config=ProfilerConfig(system_monitor_interval_millis=500),
1067+
rules=[],
1068+
sagemaker_session=sagemaker_session,
1069+
)
1070+
1071+
model = Model(
1072+
image_uri=IMAGE_URI,
1073+
model_data="s3://output/tensorflow.tar.gz",
1074+
sagemaker_session=sagemaker_session,
1075+
)
1076+
1077+
compilation_input = CompilationInput(
1078+
target_instance_type="ml_inf",
1079+
input_shape={"data": [1, 3, 1024, 1024]},
1080+
output_path="s3://output",
1081+
compile_max_run=100,
1082+
framework="tensorflow",
1083+
job_name="compile-model",
1084+
compiler_options=None
1085+
)
1086+
compilation_step = CompilationStep(
1087+
name="MyCompilationStep",
1088+
estimator=estimator,
1089+
model=model,
1090+
inputs=compilation_input
1091+
)
1092+
1093+
assert compilation_step.to_request() == {
1094+
"Name": "MyCompilationStep",
1095+
"Type": "Compilation",
1096+
"Arguments": {
1097+
"InputConfig": {
1098+
"DataInputConfig": '{"data": [1, 3, 1024, 1024]}',
1099+
"Framework": "TENSORFLOW",
1100+
"S3Uri": "s3://output/tensorflow.tar.gz"
1101+
},
1102+
"OutputConfig": {
1103+
"S3OutputLocation": "s3://output",
1104+
"TargetDevice": "ml_inf"
1105+
},
1106+
"RoleArn": ROLE,
1107+
"StoppingCondition": {
1108+
"MaxRuntimeInSeconds": 100
1109+
},
1110+
'Tags': []
1111+
},
1112+
}

0 commit comments

Comments
 (0)