Skip to content

add data wrangler processor #2306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
153 changes: 153 additions & 0 deletions src/sagemaker/wrangler/processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
# #
# http://aws.amazon.com/apache2.0/
# #
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""The process definitions for data wrangler."""

from __future__ import absolute_import

from typing import Dict, List

from sagemaker.network import NetworkConfig
from sagemaker.processing import (
ProcessingInput,
Processor,
)
from sagemaker import image_uris
from sagemaker.session import Session


class DataWranglerProcessor(Processor):
"""Handles Amazon SageMaker DataWrangler tasks"""

def __init__(
self,
role: str,
data_wrangler_flow_source: str,
instance_count: int,
instance_type: str,
volume_size_in_gb: int = 30,
volume_kms_key: str = None,
output_kms_key: str = None,
max_runtime_in_seconds: int = None,
base_job_name: str = None,
sagemaker_session: Session = None,
env: Dict[str, str] = None,
tags: List[dict] = None,
network_config: NetworkConfig = None,
):
"""Initializes a ``Processor`` instance.

The ``Processor`` handles Amazon SageMaker Processing tasks.

Args:
role (str): An AWS IAM role name or ARN. Amazon SageMaker Processing
uses this role to access AWS resources, such as
data stored in Amazon S3.
data_wrangler_flow_source (str): The source of the DaraWrangler flow which will be
used for the DataWrangler job. If a local path is provided, it will automatically
be uploaded to S3 under:
"s3://<default-bucket-name>/<job-name>/input/<input-name>".
Comment on lines +55 to +58
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should there be any client side validation on the flow file contents?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

job will fail if an invalid flow file is passed. We should do some sanity check, but we don't have the capability ATM; perhaps that's something DW team can add in the future.

instance_count (int): The number of instances to run
a processing job with.
instance_type (str): The type of EC2 instance to use for
processing, for example, 'ml.c4.xlarge'.
volume_size_in_gb (int): Size in GB of the EBS volume
to use for storing data during processing (default: 30).
volume_kms_key (str): A KMS key for the processing
volume (default: None).
output_kms_key (str): The KMS key ID for processing job outputs (default: None).
max_runtime_in_seconds (int): Timeout in seconds (default: None).
After this amount of time, Amazon SageMaker terminates the job,
regardless of its current status. If `max_runtime_in_seconds` is not
specified, the default value is 24 hours.
base_job_name (str): Prefix for processing job name. If not specified,
the processor generates a default job name, based on the
processing image name and current timestamp.
sagemaker_session (:class:`~sagemaker.session.Session`):
Session object which manages interactions with Amazon SageMaker and
any other AWS services needed. If not specified, the processor creates
one using the default AWS configuration chain.
env (dict[str, str]): Environment variables to be passed to
the processing jobs (default: None).
tags (list[dict]): List of tags to be passed to the processing job
(default: None). For more, see
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
network_config (:class:`~sagemaker.network.NetworkConfig`):
A :class:`~sagemaker.network.NetworkConfig`
object that configures network isolation, encryption of
inter-container traffic, security group IDs, and subnets.
"""
self.data_wrangler_flow_source = data_wrangler_flow_source
self.sagemaker_session = sagemaker_session or Session()
image_uri = image_uris.retrieve(
"data-wrangler", region=self.sagemaker_session.boto_region_name
)
super().__init__(
role,
image_uri,
instance_count,
instance_type,
volume_size_in_gb=volume_size_in_gb,
volume_kms_key=volume_kms_key,
output_kms_key=output_kms_key,
max_runtime_in_seconds=max_runtime_in_seconds,
base_job_name=base_job_name,
sagemaker_session=sagemaker_session,
env=env,
tags=tags,
network_config=network_config,
)

def _normalize_args(
self,
job_name=None,
arguments=None,
inputs=None,
outputs=None,
code=None,
kms_key=None,
):
"""Normalizes the arguments so that they can be passed to the job run

Args:
job_name (str): Name of the processing job to be created. If not specified, one
is generated, using the base name given to the constructor, if applicable
(default: None).
arguments (list[str]): A list of string arguments to be passed to a
processing job (default: None).
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
the processing job. These must be provided as
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
the processing job. These can be specified as either path strings or
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
code (str): This can be an S3 URI or a local path to a file with the framework
script to run (default: None). A no op in the base class.
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).
"""
inputs = inputs or []
found = any(element.input_name == "flow" for element in inputs)
if not found:
inputs.append(self._get_recipe_input())
return super()._normalize_args(job_name, arguments, inputs, outputs, code, kms_key)

def _get_recipe_input(self):
"""Creates a ProcessingInput with Data Wrangler recipe uri and appends it to inputs"""
return ProcessingInput(
source=self.data_wrangler_flow_source,
destination="/opt/ml/processing/flow",
input_name="flow",
s3_data_type="S3Prefix",
s3_input_mode="File",
s3_data_distribution_type="FullyReplicated",
)
7 changes: 7 additions & 0 deletions tests/data/workflow/dummy_data.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Class,Age,Sex,SurvivalStatus
1st,"Quantity[29., ""Years""]",female,survived
1st,"Quantity[0.9167, ""Years""]",male,survived
2nd,"Quantity[30., ""Years""]",male,died
2nd,"Quantity[28., ""Years""]",female,survived
3rd,"Quantity[16., ""Years""]",male,died
3rd,"Quantity[35., ""Years""]",female,survived
62 changes: 62 additions & 0 deletions tests/data/workflow/dummy_recipe.flow
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
{
"metadata": {
"version": 1
},
"nodes": [
{
"node_id": "3f74973c-fd1e-4845-89f8-0dd400031be9",
"type": "SOURCE",
"operator": "sagemaker.s3_source_0.1",
"parameters": {
"dataset_definition": {
"__typename": "S3CreateDatasetDefinitionOutput",
"datasetSourceType": "S3",
"name": "dummy_data.csv",
"description": null,
"s3ExecutionContext": {
"__typename": "S3ExecutionContext",
"s3Uri": "s3://bucket/dummy_data.csv",
"s3ContentType": "csv",
"s3HasHeader": true
}
}
},
"inputs": [],
"outputs": [
{
"name": "default",
"sampling": {
"sampling_method": "sample_by_limit",
"limit_rows": 50000
}
}
]
},
{
"node_id": "67c18cb1-0192-445a-86f4-31e4c3553c60",
"type": "TRANSFORM",
"operator": "sagemaker.spark.infer_and_cast_type_0.1",
"parameters": {},
"trained_parameters": {
"schema": {
"Class": "string",
"Age": "string",
"Sex": "string",
"SurvivalStatus": "string"
}
},
"inputs": [
{
"name": "default",
"node_id": "3f74973c-fd1e-4845-89f8-0dd400031be9",
"output_name": "default"
}
],
"outputs": [
{
"name": "default"
}
]
}
]
}
112 changes: 112 additions & 0 deletions tests/integ/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
rule_configs,
)
from datetime import datetime
from sagemaker import image_uris
from sagemaker.inputs import CreateModelInput, TrainingInput
from sagemaker.model import Model
from sagemaker.processing import ProcessingInput, ProcessingOutput
Expand All @@ -39,6 +40,7 @@
from sagemaker.spark.processing import PySparkProcessor, SparkJarProcessor
from sagemaker.workflow.conditions import ConditionGreaterThanOrEqualTo
from sagemaker.workflow.condition_step import ConditionStep
from sagemaker.wrangler.processing import DataWranglerProcessor
from sagemaker.dataset_definition.inputs import DatasetDefinition, AthenaDatasetDefinition
from sagemaker.workflow.execution_variables import ExecutionVariables
from sagemaker.workflow.functions import Join
Expand Down Expand Up @@ -1076,3 +1078,113 @@ def test_two_processing_job_depends_on(
pipeline.delete()
except Exception:
pass


def test_one_step_data_wrangler_processing_pipeline(
sagemaker_session,
role,
pipeline_name,
region_name,
):
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.4xlarge")

recipe_file_path = os.path.join(DATA_DIR, "workflow", "dummy_recipe.flow")
input_file_path = os.path.join(DATA_DIR, "workflow", "dummy_data.csv")

output_name = "3f74973c-fd1e-4845-89f8-0dd400031be9.default"
output_content_type = "CSV"
output_config = {output_name: {"content_type": output_content_type}}
job_argument = [f"--output-config '{json.dumps(output_config)}'"]

inputs = [
ProcessingInput(
input_name="dummy_data.csv",
source=input_file_path,
destination="/opt/ml/processing/dummy_data.csv",
)
]

output_s3_uri = f"s3://{sagemaker_session.default_bucket()}/output"
outputs = [
ProcessingOutput(
output_name=output_name,
source="/opt/ml/processing/output",
destination=output_s3_uri,
s3_upload_mode="EndOfJob",
)
]

data_wrangler_processor = DataWranglerProcessor(
role=role,
data_wrangler_flow_source=recipe_file_path,
instance_count=instance_count,
instance_type=instance_type,
sagemaker_session=sagemaker_session,
max_runtime_in_seconds=86400,
)

data_wrangler_step = ProcessingStep(
name="data-wrangler-step",
processor=data_wrangler_processor,
inputs=inputs,
outputs=outputs,
job_arguments=job_argument,
)

pipeline = Pipeline(
name=pipeline_name,
parameters=[instance_count, instance_type],
steps=[data_wrangler_step],
sagemaker_session=sagemaker_session,
)

definition = json.loads(pipeline.definition())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we execute the pipeline and make sure the step actually works as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are basically processing jobs so running it doesn't really add value...the new processor only injects DW flow and container into the processing job.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test updated, now it actually runs the pipeline

expected_image_uri = image_uris.retrieve(
"data-wrangler", region=sagemaker_session.boto_region_name
)
assert len(definition["Steps"]) == 1
assert definition["Steps"][0]["Arguments"]["AppSpecification"]["ImageUri"] is not None
assert definition["Steps"][0]["Arguments"]["AppSpecification"]["ImageUri"] == expected_image_uri

assert definition["Steps"][0]["Arguments"]["ProcessingInputs"] is not None
processing_inputs = definition["Steps"][0]["Arguments"]["ProcessingInputs"]
assert len(processing_inputs) == 2
for processing_input in processing_inputs:
if processing_input["InputName"] == "flow":
assert processing_input["S3Input"]["S3Uri"].endswith(".flow")
assert processing_input["S3Input"]["LocalPath"] == "/opt/ml/processing/flow"
elif processing_input["InputName"] == "dummy_data.csv":
assert processing_input["S3Input"]["S3Uri"].endswith(".csv")
assert processing_input["S3Input"]["LocalPath"] == "/opt/ml/processing/dummy_data.csv"
else:
raise AssertionError("Unknown input name")
assert definition["Steps"][0]["Arguments"]["ProcessingOutputConfig"] is not None
processing_outputs = definition["Steps"][0]["Arguments"]["ProcessingOutputConfig"]["Outputs"]
assert len(processing_outputs) == 1
assert processing_outputs[0]["OutputName"] == output_name
assert processing_outputs[0]["S3Output"] is not None
assert processing_outputs[0]["S3Output"]["LocalPath"] == "/opt/ml/processing/output"
assert processing_outputs[0]["S3Output"]["S3Uri"] == output_s3_uri

try:
response = pipeline.create(role)
create_arn = response["PipelineArn"]

execution = pipeline.start()
response = execution.describe()
assert response["PipelineArn"] == create_arn

try:
execution.wait(delay=60, max_attempts=10)
except WaiterError:
pass

execution_steps = execution.list_steps()
assert len(execution_steps) == 1
assert execution_steps[0]["StepName"] == "data-wrangler-step"
finally:
try:
pipeline.delete()
except Exception:
pass
Empty file.
Loading