-
Notifications
You must be signed in to change notification settings - Fork 1.2k
add data wrangler processor #2306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# # | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# # | ||
# http://aws.amazon.com/apache2.0/ | ||
# # | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""The process definitions for data wrangler.""" | ||
|
||
from __future__ import absolute_import | ||
|
||
from typing import Dict, List | ||
|
||
from sagemaker.network import NetworkConfig | ||
from sagemaker.processing import ( | ||
ProcessingInput, | ||
Processor, | ||
) | ||
from sagemaker import image_uris | ||
from sagemaker.session import Session | ||
|
||
|
||
class DataWranglerProcessor(Processor): | ||
"""Handles Amazon SageMaker DataWrangler tasks""" | ||
|
||
def __init__( | ||
self, | ||
role: str, | ||
data_wrangler_flow_source: str, | ||
instance_count: int, | ||
instance_type: str, | ||
volume_size_in_gb: int = 30, | ||
volume_kms_key: str = None, | ||
output_kms_key: str = None, | ||
max_runtime_in_seconds: int = None, | ||
base_job_name: str = None, | ||
sagemaker_session: Session = None, | ||
env: Dict[str, str] = None, | ||
tags: List[dict] = None, | ||
network_config: NetworkConfig = None, | ||
): | ||
"""Initializes a ``Processor`` instance. | ||
|
||
The ``Processor`` handles Amazon SageMaker Processing tasks. | ||
|
||
Args: | ||
role (str): An AWS IAM role name or ARN. Amazon SageMaker Processing | ||
uses this role to access AWS resources, such as | ||
data stored in Amazon S3. | ||
data_wrangler_flow_source (str): The source of the DaraWrangler flow which will be | ||
used for the DataWrangler job. If a local path is provided, it will automatically | ||
be uploaded to S3 under: | ||
"s3://<default-bucket-name>/<job-name>/input/<input-name>". | ||
instance_count (int): The number of instances to run | ||
a processing job with. | ||
instance_type (str): The type of EC2 instance to use for | ||
processing, for example, 'ml.c4.xlarge'. | ||
volume_size_in_gb (int): Size in GB of the EBS volume | ||
to use for storing data during processing (default: 30). | ||
volume_kms_key (str): A KMS key for the processing | ||
volume (default: None). | ||
output_kms_key (str): The KMS key ID for processing job outputs (default: None). | ||
max_runtime_in_seconds (int): Timeout in seconds (default: None). | ||
After this amount of time, Amazon SageMaker terminates the job, | ||
regardless of its current status. If `max_runtime_in_seconds` is not | ||
specified, the default value is 24 hours. | ||
base_job_name (str): Prefix for processing job name. If not specified, | ||
the processor generates a default job name, based on the | ||
processing image name and current timestamp. | ||
sagemaker_session (:class:`~sagemaker.session.Session`): | ||
Session object which manages interactions with Amazon SageMaker and | ||
any other AWS services needed. If not specified, the processor creates | ||
one using the default AWS configuration chain. | ||
env (dict[str, str]): Environment variables to be passed to | ||
the processing jobs (default: None). | ||
tags (list[dict]): List of tags to be passed to the processing job | ||
(default: None). For more, see | ||
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. | ||
network_config (:class:`~sagemaker.network.NetworkConfig`): | ||
A :class:`~sagemaker.network.NetworkConfig` | ||
object that configures network isolation, encryption of | ||
inter-container traffic, security group IDs, and subnets. | ||
""" | ||
self.data_wrangler_flow_source = data_wrangler_flow_source | ||
self.sagemaker_session = sagemaker_session or Session() | ||
image_uri = image_uris.retrieve( | ||
"data-wrangler", region=self.sagemaker_session.boto_region_name | ||
) | ||
super().__init__( | ||
role, | ||
image_uri, | ||
instance_count, | ||
instance_type, | ||
volume_size_in_gb=volume_size_in_gb, | ||
volume_kms_key=volume_kms_key, | ||
output_kms_key=output_kms_key, | ||
max_runtime_in_seconds=max_runtime_in_seconds, | ||
base_job_name=base_job_name, | ||
sagemaker_session=sagemaker_session, | ||
env=env, | ||
tags=tags, | ||
network_config=network_config, | ||
) | ||
|
||
def _normalize_args( | ||
self, | ||
job_name=None, | ||
arguments=None, | ||
inputs=None, | ||
outputs=None, | ||
code=None, | ||
kms_key=None, | ||
): | ||
"""Normalizes the arguments so that they can be passed to the job run | ||
|
||
Args: | ||
job_name (str): Name of the processing job to be created. If not specified, one | ||
is generated, using the base name given to the constructor, if applicable | ||
(default: None). | ||
arguments (list[str]): A list of string arguments to be passed to a | ||
processing job (default: None). | ||
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for | ||
the processing job. These must be provided as | ||
:class:`~sagemaker.processing.ProcessingInput` objects (default: None). | ||
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for | ||
the processing job. These can be specified as either path strings or | ||
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None). | ||
code (str): This can be an S3 URI or a local path to a file with the framework | ||
script to run (default: None). A no op in the base class. | ||
kms_key (str): The ARN of the KMS key that is used to encrypt the | ||
user code file (default: None). | ||
""" | ||
inputs = inputs or [] | ||
found = any(element.input_name == "flow" for element in inputs) | ||
if not found: | ||
inputs.append(self._get_recipe_input()) | ||
return super()._normalize_args(job_name, arguments, inputs, outputs, code, kms_key) | ||
|
||
def _get_recipe_input(self): | ||
"""Creates a ProcessingInput with Data Wrangler recipe uri and appends it to inputs""" | ||
return ProcessingInput( | ||
source=self.data_wrangler_flow_source, | ||
destination="/opt/ml/processing/flow", | ||
input_name="flow", | ||
s3_data_type="S3Prefix", | ||
s3_input_mode="File", | ||
s3_data_distribution_type="FullyReplicated", | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
Class,Age,Sex,SurvivalStatus | ||
1st,"Quantity[29., ""Years""]",female,survived | ||
1st,"Quantity[0.9167, ""Years""]",male,survived | ||
2nd,"Quantity[30., ""Years""]",male,died | ||
2nd,"Quantity[28., ""Years""]",female,survived | ||
3rd,"Quantity[16., ""Years""]",male,died | ||
3rd,"Quantity[35., ""Years""]",female,survived |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
{ | ||
"metadata": { | ||
"version": 1 | ||
}, | ||
"nodes": [ | ||
{ | ||
"node_id": "3f74973c-fd1e-4845-89f8-0dd400031be9", | ||
"type": "SOURCE", | ||
"operator": "sagemaker.s3_source_0.1", | ||
"parameters": { | ||
"dataset_definition": { | ||
"__typename": "S3CreateDatasetDefinitionOutput", | ||
"datasetSourceType": "S3", | ||
"name": "dummy_data.csv", | ||
"description": null, | ||
"s3ExecutionContext": { | ||
"__typename": "S3ExecutionContext", | ||
"s3Uri": "s3://bucket/dummy_data.csv", | ||
"s3ContentType": "csv", | ||
"s3HasHeader": true | ||
} | ||
} | ||
}, | ||
"inputs": [], | ||
"outputs": [ | ||
{ | ||
"name": "default", | ||
"sampling": { | ||
"sampling_method": "sample_by_limit", | ||
"limit_rows": 50000 | ||
} | ||
} | ||
] | ||
}, | ||
{ | ||
"node_id": "67c18cb1-0192-445a-86f4-31e4c3553c60", | ||
"type": "TRANSFORM", | ||
"operator": "sagemaker.spark.infer_and_cast_type_0.1", | ||
"parameters": {}, | ||
"trained_parameters": { | ||
"schema": { | ||
"Class": "string", | ||
"Age": "string", | ||
"Sex": "string", | ||
"SurvivalStatus": "string" | ||
} | ||
}, | ||
"inputs": [ | ||
{ | ||
"name": "default", | ||
"node_id": "3f74973c-fd1e-4845-89f8-0dd400031be9", | ||
"output_name": "default" | ||
} | ||
], | ||
"outputs": [ | ||
{ | ||
"name": "default" | ||
} | ||
] | ||
} | ||
] | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ | |
rule_configs, | ||
) | ||
from datetime import datetime | ||
from sagemaker import image_uris | ||
from sagemaker.inputs import CreateModelInput, TrainingInput | ||
from sagemaker.model import Model | ||
from sagemaker.processing import ProcessingInput, ProcessingOutput | ||
|
@@ -39,6 +40,7 @@ | |
from sagemaker.spark.processing import PySparkProcessor, SparkJarProcessor | ||
from sagemaker.workflow.conditions import ConditionGreaterThanOrEqualTo | ||
from sagemaker.workflow.condition_step import ConditionStep | ||
from sagemaker.wrangler.processing import DataWranglerProcessor | ||
from sagemaker.dataset_definition.inputs import DatasetDefinition, AthenaDatasetDefinition | ||
from sagemaker.workflow.execution_variables import ExecutionVariables | ||
from sagemaker.workflow.functions import Join | ||
|
@@ -1076,3 +1078,113 @@ def test_two_processing_job_depends_on( | |
pipeline.delete() | ||
except Exception: | ||
pass | ||
|
||
|
||
def test_one_step_data_wrangler_processing_pipeline( | ||
sagemaker_session, | ||
role, | ||
pipeline_name, | ||
region_name, | ||
): | ||
instance_count = ParameterInteger(name="InstanceCount", default_value=1) | ||
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.4xlarge") | ||
|
||
recipe_file_path = os.path.join(DATA_DIR, "workflow", "dummy_recipe.flow") | ||
input_file_path = os.path.join(DATA_DIR, "workflow", "dummy_data.csv") | ||
|
||
output_name = "3f74973c-fd1e-4845-89f8-0dd400031be9.default" | ||
output_content_type = "CSV" | ||
output_config = {output_name: {"content_type": output_content_type}} | ||
job_argument = [f"--output-config '{json.dumps(output_config)}'"] | ||
|
||
inputs = [ | ||
ProcessingInput( | ||
input_name="dummy_data.csv", | ||
source=input_file_path, | ||
destination="/opt/ml/processing/dummy_data.csv", | ||
) | ||
] | ||
|
||
output_s3_uri = f"s3://{sagemaker_session.default_bucket()}/output" | ||
outputs = [ | ||
ProcessingOutput( | ||
output_name=output_name, | ||
source="/opt/ml/processing/output", | ||
destination=output_s3_uri, | ||
s3_upload_mode="EndOfJob", | ||
) | ||
] | ||
|
||
data_wrangler_processor = DataWranglerProcessor( | ||
role=role, | ||
data_wrangler_flow_source=recipe_file_path, | ||
instance_count=instance_count, | ||
instance_type=instance_type, | ||
sagemaker_session=sagemaker_session, | ||
max_runtime_in_seconds=86400, | ||
) | ||
|
||
data_wrangler_step = ProcessingStep( | ||
name="data-wrangler-step", | ||
processor=data_wrangler_processor, | ||
inputs=inputs, | ||
outputs=outputs, | ||
job_arguments=job_argument, | ||
) | ||
|
||
pipeline = Pipeline( | ||
name=pipeline_name, | ||
parameters=[instance_count, instance_type], | ||
steps=[data_wrangler_step], | ||
sagemaker_session=sagemaker_session, | ||
) | ||
|
||
definition = json.loads(pipeline.definition()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we execute the pipeline and make sure the step actually works as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They are basically processing jobs so running it doesn't really add value...the new processor only injects DW flow and container into the processing job. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. test updated, now it actually runs the pipeline |
||
expected_image_uri = image_uris.retrieve( | ||
"data-wrangler", region=sagemaker_session.boto_region_name | ||
) | ||
assert len(definition["Steps"]) == 1 | ||
assert definition["Steps"][0]["Arguments"]["AppSpecification"]["ImageUri"] is not None | ||
assert definition["Steps"][0]["Arguments"]["AppSpecification"]["ImageUri"] == expected_image_uri | ||
|
||
assert definition["Steps"][0]["Arguments"]["ProcessingInputs"] is not None | ||
processing_inputs = definition["Steps"][0]["Arguments"]["ProcessingInputs"] | ||
assert len(processing_inputs) == 2 | ||
for processing_input in processing_inputs: | ||
if processing_input["InputName"] == "flow": | ||
assert processing_input["S3Input"]["S3Uri"].endswith(".flow") | ||
assert processing_input["S3Input"]["LocalPath"] == "/opt/ml/processing/flow" | ||
elif processing_input["InputName"] == "dummy_data.csv": | ||
assert processing_input["S3Input"]["S3Uri"].endswith(".csv") | ||
assert processing_input["S3Input"]["LocalPath"] == "/opt/ml/processing/dummy_data.csv" | ||
else: | ||
raise AssertionError("Unknown input name") | ||
assert definition["Steps"][0]["Arguments"]["ProcessingOutputConfig"] is not None | ||
processing_outputs = definition["Steps"][0]["Arguments"]["ProcessingOutputConfig"]["Outputs"] | ||
assert len(processing_outputs) == 1 | ||
assert processing_outputs[0]["OutputName"] == output_name | ||
assert processing_outputs[0]["S3Output"] is not None | ||
assert processing_outputs[0]["S3Output"]["LocalPath"] == "/opt/ml/processing/output" | ||
assert processing_outputs[0]["S3Output"]["S3Uri"] == output_s3_uri | ||
|
||
try: | ||
response = pipeline.create(role) | ||
create_arn = response["PipelineArn"] | ||
|
||
execution = pipeline.start() | ||
response = execution.describe() | ||
assert response["PipelineArn"] == create_arn | ||
|
||
try: | ||
execution.wait(delay=60, max_attempts=10) | ||
except WaiterError: | ||
pass | ||
|
||
execution_steps = execution.list_steps() | ||
assert len(execution_steps) == 1 | ||
assert execution_steps[0]["StepName"] == "data-wrangler-step" | ||
finally: | ||
try: | ||
pipeline.delete() | ||
except Exception: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should there be any client side validation on the flow file contents?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
job will fail if an invalid flow file is passed. We should do some sanity check, but we don't have the capability ATM; perhaps that's something DW team can add in the future.