Skip to content

Commit 9211562

Browse files
feature: add support for Std:Join for pipelines (#2103)
* feature: add support for Std:Join for pipelines * Update src/sagemaker/workflow/functions.py * Update tests/unit/sagemaker/workflow/test_functions.py * fix: ensure region is specified for workflow client * feature: Map image name to image uri (#2100) * Map image name to image uri * fix bug in test Co-authored-by: Neelesh Dodda <[email protected]>
1 parent c6effe5 commit 9211562

File tree

4 files changed

+160
-21
lines changed

4 files changed

+160
-21
lines changed

src/sagemaker/processing.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from sagemaker.session import Session
3232
from sagemaker.network import NetworkConfig # noqa: F401 # pylint: disable=unused-import
3333
from sagemaker.workflow.properties import Properties
34+
from sagemaker.workflow.entities import Expression
3435
from sagemaker.dataset_definition.inputs import S3Input, DatasetDefinition
3536
from sagemaker.apiutils._base_types import ApiObject
3637

@@ -338,6 +339,10 @@ def _normalize_outputs(self, outputs=None):
338339
# Generate a name for the ProcessingOutput if it doesn't have one.
339340
if output.output_name is None:
340341
output.output_name = "output-{}".format(count)
342+
# if the output's destination is a workflow expression, do no normalization
343+
if isinstance(output.destination, Expression):
344+
normalized_outputs.append(output)
345+
continue
341346
# If the output's destination is not an s3_uri, create one.
342347
parse_result = urlparse(output.destination)
343348
if parse_result.scheme != "s3":

src/sagemaker/workflow/functions.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""The step definitions for workflow."""
14+
from __future__ import absolute_import
15+
16+
from typing import List
17+
18+
import attr
19+
20+
from sagemaker.workflow.entities import Expression
21+
22+
23+
@attr.s
24+
class Join(Expression):
25+
"""Join together properties.
26+
27+
Attributes:
28+
values (List[Union[PrimitiveType, Parameter]]): The primitive types
29+
and parameters to join.
30+
on_str (str): The string to join the values on (Defaults to "").
31+
"""
32+
33+
on: str = attr.ib(factory=str)
34+
values: List = attr.ib(factory=list)
35+
36+
@property
37+
def expr(self):
38+
"""The expression dict for a `Join` function."""
39+
return {
40+
"Std:Join": {
41+
"On": self.on,
42+
"Values": [
43+
value.expr if hasattr(value, "expr") else value for value in self.values
44+
],
45+
},
46+
}

tests/integ/test_workflow.py

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
from sagemaker.workflow.conditions import ConditionGreaterThanOrEqualTo
3939
from sagemaker.workflow.condition_step import ConditionStep
4040
from sagemaker.dataset_definition.inputs import DatasetDefinition, AthenaDatasetDefinition
41+
from sagemaker.workflow.execution_variables import ExecutionVariables
42+
from sagemaker.workflow.functions import Join
4143
from sagemaker.workflow.parameters import (
4244
ParameterInteger,
4345
ParameterString,
@@ -72,16 +74,9 @@ def role(sagemaker_session):
7274
return get_execution_role(sagemaker_session)
7375

7476

75-
# TODO-reinvent-2020: remove use of specific region and this session
7677
@pytest.fixture(scope="module")
77-
def region():
78-
return "us-east-2"
79-
80-
81-
# TODO-reinvent-2020: remove use of specific region and this session
82-
@pytest.fixture(scope="module")
83-
def workflow_session(region):
84-
boto_session = boto3.Session(region_name=region)
78+
def workflow_session(region_name):
79+
boto_session = boto3.Session(region_name=region_name)
8580

8681
sagemaker_client_config = dict()
8782
sagemaker_client_config.setdefault("config", Config(retries=dict(max_attempts=2)))
@@ -134,6 +129,7 @@ def test_three_step_definition(
134129
framework_version = "0.20.0"
135130
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
136131
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
132+
output_prefix = ParameterString(name="OutputPrefix", default_value="output")
137133

138134
input_data = f"s3://sagemaker-sample-data-{region_name}/processing/census/census-income.csv"
139135

@@ -154,7 +150,20 @@ def test_three_step_definition(
154150
],
155151
outputs=[
156152
ProcessingOutput(output_name="train_data", source="/opt/ml/processing/train"),
157-
ProcessingOutput(output_name="test_data", source="/opt/ml/processing/test"),
153+
ProcessingOutput(
154+
output_name="test_data",
155+
source="/opt/ml/processing/test",
156+
destination=Join(
157+
on="/",
158+
values=[
159+
"s3:/",
160+
sagemaker_session.default_bucket(),
161+
"test-sklearn",
162+
output_prefix,
163+
ExecutionVariables.PIPELINE_EXECUTION_ID,
164+
],
165+
),
166+
),
158167
],
159168
code=os.path.join(script_dir, "preprocessing.py"),
160169
)
@@ -194,7 +203,7 @@ def test_three_step_definition(
194203

195204
pipeline = Pipeline(
196205
name=pipeline_name,
197-
parameters=[instance_type, instance_count],
206+
parameters=[instance_type, instance_count, output_prefix],
198207
steps=[step_process, step_train, step_model],
199208
sagemaker_session=workflow_session,
200209
)
@@ -208,6 +217,7 @@ def test_three_step_definition(
208217
{"Name": "InstanceType", "Type": "String", "DefaultValue": "ml.m5.xlarge"}.items()
209218
),
210219
tuple({"Name": "InstanceCount", "Type": "Integer", "DefaultValue": 1}.items()),
220+
tuple({"Name": "OutputPrefix", "Type": "String", "DefaultValue": "output"}.items()),
211221
]
212222
)
213223

@@ -251,17 +261,28 @@ def test_three_step_definition(
251261
assert model_args["PrimaryContainer"]["ModelDataUrl"] == {
252262
"Get": "Steps.my-train.ModelArtifacts.S3ModelArtifacts"
253263
}
264+
try:
265+
response = pipeline.create(role)
266+
create_arn = response["PipelineArn"]
267+
assert re.match(
268+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
269+
create_arn,
270+
)
271+
finally:
272+
try:
273+
pipeline.delete()
274+
except Exception:
275+
pass
254276

255277

256-
# TODO-reinvent-2020: Modify use of the workflow client
257278
def test_one_step_sklearn_processing_pipeline(
258279
sagemaker_session,
259280
workflow_session,
260281
role,
261282
sklearn_latest_version,
262283
cpu_instance_type,
263284
pipeline_name,
264-
region,
285+
region_name,
265286
athena_dataset_definition,
266287
):
267288
instance_count = ParameterInteger(name="InstanceCount", default_value=2)
@@ -305,21 +326,21 @@ def test_one_step_sklearn_processing_pipeline(
305326
response = pipeline.create(role)
306327
create_arn = response["PipelineArn"]
307328
assert re.match(
308-
fr"arn:aws:sagemaker:{region}:\d{{12}}:pipeline/{pipeline_name}",
329+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
309330
create_arn,
310331
)
311332

312333
pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)]
313334
response = pipeline.update(role)
314335
update_arn = response["PipelineArn"]
315336
assert re.match(
316-
fr"arn:aws:sagemaker:{region}:\d{{12}}:pipeline/{pipeline_name}",
337+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
317338
update_arn,
318339
)
319340

320341
execution = pipeline.start(parameters={})
321342
assert re.match(
322-
fr"arn:aws:sagemaker:{region}:\d{{12}}:pipeline/{pipeline_name}/execution/",
343+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/",
323344
execution.arn,
324345
)
325346

@@ -340,14 +361,13 @@ def test_one_step_sklearn_processing_pipeline(
340361
pass
341362

342363

343-
# TODO-reinvent-2020: Modify use of the workflow client
344364
def test_conditional_pytorch_training_model_registration(
345365
sagemaker_session,
346366
workflow_session,
347367
role,
348368
cpu_instance_type,
349369
pipeline_name,
350-
region,
370+
region_name,
351371
):
352372
base_dir = os.path.join(DATA_DIR, "pytorch_mnist")
353373
entry_point = os.path.join(base_dir, "mnist.py")
@@ -420,18 +440,18 @@ def test_conditional_pytorch_training_model_registration(
420440
response = pipeline.create(role)
421441
create_arn = response["PipelineArn"]
422442
assert re.match(
423-
fr"arn:aws:sagemaker:{region}:\d{{12}}:pipeline/{pipeline_name}", create_arn
443+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", create_arn
424444
)
425445

426446
execution = pipeline.start(parameters={})
427447
assert re.match(
428-
fr"arn:aws:sagemaker:{region}:\d{{12}}:pipeline/{pipeline_name}/execution/",
448+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/",
429449
execution.arn,
430450
)
431451

432452
execution = pipeline.start(parameters={"GoodEnoughInput": 0})
433453
assert re.match(
434-
fr"arn:aws:sagemaker:{region}:\d{{12}}:pipeline/{pipeline_name}/execution/",
454+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/",
435455
execution.arn,
436456
)
437457
finally:
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
# language governing permissions and limitations under the License.
14+
from __future__ import absolute_import
15+
16+
from sagemaker.workflow.execution_variables import ExecutionVariables
17+
from sagemaker.workflow.functions import Join
18+
from sagemaker.workflow.parameters import (
19+
ParameterFloat,
20+
ParameterInteger,
21+
ParameterString,
22+
)
23+
from sagemaker.workflow.properties import Properties
24+
25+
26+
def test_join_primitives_default_on():
27+
assert Join(values=[1, "a", False, 1.1]).expr == {
28+
"Std:Join": {
29+
"On": "",
30+
"Values": [1, "a", False, 1.1],
31+
},
32+
}
33+
34+
35+
def test_join_primitives():
36+
assert Join(on=",", values=[1, "a", False, 1.1]).expr == {
37+
"Std:Join": {
38+
"On": ",",
39+
"Values": [1, "a", False, 1.1],
40+
},
41+
}
42+
43+
44+
def test_join_expressions():
45+
assert Join(
46+
values=[
47+
"foo",
48+
ParameterFloat(name="MyFloat"),
49+
ParameterInteger(name="MyInt"),
50+
ParameterString(name="MyStr"),
51+
Properties(path="Steps.foo.OutputPath.S3Uri"),
52+
ExecutionVariables.PIPELINE_EXECUTION_ID,
53+
Join(on=",", values=[1, "a", False, 1.1]),
54+
]
55+
).expr == {
56+
"Std:Join": {
57+
"On": "",
58+
"Values": [
59+
"foo",
60+
{"Get": "Parameters.MyFloat"},
61+
{"Get": "Parameters.MyInt"},
62+
{"Get": "Parameters.MyStr"},
63+
{"Get": "Steps.foo.OutputPath.S3Uri"},
64+
{"Get": "Execution.PipelineExecutionId"},
65+
{"Std:Join": {"On": ",", "Values": [1, "a", False, 1.1]}},
66+
],
67+
},
68+
}

0 commit comments

Comments
 (0)