Skip to content

Commit ca55485

Browse files
add data wrangler processor
1 parent 7b1e5c1 commit ca55485

File tree

5 files changed

+380
-31
lines changed

5 files changed

+380
-31
lines changed

src/sagemaker/workflow/processing.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# #
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
# #
7+
# http://aws.amazon.com/apache2.0/
8+
# #
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""The process definitions for workflow."""
14+
15+
from __future__ import absolute_import
16+
17+
from sagemaker.processing import (
18+
ProcessingInput,
19+
Processor,
20+
)
21+
from sagemaker import image_uris
22+
from sagemaker.session import Session
23+
24+
25+
class DataWranglerProcessor(Processor):
26+
"""Handles Amazon SageMaker DataWrangler tasks"""
27+
28+
def __init__(
29+
self,
30+
role,
31+
data_wrangler_recipe_source,
32+
instance_count,
33+
instance_type,
34+
volume_size_in_gb=30,
35+
volume_kms_key=None,
36+
output_kms_key=None,
37+
max_runtime_in_seconds=None,
38+
base_job_name=None,
39+
sagemaker_session=None,
40+
tags=None,
41+
network_config=None,
42+
):
43+
"""Initializes a ``Processor`` instance.
44+
45+
The ``Processor`` handles Amazon SageMaker Processing tasks.
46+
47+
Args:
48+
role (str): An AWS IAM role name or ARN. Amazon SageMaker Processing
49+
uses this role to access AWS resources, such as
50+
data stored in Amazon S3.
51+
data_wrangler_recipe_source (str): The source of the DaraWrangler recipe which will be
52+
used for the DataWrangler job. If a local path is provided, it will automatically be uploaded to S3
53+
under: "s3://<default-bucket-name>/<job-name>/input/<input-name>".
54+
instance_count (int): The number of instances to run
55+
a processing job with.
56+
instance_type (str): The type of EC2 instance to use for
57+
processing, for example, 'ml.c4.xlarge'.
58+
volume_size_in_gb (int): Size in GB of the EBS volume
59+
to use for storing data during processing (default: 30).
60+
volume_kms_key (str): A KMS key for the processing
61+
volume (default: None).
62+
output_kms_key (str): The KMS key ID for processing job outputs (default: None).
63+
max_runtime_in_seconds (int): Timeout in seconds (default: None).
64+
After this amount of time, Amazon SageMaker terminates the job,
65+
regardless of its current status. If `max_runtime_in_seconds` is not
66+
specified, the default value is 24 hours.
67+
base_job_name (str): Prefix for processing job name. If not specified,
68+
the processor generates a default job name, based on the
69+
processing image name and current timestamp.
70+
sagemaker_session (:class:`~sagemaker.session.Session`):
71+
Session object which manages interactions with Amazon SageMaker and
72+
any other AWS services needed. If not specified, the processor creates
73+
one using the default AWS configuration chain.
74+
tags (list[dict]): List of tags to be passed to the processing job
75+
(default: None). For more, see
76+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
77+
network_config (:class:`~sagemaker.network.NetworkConfig`):
78+
A :class:`~sagemaker.network.NetworkConfig`
79+
object that configures network isolation, encryption of
80+
inter-container traffic, security group IDs, and subnets.
81+
"""
82+
self.data_wrangler_recipe_source = data_wrangler_recipe_source
83+
self.sagemaker_session = sagemaker_session or Session()
84+
image_uri = image_uris.retrieve(
85+
"data-wrangler", region=self.sagemaker_session.boto_region_name
86+
)
87+
super().__init__(
88+
role,
89+
image_uri,
90+
instance_count,
91+
instance_type,
92+
volume_size_in_gb=volume_size_in_gb,
93+
volume_kms_key=volume_kms_key,
94+
output_kms_key=output_kms_key,
95+
max_runtime_in_seconds=max_runtime_in_seconds,
96+
base_job_name=base_job_name,
97+
sagemaker_session=sagemaker_session,
98+
tags=tags,
99+
network_config=network_config,
100+
)
101+
102+
def _normalize_args(
103+
self,
104+
job_name=None,
105+
arguments=None,
106+
inputs=None,
107+
outputs=None,
108+
code=None,
109+
kms_key=None,
110+
):
111+
"""Normalizes the arguments so that they can be passed to the job run
112+
113+
Args:
114+
job_name (str): Name of the processing job to be created. If not specified, one
115+
is generated, using the base name given to the constructor, if applicable
116+
(default: None).
117+
arguments (list[str]): A list of string arguments to be passed to a
118+
processing job (default: None).
119+
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
120+
the processing job. These must be provided as
121+
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
122+
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
123+
the processing job. These can be specified as either path strings or
124+
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
125+
code (str): This can be an S3 URI or a local path to a file with the framework
126+
script to run (default: None). A no op in the base class.
127+
kms_key (str): The ARN of the KMS key that is used to encrypt the
128+
user code file (default: None).
129+
"""
130+
inputs = inputs or []
131+
inputs.append(self._get_recipe_input())
132+
return super()._normalize_args(job_name, arguments, inputs, outputs, code, kms_key)
133+
134+
def _get_recipe_input(self):
135+
"""Creates a ProcessingInput with Data Wrangler recipe uri and appends it to inputs"""
136+
return ProcessingInput(
137+
source=self.data_wrangler_recipe_source,
138+
destination="/opt/ml/processing/flow",
139+
input_name="flow",
140+
s3_data_type="S3Prefix",
141+
s3_input_mode="File",
142+
s3_data_distribution_type="FullyReplicated",
143+
)

tests/data/workflow/dummy_data.csv

Whitespace-only changes.

tests/data/workflow/dummy_recipe.flow

Whitespace-only changes.

tests/integ/test_workflow.py

Lines changed: 108 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
rule_configs,
2929
)
3030
from datetime import datetime
31+
from sagemaker import image_uris
3132
from sagemaker.inputs import CreateModelInput, TrainingInput
3233
from sagemaker.model import Model
3334
from sagemaker.processing import ProcessingInput, ProcessingOutput
@@ -39,6 +40,7 @@
3940
from sagemaker.spark.processing import PySparkProcessor, SparkJarProcessor
4041
from sagemaker.workflow.conditions import ConditionGreaterThanOrEqualTo
4142
from sagemaker.workflow.condition_step import ConditionStep
43+
from sagemaker.workflow.processing import DataWranglerProcessor
4244
from sagemaker.dataset_definition.inputs import DatasetDefinition, AthenaDatasetDefinition
4345
from sagemaker.workflow.execution_variables import ExecutionVariables
4446
from sagemaker.workflow.functions import Join
@@ -84,7 +86,7 @@ def script_dir():
8486

8587
@pytest.fixture
8688
def pipeline_name():
87-
return f"my-pipeline-{int(time.time() * 10**7)}"
89+
return f"my-pipeline-{int(time.time() * 10 ** 7)}"
8890

8991

9092
@pytest.fixture
@@ -228,12 +230,12 @@ def build_jar():
228230

229231

230232
def test_three_step_definition(
231-
sagemaker_session,
232-
region_name,
233-
role,
234-
script_dir,
235-
pipeline_name,
236-
athena_dataset_definition,
233+
sagemaker_session,
234+
region_name,
235+
role,
236+
script_dir,
237+
pipeline_name,
238+
athena_dataset_definition,
237239
):
238240
framework_version = "0.20.0"
239241
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
@@ -385,13 +387,13 @@ def test_three_step_definition(
385387

386388

387389
def test_one_step_sklearn_processing_pipeline(
388-
sagemaker_session,
389-
role,
390-
sklearn_latest_version,
391-
cpu_instance_type,
392-
pipeline_name,
393-
region_name,
394-
athena_dataset_definition,
390+
sagemaker_session,
391+
role,
392+
sklearn_latest_version,
393+
cpu_instance_type,
394+
pipeline_name,
395+
region_name,
396+
athena_dataset_definition,
395397
):
396398
instance_count = ParameterInteger(name="InstanceCount", default_value=2)
397399
script_path = os.path.join(DATA_DIR, "dummy_script.py")
@@ -478,11 +480,11 @@ def test_one_step_sklearn_processing_pipeline(
478480

479481

480482
def test_one_step_pyspark_processing_pipeline(
481-
sagemaker_session,
482-
role,
483-
cpu_instance_type,
484-
pipeline_name,
485-
region_name,
483+
sagemaker_session,
484+
role,
485+
cpu_instance_type,
486+
pipeline_name,
487+
region_name,
486488
):
487489
instance_count = ParameterInteger(name="InstanceCount", default_value=2)
488490
script_path = os.path.join(DATA_DIR, "dummy_script.py")
@@ -580,7 +582,7 @@ def test_one_step_pyspark_processing_pipeline(
580582

581583

582584
def test_one_step_sparkjar_processing_pipeline(
583-
sagemaker_session, role, cpu_instance_type, pipeline_name, region_name, configuration, build_jar
585+
sagemaker_session, role, cpu_instance_type, pipeline_name, region_name, configuration, build_jar
584586
):
585587
instance_count = ParameterInteger(name="InstanceCount", default_value=2)
586588
cache_config = CacheConfig(enable_caching=True, expire_after="T30m")
@@ -677,11 +679,11 @@ def test_one_step_sparkjar_processing_pipeline(
677679

678680

679681
def test_conditional_pytorch_training_model_registration(
680-
sagemaker_session,
681-
role,
682-
cpu_instance_type,
683-
pipeline_name,
684-
region_name,
682+
sagemaker_session,
683+
role,
684+
cpu_instance_type,
685+
pipeline_name,
686+
region_name,
685687
):
686688
base_dir = os.path.join(DATA_DIR, "pytorch_mnist")
687689
entry_point = os.path.join(base_dir, "mnist.py")
@@ -777,11 +779,11 @@ def test_conditional_pytorch_training_model_registration(
777779

778780

779781
def test_training_job_with_debugger_and_profiler(
780-
sagemaker_session,
781-
pipeline_name,
782-
role,
783-
pytorch_training_latest_version,
784-
pytorch_training_latest_py_version,
782+
sagemaker_session,
783+
pipeline_name,
784+
role,
785+
pytorch_training_latest_version,
786+
pytorch_training_latest_py_version,
785787
):
786788
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
787789
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
@@ -858,7 +860,7 @@ def test_training_job_with_debugger_and_profiler(
858860
assert config["RuleEvaluatorImage"] == rule.image_uri
859861
assert config["VolumeSizeInGB"] == 0
860862
assert (
861-
config["RuleParameters"]["rule_to_invoke"] == rule.rule_parameters["rule_to_invoke"]
863+
config["RuleParameters"]["rule_to_invoke"] == rule.rule_parameters["rule_to_invoke"]
862864
)
863865
assert job_description["DebugHookConfig"] == debugger_hook_config._to_request_dict()
864866

@@ -869,3 +871,78 @@ def test_training_job_with_debugger_and_profiler(
869871
pipeline.delete()
870872
except Exception:
871873
pass
874+
875+
876+
def test_one_step_data_wrangler_processing_pipeline(
877+
sagemaker_session, role, cpu_instance_type, pipeline_name, region_name
878+
):
879+
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
880+
881+
recipe_file_path = os.path.join(DATA_DIR, "workflow", "dummy_recipe.flow")
882+
input_file_path = os.path.join(DATA_DIR, "workflow", "dummy_data.csv")
883+
884+
output_name = "1bd0aaad-9c93-41b2-8d42-58e214f0843f.default"
885+
output_content_type = "CSV"
886+
output_config = {output_name: {"content_type": output_content_type}}
887+
job_argument = [f"--output-config '{json.dumps(output_config)}'"]
888+
889+
inputs = [ProcessingInput(input_name="job_data", source=input_file_path, destination="/opt/ml/processing")]
890+
891+
output_s3_uri = f"s3://{sagemaker_session.default_bucket()}/output"
892+
outputs = [
893+
ProcessingOutput(
894+
output_name=output_name,
895+
source="/opt/ml/processing/output",
896+
destination=output_s3_uri,
897+
s3_upload_mode="EndOfJob",
898+
)
899+
]
900+
901+
data_wrangler_processor = DataWranglerProcessor(
902+
role=role,
903+
data_wrangler_recipe_source=recipe_file_path,
904+
instance_count=instance_count,
905+
instance_type=cpu_instance_type,
906+
max_runtime_in_seconds=86400,
907+
)
908+
909+
data_wrangler_step = ProcessingStep(
910+
name="data-wrangler-step",
911+
processor=data_wrangler_processor,
912+
inputs=inputs,
913+
outputs=outputs,
914+
job_arguments=job_argument,
915+
)
916+
917+
pipeline = Pipeline(
918+
name=pipeline_name,
919+
parameters=[instance_count],
920+
steps=[data_wrangler_step],
921+
sagemaker_session=sagemaker_session,
922+
)
923+
924+
definition = json.loads(pipeline.definition())
925+
expected_image_uri = image_uris.retrieve("data-wrangler", region=sagemaker_session.boto_region_name)
926+
assert len(definition["Steps"]) == 1
927+
assert definition["Steps"][0]["Arguments"]["AppSpecification"]["ImageUri"] is not None
928+
assert definition["Steps"][0]["Arguments"]["AppSpecification"]["ImageUri"] == expected_image_uri
929+
930+
assert definition["Steps"][0]["Arguments"]["ProcessingInputs"] is not None
931+
processing_inputs = definition["Steps"][0]["Arguments"]["ProcessingInputs"]
932+
assert len(processing_inputs) == 2
933+
for processing_input in processing_inputs:
934+
if processing_input["InputName"] == "flow":
935+
assert processing_input["S3Input"]["S3Uri"].endswith(".flow")
936+
assert processing_input["S3Input"]["LocalPath"] == "/opt/ml/processing/flow"
937+
elif processing_input["InputName"] == "job_data":
938+
assert processing_input["S3Input"]["S3Uri"].endswith(".csv")
939+
assert processing_input["S3Input"]["LocalPath"] == "/opt/ml/processing"
940+
else:
941+
raise AssertionError("Unknown input name")
942+
assert definition["Steps"][0]["Arguments"]["ProcessingOutputConfig"] is not None
943+
processing_outputs = definition["Steps"][0]["Arguments"]["ProcessingOutputConfig"]["Outputs"]
944+
assert len(processing_outputs) == 1
945+
assert processing_outputs[0]["OutputName"] == output_name
946+
assert processing_outputs[0]["S3Output"] is not None
947+
assert processing_outputs[0]["S3Output"]["LocalPath"] == "/opt/ml/processing/output"
948+
assert processing_outputs[0]["S3Output"]["S3Uri"] == output_s3_uri

0 commit comments

Comments
 (0)