Skip to content

Commit 2cae789

Browse files
author
Payton Staub
committed
Add callback step
1 parent d51bda1 commit 2cae789

File tree

3 files changed

+159
-0
lines changed

3 files changed

+159
-0
lines changed
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""The step definitions for workflow."""
14+
from __future__ import absolute_import
15+
16+
from typing import List
17+
from enum import Enum
18+
19+
import attr
20+
21+
from sagemaker.workflow.entities import (
22+
RequestType,
23+
)
24+
from sagemaker.workflow.properties import (
25+
Properties,
26+
)
27+
from sagemaker.workflow.entities import (
28+
DefaultEnumMeta,
29+
)
30+
from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig
31+
32+
33+
class CallbackOutputTypeEnum(Enum, metaclass=DefaultEnumMeta):
34+
"""CallbackOutput type enum."""
35+
36+
String = "String"
37+
Integer = "Integer"
38+
Boolean = "Boolean"
39+
Float = "Float"
40+
41+
42+
@attr.s
43+
class CallbackOutput:
44+
"""Output for a callback step. The value of this output is provided at time of callback
45+
via the SendPipelineExecutionStepSuccess API
46+
47+
Attributes:
48+
output_name (str): The output name
49+
output_type (CallbackOutputTypeEnum): The output type
50+
"""
51+
52+
output_name: str = attr.ib(default=None)
53+
output_type: CallbackOutputTypeEnum = attr.ib(default=CallbackOutputTypeEnum.String.value)
54+
55+
def to_request(self) -> RequestType:
56+
"""Get the request structure for workflow service calls."""
57+
return {
58+
"OutputName": self.output_name,
59+
"OutputType": self.output_type.value,
60+
}
61+
62+
63+
class CallbackStep(Step):
64+
"""Callback step for workflow."""
65+
66+
def __init__(
67+
self,
68+
name: str,
69+
sqs_queue_url: str,
70+
inputs: dict,
71+
outputs: List[CallbackOutput],
72+
cache_config: CacheConfig = None,
73+
depends_on: List[str] = None,
74+
):
75+
"""Constructs a CallbackStep.
76+
77+
Args:
78+
name (str): The name of the callback step.
79+
sqs_queue_url (str): An SQS queue URL for receiving callback messages.
80+
inputs (dict): Input arguments that will be provided in the SQS message body of callback messages.
81+
outputs (List[CallbackOutput]): Outputs that can be provided when completing a callback.
82+
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
83+
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TransformStep`
84+
depends on
85+
"""
86+
super(CallbackStep, self).__init__(name, StepTypeEnum.CALLBACK, depends_on)
87+
self.sqs_queue_url = sqs_queue_url
88+
self.outputs = outputs
89+
self.cache_config = cache_config
90+
self.inputs = inputs
91+
92+
root_path = f"Steps.{name}"
93+
root_prop = Properties(path=root_path)
94+
root_prop.__dict__["OutputParameters"] = Properties(f"{root_path}.OutputParameters")
95+
self._properties = root_prop
96+
97+
@property
98+
def arguments(self) -> RequestType:
99+
"""The arguments dict that is used to define the callback step."""
100+
return self.inputs
101+
102+
@property
103+
def properties(self):
104+
"""A Properties object representing the output parameters of the callback step."""
105+
return self._properties
106+
107+
def to_request(self) -> RequestType:
108+
"""Updates the dictionary with cache configuration."""
109+
request_dict = super().to_request()
110+
if self.cache_config:
111+
request_dict.update(self.cache_config.config)
112+
113+
request_dict["SqsQueueUrl"] = self.sqs_queue_url
114+
request_dict["OutputParameters"] = list(map(lambda op: op.to_request(), self.outputs))
115+
116+
return request_dict

src/sagemaker/workflow/steps.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class StepTypeEnum(Enum, metaclass=DefaultEnumMeta):
5454
REGISTER_MODEL = "RegisterModel"
5555
TRAINING = "Training"
5656
TRANSFORM = "Transform"
57+
CALLBACK = "Callback"
5758

5859

5960
@attr.s
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
# language governing permissions and limitations under the License.
14+
from __future__ import absolute_import
15+
16+
from sagemaker.workflow.parameters import ParameterInteger
17+
from sagemaker.workflow.callback_step import CallbackStep, CallbackOutput, CallbackOutputTypeEnum
18+
19+
20+
def test_callback_step():
21+
param = ParameterInteger(name="MyInt")
22+
outputParam1 = CallbackOutput(output_name="output1", output_type=CallbackOutputTypeEnum.String)
23+
outputParam2 = CallbackOutput(output_name="output2", output_type=CallbackOutputTypeEnum.Boolean)
24+
cb_step = CallbackStep(
25+
name="MyCallbackStep",
26+
depends_on=["TestStep"],
27+
sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
28+
inputs={"arg1": "foo", "arg2": 5, "arg3": param},
29+
outputs=[outputParam1, outputParam2],
30+
)
31+
cb_step.add_depends_on(["SecondTestStep"])
32+
assert cb_step.to_request() == {
33+
"Name": "MyCallbackStep",
34+
"Type": "Callback",
35+
"DependsOn": ["TestStep", "SecondTestStep"],
36+
"SqsQueueUrl": "https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
37+
"OutputParameters": [
38+
{"OutputName": "output1", "OutputType": "String"},
39+
{"OutputName": "output2", "OutputType": "Boolean"},
40+
],
41+
"Arguments": {"arg1": "foo", "arg2": 5, "arg3": param},
42+
}

0 commit comments

Comments
 (0)