Skip to content

feature: Apache Airflow integration for SageMaker Processing Jobs #1620

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions src/sagemaker/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,93 @@ def stop(self):
"""Stops the processing job."""
self.sagemaker_session.stop_processing_job(self.name)

@staticmethod
def prepare_app_specification(container_arguments, container_entrypoint, image_uri):
"""Prepares a dict that represents a ProcessingJob's AppSpecification.

Args:
container_arguments (list[str]): The arguments for a container
used to run a processing job.
container_entrypoint (list[str]): The entrypoint for a container
used to run a processing job.
image_uri (str): The container image to be run by the processing job.

Returns:
dict: Represents AppSpecification which configures the
processing job to run a specified Docker container image.
"""
config = {"ImageUri": image_uri}
if container_arguments is not None:
config["ContainerArguments"] = container_arguments
if container_entrypoint is not None:
config["ContainerEntrypoint"] = container_entrypoint
return config

@staticmethod
def prepare_output_config(kms_key_id, outputs):
"""Prepares a dict that represents a ProcessingOutputConfig.

Args:
kms_key_id (str): The AWS Key Management Service (AWS KMS) key that
Amazon SageMaker uses to encrypt the processing job output.
KmsKeyId can be an ID of a KMS key, ARN of a KMS key, alias of a KMS key,
or alias of a KMS key. The KmsKeyId is applied to all outputs.
outputs (list[dict]): Output configuration information for a processing job.

Returns:
dict: Represents output configuration for the processing job.
"""
config = {"Outputs": outputs}
if kms_key_id is not None:
config["KmsKeyId"] = kms_key_id
return config

@staticmethod
def prepare_processing_resources(
instance_count, instance_type, volume_kms_key_id, volume_size_in_gb
):
"""Prepares a dict that represents the ProcessingResources.

Args:
instance_count (int): The number of ML compute instances
to use in the processing job. For distributed processing jobs,
specify a value greater than 1. The default value is 1.
instance_type (str): The ML compute instance type for the processing job.
volume_kms_key_id (str): The AWS Key Management Service (AWS KMS) key
that Amazon SageMaker uses to encrypt data on the storage
volume attached to the ML compute instance(s) that run the processing job.
volume_size_in_gb (int): The size of the ML storage volume in gigabytes
that you want to provision. You must specify sufficient
ML storage for your scenario.

Returns:
dict: Represents ProcessingResources which identifies the resources,
ML compute instances, and ML storage volumes to deploy
for a processing job.
"""
processing_resources = {}
cluster_config = {
"InstanceCount": instance_count,
"InstanceType": instance_type,
"VolumeSizeInGB": volume_size_in_gb,
}
if volume_kms_key_id is not None:
cluster_config["VolumeKmsKeyId"] = volume_kms_key_id
processing_resources["ClusterConfig"] = cluster_config
return processing_resources

@staticmethod
def prepare_stopping_condition(max_runtime_in_seconds):
"""Prepares a dict that represents the job's StoppingCondition.

Args:
max_runtime_in_seconds (int): Specifies the maximum runtime in seconds.

Returns:
dict
"""
return {"MaxRuntimeInSeconds": max_runtime_in_seconds}


class ProcessingInput(object):
"""Accepts parameters that specify an Amazon S3 input for a processing job and
Expand Down
105 changes: 105 additions & 0 deletions src/sagemaker/workflow/airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,3 +1070,108 @@ def deploy_config_from_estimator(
model.name = model_name
config = deploy_config(model, initial_instance_count, instance_type, endpoint_name, tags)
return config


def processing_config(
processor,
inputs=None,
outputs=None,
job_name=None,
experiment_config=None,
container_arguments=None,
container_entrypoint=None,
kms_key_id=None,
):
"""Export Airflow processing config from a SageMaker processor

Args:
processor (sagemaker.processor.Processor): The SageMaker
processor to export Airflow config from.
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
the processing job. These must be provided as
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
the processing job. These can be specified as either path strings or
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
job_name (str): Processing job name. If not specified, the processor generates
a default job name, based on the base job name and current timestamp.
experiment_config (dict[str, str]): Experiment management configuration.
Dictionary contains three optional keys:
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
container_arguments ([str]): The arguments for a container used to run a processing job.
container_entrypoint ([str]): The entrypoint for a container used to run a processing job.
kms_key_id (str): The AWS Key Management Service (AWS KMS) key that Amazon SageMaker
uses to encrypt the processing job output. KmsKeyId can be an ID of a KMS key,
ARN of a KMS key, alias of a KMS key, or alias of a KMS key.
The KmsKeyId is applied to all outputs.

Returns:
dict: Processing config that can be directly used by
SageMakerProcessingOperator in Airflow.
"""
if job_name is not None:
processor._current_job_name = job_name
else:
base_name = processor.base_job_name
processor._current_job_name = (
utils.name_from_base(base_name)
if base_name is not None
else utils.base_name_from_image(processor.image_uri)
)

config = {
"ProcessingJobName": processor._current_job_name,
"ProcessingInputs": input_output_list_converter(inputs),
}

processing_output_config = sagemaker.processing.ProcessingJob.prepare_output_config(
kms_key_id, input_output_list_converter(outputs)
)

config["ProcessingOutputConfig"] = processing_output_config

if experiment_config is not None:
config["ExperimentConfig"] = experiment_config

app_specification = sagemaker.processing.ProcessingJob.prepare_app_specification(
container_arguments, container_entrypoint, processor.image_uri
)
config["AppSpecification"] = app_specification

config["RoleArn"] = processor.role

if processor.env is not None:
config["Environment"] = processor.env

if processor.network_config is not None:
config["NetworkConfig"] = processor.network_config

processing_resources = sagemaker.processing.ProcessingJob.prepare_processing_resources(
instance_count=processor.instance_count,
instance_type=processor.instance_type,
volume_kms_key_id=processor.volume_kms_key,
volume_size_in_gb=processor.volume_size_in_gb,
)
config["ProcessingResources"] = processing_resources

stopping_condition = sagemaker.processing.ProcessingJob.prepare_stopping_condition(
processor.max_runtime_in_seconds
)
config["StoppingCondition"] = stopping_condition

if processor.tags is not None:
config["Tags"] = processor.tags

return config


def input_output_list_converter(object_list):
"""Converts a list of ProcessingInput or ProcessingOutput objects to a list of dicts

Args:
object_list (list[ProcessingInput or ProcessingOutput]

Returns:
List of dicts
"""
return [obj._to_request_dict() for obj in object_list]
111 changes: 110 additions & 1 deletion tests/unit/test_airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import pytest
from mock import Mock, MagicMock, patch

from sagemaker import chainer, estimator, model, mxnet, tensorflow, transformer, tuner
from sagemaker import chainer, estimator, model, mxnet, tensorflow, transformer, tuner, processing
from sagemaker.processing import ProcessingInput, ProcessingOutput
from sagemaker.workflow import airflow
from sagemaker.amazon import amazon_estimator
from sagemaker.amazon import knn, linear_learner, ntm, pca
Expand Down Expand Up @@ -1592,3 +1593,111 @@ def test_deploy_config_from_amazon_alg_estimator(sagemaker_session):
}

assert config == expected_config


@patch("sagemaker.utils.sagemaker_timestamp", MagicMock(return_value=TIME_STAMP))
def test_processing_config(sagemaker_session):

processor = processing.Processor(
role="arn:aws:iam::0122345678910:role/SageMakerPowerUser",
image_uri="{{ image_uri }}",
instance_count=2,
instance_type="ml.p2.xlarge",
entrypoint="{{ entrypoint }}",
volume_size_in_gb=30,
volume_kms_key="{{ kms_key }}",
output_kms_key="{{ kms_key }}",
max_runtime_in_seconds=3600,
base_job_name="processing_base_job_name",
sagemaker_session=sagemaker_session,
tags=[{"{{ key }}": "{{ value }}"}],
env={"{{ key }}": "{{ value }}"},
)

outputs = [
ProcessingOutput(
output_name="AnalyticsOutputName",
source="{{ Local Path }}",
destination="{{ S3Uri }}",
s3_upload_mode="EndOfJob",
)
]
inputs = [
ProcessingInput(
input_name="AnalyticsInputName",
source="{{ S3Uri }}",
destination="{{ Local Path }}",
s3_data_type="S3Prefix",
s3_input_mode="File",
s3_data_distribution_type="FullyReplicated",
s3_compression_type="None",
)
]

experiment_config = {}
experiment_config["ExperimentName"] = "ExperimentName"
experiment_config["TrialName"] = "TrialName"
experiment_config["TrialComponentDisplayName"] = "TrialComponentDisplayName"

config = airflow.processing_config(
processor,
inputs=inputs,
outputs=outputs,
job_name="ProcessingJobName",
container_arguments=["container_arg"],
container_entrypoint=["container_entrypoint"],
kms_key_id="KmsKeyID",
experiment_config=experiment_config,
)
expected_config = {
"AppSpecification": {
"ContainerArguments": ["container_arg"],
"ContainerEntrypoint": ["container_entrypoint"],
"ImageUri": "{{ image_uri }}",
},
"Environment": {"{{ key }}": "{{ value }}"},
"ExperimentConfig": {
"ExperimentName": "ExperimentName",
"TrialComponentDisplayName": "TrialComponentDisplayName",
"TrialName": "TrialName",
},
"ProcessingInputs": [
{
"InputName": "AnalyticsInputName",
"S3Input": {
"LocalPath": "{{ Local Path }}",
"S3CompressionType": "None",
"S3DataDistributionType": "FullyReplicated",
"S3DataType": "S3Prefix",
"S3InputMode": "File",
"S3Uri": "{{ S3Uri }}",
},
}
],
"ProcessingJobName": "ProcessingJobName",
"ProcessingOutputConfig": {
"KmsKeyId": "KmsKeyID",
"Outputs": [
{
"OutputName": "AnalyticsOutputName",
"S3Output": {
"LocalPath": "{{ Local Path }}",
"S3UploadMode": "EndOfJob",
"S3Uri": "{{ S3Uri }}",
},
}
],
},
"ProcessingResources": {
"ClusterConfig": {
"InstanceCount": 2,
"InstanceType": "ml.p2.xlarge",
"VolumeSizeInGB": 30,
"VolumeKmsKeyId": "{{ kms_key }}",
}
},
"RoleArn": "arn:aws:iam::0122345678910:role/SageMakerPowerUser",
"StoppingCondition": {"MaxRuntimeInSeconds": 3600},
"Tags": [{"{{ key }}": "{{ value }}"}],
}
assert config == expected_config