Skip to content

Commit 642c958

Browse files
add data wrangler processor
1 parent 7b1e5c1 commit 642c958

File tree

2 files changed

+271
-0
lines changed

2 files changed

+271
-0
lines changed

src/sagemaker/workflow/processing.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# #
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
# #
7+
# http://aws.amazon.com/apache2.0/
8+
# #
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""The process definitions for workflow."""
14+
15+
from __future__ import absolute_import
16+
17+
from sagemaker.processing import (
18+
ProcessingInput,
19+
Processor,
20+
)
21+
from sagemaker import image_uris
22+
from sagemaker.session import Session
23+
24+
25+
class DataWranglerProcessor(Processor):
26+
"""Handles Amazon SageMaker DataWrangler tasks"""
27+
28+
def __init__(
29+
self,
30+
role,
31+
data_wrangler_recipe_uri,
32+
instance_count,
33+
instance_type,
34+
volume_size_in_gb=30,
35+
volume_kms_key=None,
36+
output_kms_key=None,
37+
max_runtime_in_seconds=None,
38+
base_job_name=None,
39+
sagemaker_session=None,
40+
tags=None,
41+
network_config=None,
42+
):
43+
"""Initializes a ``Processor`` instance.
44+
45+
The ``Processor`` handles Amazon SageMaker Processing tasks.
46+
47+
Args:
48+
role (str): An AWS IAM role name or ARN. Amazon SageMaker Processing
49+
uses this role to access AWS resources, such as
50+
data stored in Amazon S3.
51+
data_wrangler_recipe_uri (str): The S3 URI to DaraWrangler recipe which will be
52+
used for the DataWrangler job
53+
instance_count (int): The number of instances to run
54+
a processing job with.
55+
instance_type (str): The type of EC2 instance to use for
56+
processing, for example, 'ml.c4.xlarge'.
57+
volume_size_in_gb (int): Size in GB of the EBS volume
58+
to use for storing data during processing (default: 30).
59+
volume_kms_key (str): A KMS key for the processing
60+
volume (default: None).
61+
output_kms_key (str): The KMS key ID for processing job outputs (default: None).
62+
max_runtime_in_seconds (int): Timeout in seconds (default: None).
63+
After this amount of time, Amazon SageMaker terminates the job,
64+
regardless of its current status. If `max_runtime_in_seconds` is not
65+
specified, the default value is 24 hours.
66+
base_job_name (str): Prefix for processing job name. If not specified,
67+
the processor generates a default job name, based on the
68+
processing image name and current timestamp.
69+
sagemaker_session (:class:`~sagemaker.session.Session`):
70+
Session object which manages interactions with Amazon SageMaker and
71+
any other AWS services needed. If not specified, the processor creates
72+
one using the default AWS configuration chain.
73+
tags (list[dict]): List of tags to be passed to the processing job
74+
(default: None). For more, see
75+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
76+
network_config (:class:`~sagemaker.network.NetworkConfig`):
77+
A :class:`~sagemaker.network.NetworkConfig`
78+
object that configures network isolation, encryption of
79+
inter-container traffic, security group IDs, and subnets.
80+
"""
81+
self.data_wrangler_recipe_uri = data_wrangler_recipe_uri
82+
self.sagemaker_session = sagemaker_session or Session()
83+
image_uri = image_uris.retrieve(
84+
"data-wrangler", region=self.sagemaker_session.boto_region_name
85+
)
86+
super().__init__(
87+
role,
88+
image_uri,
89+
instance_count,
90+
instance_type,
91+
volume_size_in_gb=volume_size_in_gb,
92+
volume_kms_key=volume_kms_key,
93+
output_kms_key=output_kms_key,
94+
max_runtime_in_seconds=max_runtime_in_seconds,
95+
base_job_name=base_job_name,
96+
sagemaker_session=sagemaker_session,
97+
tags=tags,
98+
network_config=network_config,
99+
)
100+
101+
def _normalize_args(
102+
self,
103+
job_name=None,
104+
arguments=None,
105+
inputs=None,
106+
outputs=None,
107+
code=None,
108+
kms_key=None,
109+
):
110+
"""Normalizes the arguments so that they can be passed to the job run
111+
112+
Args:
113+
job_name (str): Name of the processing job to be created. If not specified, one
114+
is generated, using the base name given to the constructor, if applicable
115+
(default: None).
116+
arguments (list[str]): A list of string arguments to be passed to a
117+
processing job (default: None).
118+
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
119+
the processing job. These must be provided as
120+
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
121+
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
122+
the processing job. These can be specified as either path strings or
123+
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
124+
code (str): This can be an S3 URI or a local path to a file with the framework
125+
script to run (default: None). A no op in the base class.
126+
kms_key (str): The ARN of the KMS key that is used to encrypt the
127+
user code file (default: None).
128+
"""
129+
inputs = inputs or []
130+
inputs.append(self._get_recipe_input())
131+
return super()._normalize_args(job_name, arguments, inputs, outputs, code, kms_key)
132+
133+
def _get_recipe_input(self):
134+
"""Creates a ProcessingInput with Data Wrangler recipe uri and appends it to inputs"""
135+
return ProcessingInput(
136+
source=self.data_wrangler_recipe_uri,
137+
destination="/opt/ml/processing/flow",
138+
input_name="flow",
139+
s3_data_type="S3Prefix",
140+
s3_input_mode="File",
141+
s3_data_distribution_type="FullyReplicated",
142+
)
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# #
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
# #
7+
# http://aws.amazon.com/apache2.0/
8+
# #
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
import pytest
15+
from mock import Mock, MagicMock
16+
17+
from sagemaker.workflow.processing import DataWranglerProcessor
18+
from sagemaker.processing import ProcessingInput
19+
20+
ROLE = "arn:aws:iam::012345678901:role/SageMakerRole"
21+
REGION = "us-west-2"
22+
DATA_WRANGLER_RECIPE_URI = "s3://data_wrangler_flows/flow-26-18-43-16-0b48ac2e.flow"
23+
DATA_WRANGLER_CONTAINER_URI = (
24+
"174368400705.dkr.ecr.us-west-2.amazonaws.com/sagemaker-data-wrangler-container:1.x"
25+
)
26+
MOCK_S3_URI = "s3://mock_data/mock.csv"
27+
28+
29+
@pytest.fixture()
30+
def sagemaker_session():
31+
boto_mock = Mock(name="boto_session", region_name=REGION)
32+
session_mock = MagicMock(
33+
name="sagemaker_session",
34+
boto_session=boto_mock,
35+
boto_region_name=REGION,
36+
config=None,
37+
local_mode=False,
38+
)
39+
session_mock.expand_role.return_value = ROLE
40+
return session_mock
41+
42+
43+
def test_data_wrangler_processor_with_required_parameters(sagemaker_session):
44+
processor = DataWranglerProcessor(
45+
role=ROLE,
46+
data_wrangler_recipe_uri=DATA_WRANGLER_RECIPE_URI,
47+
instance_count=1,
48+
instance_type="ml.m4.xlarge",
49+
sagemaker_session=sagemaker_session,
50+
)
51+
52+
processor.run()
53+
expected_args = _get_expected_args(processor._current_job_name)
54+
sagemaker_session.process.assert_called_with(**expected_args)
55+
56+
57+
def test_data_wrangler_processor_with_mock_input(sagemaker_session):
58+
processor = DataWranglerProcessor(
59+
role=ROLE,
60+
data_wrangler_recipe_uri=DATA_WRANGLER_RECIPE_URI,
61+
instance_count=1,
62+
instance_type="ml.m4.xlarge",
63+
sagemaker_session=sagemaker_session,
64+
)
65+
66+
mock_input = ProcessingInput(
67+
source=MOCK_S3_URI,
68+
destination="/opt/ml/processing/mock_input",
69+
input_name="mock_input",
70+
s3_data_type="S3Prefix",
71+
s3_input_mode="File",
72+
s3_data_distribution_type="FullyReplicated",
73+
)
74+
processor.run(inputs=[mock_input])
75+
expected_args = _get_expected_args(processor._current_job_name, add_mock_input=True)
76+
sagemaker_session.process.assert_called_with(**expected_args)
77+
78+
79+
def _get_expected_args(job_name, add_mock_input=False):
80+
args = {
81+
"inputs": [
82+
{
83+
"InputName": "flow",
84+
"AppManaged": False,
85+
"S3Input": {
86+
"S3Uri": DATA_WRANGLER_RECIPE_URI,
87+
"LocalPath": "/opt/ml/processing/flow",
88+
"S3DataType": "S3Prefix",
89+
"S3InputMode": "File",
90+
"S3DataDistributionType": "FullyReplicated",
91+
"S3CompressionType": "None",
92+
},
93+
}
94+
],
95+
"output_config": {"Outputs": []},
96+
"job_name": job_name,
97+
"resources": {
98+
"ClusterConfig": {
99+
"InstanceType": "ml.m4.xlarge",
100+
"InstanceCount": 1,
101+
"VolumeSizeInGB": 30,
102+
}
103+
},
104+
"stopping_condition": None,
105+
"app_specification": {
106+
"ImageUri": DATA_WRANGLER_CONTAINER_URI,
107+
},
108+
"environment": None,
109+
"network_config": None,
110+
"role_arn": ROLE,
111+
"tags": None,
112+
"experiment_config": None,
113+
}
114+
115+
if add_mock_input:
116+
mock_input = {
117+
"InputName": "mock_input",
118+
"AppManaged": False,
119+
"S3Input": {
120+
"S3Uri": MOCK_S3_URI,
121+
"LocalPath": "/opt/ml/processing/mock_input",
122+
"S3DataType": "S3Prefix",
123+
"S3InputMode": "File",
124+
"S3DataDistributionType": "FullyReplicated",
125+
"S3CompressionType": "None",
126+
},
127+
}
128+
args["inputs"].insert(0, mock_input)
129+
return args

0 commit comments

Comments
 (0)