Skip to content

Commit b39eef5

Browse files
authored
Merge branch 'master' into master
2 parents 967a0dd + 0a63eec commit b39eef5

File tree

5 files changed

+309
-1
lines changed

5 files changed

+309
-1
lines changed
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""The step definitions for workflow."""
14+
from __future__ import absolute_import
15+
16+
from typing import List, Dict
17+
from enum import Enum
18+
19+
import attr
20+
21+
from sagemaker.workflow.entities import (
22+
RequestType,
23+
)
24+
from sagemaker.workflow.properties import (
25+
Properties,
26+
)
27+
from sagemaker.workflow.entities import (
28+
DefaultEnumMeta,
29+
)
30+
from sagemaker.workflow.steps import Step, StepTypeEnum, CacheConfig
31+
32+
33+
class CallbackOutputTypeEnum(Enum, metaclass=DefaultEnumMeta):
34+
"""CallbackOutput type enum."""
35+
36+
String = "String"
37+
Integer = "Integer"
38+
Boolean = "Boolean"
39+
Float = "Float"
40+
41+
42+
@attr.s
43+
class CallbackOutput:
44+
"""Output for a callback step.
45+
46+
Attributes:
47+
output_name (str): The output name
48+
output_type (CallbackOutputTypeEnum): The output type
49+
"""
50+
51+
output_name: str = attr.ib(default=None)
52+
output_type: CallbackOutputTypeEnum = attr.ib(default=CallbackOutputTypeEnum.String.value)
53+
54+
def to_request(self) -> RequestType:
55+
"""Get the request structure for workflow service calls."""
56+
return {
57+
"OutputName": self.output_name,
58+
"OutputType": self.output_type.value,
59+
}
60+
61+
@property
62+
def expr(self) -> Dict[str, str]:
63+
"""The 'Get' expression dict for a `Parameter`."""
64+
return CallbackOutput._expr(self.output_name)
65+
66+
@classmethod
67+
def _expr(cls, name):
68+
"""An internal classmethod for the 'Get' expression dict for a `CallbackOutput`.
69+
70+
Args:
71+
name (str): The name of the callback output.
72+
"""
73+
return {"Get": f"Steps.{name}.OutputParameters['{name}']"}
74+
75+
76+
class CallbackStep(Step):
77+
"""Callback step for workflow."""
78+
79+
def __init__(
80+
self,
81+
name: str,
82+
sqs_queue_url: str,
83+
inputs: dict,
84+
outputs: List[CallbackOutput],
85+
cache_config: CacheConfig = None,
86+
depends_on: List[str] = None,
87+
):
88+
"""Constructs a CallbackStep.
89+
90+
Args:
91+
name (str): The name of the callback step.
92+
sqs_queue_url (str): An SQS queue URL for receiving callback messages.
93+
inputs (dict): Input arguments that will be provided
94+
in the SQS message body of callback messages.
95+
outputs (List[CallbackOutput]): Outputs that can be provided when completing a callback.
96+
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
97+
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TransformStep`
98+
depends on
99+
"""
100+
super(CallbackStep, self).__init__(name, StepTypeEnum.CALLBACK, depends_on)
101+
self.sqs_queue_url = sqs_queue_url
102+
self.outputs = outputs
103+
self.cache_config = cache_config
104+
self.inputs = inputs
105+
106+
root_path = f"Steps.{name}"
107+
root_prop = Properties(path=root_path)
108+
109+
property_dict = {}
110+
for output in outputs:
111+
property_dict[output.output_name] = Properties(
112+
f"{root_path}.OutputParameters['{output.output_name}']"
113+
)
114+
115+
root_prop.__dict__["Outputs"] = property_dict
116+
self._properties = root_prop
117+
118+
@property
119+
def arguments(self) -> RequestType:
120+
"""The arguments dict that is used to define the callback step."""
121+
return self.inputs
122+
123+
@property
124+
def properties(self):
125+
"""A Properties object representing the output parameters of the callback step."""
126+
return self._properties
127+
128+
def to_request(self) -> RequestType:
129+
"""Updates the dictionary with cache configuration."""
130+
request_dict = super().to_request()
131+
if self.cache_config:
132+
request_dict.update(self.cache_config.config)
133+
134+
request_dict["SqsQueueUrl"] = self.sqs_queue_url
135+
request_dict["OutputParameters"] = list(map(lambda op: op.to_request(), self.outputs))
136+
137+
return request_dict

src/sagemaker/workflow/pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
from sagemaker._studio import _append_project_tags
2626
from sagemaker.session import Session
27+
from sagemaker.workflow.callback_step import CallbackOutput
2728
from sagemaker.workflow.entities import (
2829
Entity,
2930
Expression,
@@ -281,7 +282,7 @@ def _interpolate(obj: Union[RequestType, Any]):
281282
Args:
282283
obj (Union[RequestType, Any]): The request dict.
283284
"""
284-
if isinstance(obj, (Expression, Parameter, Properties)):
285+
if isinstance(obj, (Expression, Parameter, Properties, CallbackOutput)):
285286
return obj.expr
286287
if isinstance(obj, dict):
287288
new = obj.__class__()

src/sagemaker/workflow/steps.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class StepTypeEnum(Enum, metaclass=DefaultEnumMeta):
5454
REGISTER_MODEL = "RegisterModel"
5555
TRAINING = "Training"
5656
TRANSFORM = "Transform"
57+
CALLBACK = "Callback"
5758

5859

5960
@attr.s

tests/integ/test_workflow.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from sagemaker.spark.processing import PySparkProcessor, SparkJarProcessor
4646
from sagemaker.workflow.conditions import ConditionGreaterThanOrEqualTo, ConditionIn
4747
from sagemaker.workflow.condition_step import ConditionStep
48+
from sagemaker.workflow.callback_step import CallbackStep, CallbackOutput, CallbackOutputTypeEnum
4849
from sagemaker.wrangler.processing import DataWranglerProcessor
4950
from sagemaker.dataset_definition.inputs import DatasetDefinition, AthenaDatasetDefinition
5051
from sagemaker.workflow.execution_variables import ExecutionVariables
@@ -698,6 +699,46 @@ def test_one_step_sparkjar_processing_pipeline(
698699
pass
699700

700701

702+
def test_one_step_callback_pipeline(sagemaker_session, role, pipeline_name, region_name):
703+
instance_count = ParameterInteger(name="InstanceCount", default_value=2)
704+
705+
outputParam1 = CallbackOutput(output_name="output1", output_type=CallbackOutputTypeEnum.String)
706+
step_callback = CallbackStep(
707+
name="callback-step",
708+
sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
709+
inputs={"arg1": "foo"},
710+
outputs=[outputParam1],
711+
)
712+
713+
pipeline = Pipeline(
714+
name=pipeline_name,
715+
parameters=[instance_count],
716+
steps=[step_callback],
717+
sagemaker_session=sagemaker_session,
718+
)
719+
720+
try:
721+
response = pipeline.create(role)
722+
create_arn = response["PipelineArn"]
723+
assert re.match(
724+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
725+
create_arn,
726+
)
727+
728+
pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)]
729+
response = pipeline.update(role)
730+
update_arn = response["PipelineArn"]
731+
assert re.match(
732+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
733+
update_arn,
734+
)
735+
finally:
736+
try:
737+
pipeline.delete()
738+
except Exception:
739+
pass
740+
741+
701742
def test_conditional_pytorch_training_model_registration(
702743
sagemaker_session,
703744
role,
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import json
16+
17+
import pytest
18+
19+
from mock import Mock
20+
21+
from sagemaker.workflow.parameters import ParameterInteger, ParameterString
22+
from sagemaker.workflow.pipeline import Pipeline
23+
from sagemaker.workflow.callback_step import CallbackStep, CallbackOutput, CallbackOutputTypeEnum
24+
25+
26+
@pytest.fixture
27+
def sagemaker_session_mock():
28+
return Mock()
29+
30+
31+
def test_callback_step():
32+
param = ParameterInteger(name="MyInt")
33+
outputParam1 = CallbackOutput(output_name="output1", output_type=CallbackOutputTypeEnum.String)
34+
outputParam2 = CallbackOutput(output_name="output2", output_type=CallbackOutputTypeEnum.Boolean)
35+
cb_step = CallbackStep(
36+
name="MyCallbackStep",
37+
depends_on=["TestStep"],
38+
sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
39+
inputs={"arg1": "foo", "arg2": 5, "arg3": param},
40+
outputs=[outputParam1, outputParam2],
41+
)
42+
cb_step.add_depends_on(["SecondTestStep"])
43+
assert cb_step.to_request() == {
44+
"Name": "MyCallbackStep",
45+
"Type": "Callback",
46+
"DependsOn": ["TestStep", "SecondTestStep"],
47+
"SqsQueueUrl": "https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
48+
"OutputParameters": [
49+
{"OutputName": "output1", "OutputType": "String"},
50+
{"OutputName": "output2", "OutputType": "Boolean"},
51+
],
52+
"Arguments": {"arg1": "foo", "arg2": 5, "arg3": param},
53+
}
54+
55+
56+
def test_callback_step_output_expr():
57+
param = ParameterInteger(name="MyInt")
58+
outputParam1 = CallbackOutput(output_name="output1", output_type=CallbackOutputTypeEnum.String)
59+
outputParam2 = CallbackOutput(output_name="output2", output_type=CallbackOutputTypeEnum.Boolean)
60+
cb_step = CallbackStep(
61+
name="MyCallbackStep",
62+
depends_on=["TestStep"],
63+
sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
64+
inputs={"arg1": "foo", "arg2": 5, "arg3": param},
65+
outputs=[outputParam1, outputParam2],
66+
)
67+
68+
assert cb_step.properties.Outputs["output1"].expr == {
69+
"Get": "Steps.MyCallbackStep.OutputParameters['output1']"
70+
}
71+
assert cb_step.properties.Outputs["output2"].expr == {
72+
"Get": "Steps.MyCallbackStep.OutputParameters['output2']"
73+
}
74+
75+
76+
def test_pipeline_interpolates_callback_outputs():
77+
parameter = ParameterString("MyStr")
78+
outputParam1 = CallbackOutput(output_name="output1", output_type=CallbackOutputTypeEnum.String)
79+
outputParam2 = CallbackOutput(output_name="output2", output_type=CallbackOutputTypeEnum.String)
80+
cb_step1 = CallbackStep(
81+
name="MyCallbackStep1",
82+
depends_on=["TestStep"],
83+
sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
84+
inputs={"arg1": "foo"},
85+
outputs=[outputParam1],
86+
)
87+
cb_step2 = CallbackStep(
88+
name="MyCallbackStep2",
89+
depends_on=["TestStep"],
90+
sqs_queue_url="https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
91+
inputs={"arg1": cb_step1.properties.Outputs["output1"]},
92+
outputs=[outputParam2],
93+
)
94+
95+
pipeline = Pipeline(
96+
name="MyPipeline",
97+
parameters=[parameter],
98+
steps=[cb_step1, cb_step2],
99+
sagemaker_session=sagemaker_session_mock,
100+
)
101+
102+
assert json.loads(pipeline.definition()) == {
103+
"Version": "2020-12-01",
104+
"Metadata": {},
105+
"Parameters": [{"Name": "MyStr", "Type": "String"}],
106+
"PipelineExperimentConfig": {
107+
"ExperimentName": {"Get": "Execution.PipelineName"},
108+
"TrialName": {"Get": "Execution.PipelineExecutionId"},
109+
},
110+
"Steps": [
111+
{
112+
"Name": "MyCallbackStep1",
113+
"Type": "Callback",
114+
"Arguments": {"arg1": "foo"},
115+
"DependsOn": ["TestStep"],
116+
"SqsQueueUrl": "https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
117+
"OutputParameters": [{"OutputName": "output1", "OutputType": "String"}],
118+
},
119+
{
120+
"Name": "MyCallbackStep2",
121+
"Type": "Callback",
122+
"Arguments": {"arg1": {"Get": "Steps.MyCallbackStep1.OutputParameters['output1']"}},
123+
"DependsOn": ["TestStep"],
124+
"SqsQueueUrl": "https://sqs.us-east-2.amazonaws.com/123456789012/MyQueue",
125+
"OutputParameters": [{"OutputName": "output2", "OutputType": "String"}],
126+
},
127+
],
128+
}

0 commit comments

Comments
 (0)