Skip to content

Commit 53d8c5e

Browse files
committed
feature: Apache Airflow integration for SageMaker Processing Jobs
1 parent 18af12b commit 53d8c5e

File tree

3 files changed

+302
-1
lines changed

3 files changed

+302
-1
lines changed

src/sagemaker/processing.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,94 @@ def stop(self):
739739
"""Stops the processing job."""
740740
self.sagemaker_session.stop_processing_job(self.name)
741741

742+
@staticmethod
743+
def prepare_app_specification(container_arguments, container_entrypoint, image_uri):
744+
"""Prepares a dict that represents a ProcessingJob's AppSpecification
745+
which configures the processing job to run a specified Docker container image.
746+
747+
Args:
748+
container_arguments (list[str]): The arguments for a container
749+
used to run a processing job.
750+
container_entrypoint (list[str]): The entrypoint for a container
751+
used to run a processing job.
752+
image_uri (str): The container image to be run by the processing job.
753+
754+
Returns:
755+
dict
756+
"""
757+
config = {"ImageUri": image_uri}
758+
if container_arguments is not None:
759+
config["ContainerArguments"] = container_arguments
760+
if container_entrypoint is not None:
761+
config["ContainerEntrypoint"] = container_entrypoint
762+
return config
763+
764+
@staticmethod
765+
def prepare_output_config(kms_key_id, outputs):
766+
"""Prepares a dict that represents a ProcessingOutputConfig,
767+
output configuration for the processing job.
768+
769+
Args:
770+
kms_key_id (str): The AWS Key Management Service (AWS KMS) key that
771+
Amazon SageMaker uses to encrypt the processing job output.
772+
KmsKeyId can be an ID of a KMS key, ARN of a KMS key, alias of a KMS key,
773+
or alias of a KMS key. The KmsKeyId is applied to all outputs.
774+
outputs (list[dict]): Output configuration information for a processing job.
775+
776+
Returns:
777+
dict
778+
"""
779+
config = {"Outputs": outputs}
780+
if kms_key_id is not None:
781+
config["KmsKeyId"] = kms_key_id
782+
return config
783+
784+
@staticmethod
785+
def prepare_processing_resources(
786+
instance_count, instance_type, volume_kms_key_id, volume_size_in_gb
787+
):
788+
"""Prepares a dict that represents the ProcessingResources
789+
which identifies the resources, ML compute instances,
790+
and ML storage volumes to deploy for a processing job.
791+
792+
Args:
793+
instance_count (int): The number of ML compute instances
794+
to use in the processing job. For distributed processing jobs,
795+
specify a value greater than 1. The default value is 1.
796+
instance_type (str): The ML compute instance type for the processing job.
797+
volume_kms_key_id (str): The AWS Key Management Service (AWS KMS) key
798+
that Amazon SageMaker uses to encrypt data on the storage
799+
volume attached to the ML compute instance(s) that run the processing job.
800+
volume_size_in_gb (int): The size of the ML storage volume in gigabytes
801+
that you want to provision. You must specify sufficient
802+
ML storage for your scenario.
803+
804+
Returns:
805+
dict
806+
"""
807+
processing_resources = {}
808+
cluster_config = {}
809+
if volume_kms_key_id is not None:
810+
cluster_config["VolumeKmsKeyId"] = volume_kms_key_id
811+
cluster_config["InstanceCount"] = instance_count
812+
cluster_config["InstanceType"] = instance_type
813+
cluster_config["VolumeSizeInGB"] = volume_size_in_gb
814+
processing_resources["ClusterConfig"] = cluster_config
815+
return processing_resources
816+
817+
@staticmethod
818+
def prepare_stopping_condition(max_runtime_in_seconds):
819+
"""Prepares a dict that represents the job's StoppingCondition
820+
821+
Args:
822+
max_runtime_in_seconds (int): Specifies the maximum runtime in seconds.
823+
824+
Returns:
825+
dict
826+
"""
827+
return {"MaxRuntimeInSeconds": max_runtime_in_seconds}
828+
829+
742830

743831
class ProcessingInput(object):
744832
"""Accepts parameters that specify an Amazon S3 input for a processing job and

src/sagemaker/workflow/airflow.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,3 +1065,107 @@ def deploy_config_from_estimator(
10651065
model.name = model_name
10661066
config = deploy_config(model, initial_instance_count, instance_type, endpoint_name, tags)
10671067
return config
1068+
1069+
1070+
def processing_config(
1071+
processor,
1072+
inputs=None,
1073+
outputs=None,
1074+
job_name=None,
1075+
experiment_config=None,
1076+
container_arguments=None,
1077+
container_entrypoint=None,
1078+
kms_key_id=None,
1079+
):
1080+
"""Export Airflow processing config from a SageMaker processor
1081+
1082+
Args:
1083+
processor (sagemaker.processor.Processor): The SageMaker
1084+
processor to export Airflow config from.
1085+
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
1086+
the processing job. These must be provided as
1087+
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
1088+
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
1089+
the processing job. These can be specified as either path strings or
1090+
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
1091+
job_name (str): Processing job name. If not specified, the processor generates
1092+
a default job name, based on the base job name and current timestamp.
1093+
experiment_config (dict[str, str]): Experiment management configuration.
1094+
Dictionary contains three optional keys:
1095+
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
1096+
container_arguments ([str]): The arguments for a container used to run a processing job.
1097+
container_entrypoint ([str]): The entrypoint for a container used to run a processing job.
1098+
kms_key_id (str): The AWS Key Management Service (AWS KMS) key that Amazon SageMaker
1099+
uses to encrypt the processing job output. KmsKeyId can be an ID of a KMS key,
1100+
ARN of a KMS key, alias of a KMS key, or alias of a KMS key.
1101+
The KmsKeyId is applied to all outputs.
1102+
1103+
Returns:
1104+
dict: Processing config that can be directly used by
1105+
SageMakerProcessingOperator in Airflow.
1106+
"""
1107+
if job_name is not None:
1108+
processor._current_job_name = job_name
1109+
else:
1110+
base_name = processor.base_job_name
1111+
processor._current_job_name = (
1112+
utils.name_from_base(base_name)
1113+
if base_name is not None
1114+
else utils.base_name_from_image(processor.image_uri)
1115+
)
1116+
1117+
config = {
1118+
"ProcessingJobName": processor._current_job_name,
1119+
"ProcessingInputs": input_output_list_converter(inputs),
1120+
}
1121+
1122+
processing_output_config = sagemaker.processing.ProcessingJob.prepare_output_config(
1123+
kms_key_id, input_output_list_converter(outputs))
1124+
1125+
config["ProcessingOutputConfig"] = processing_output_config
1126+
1127+
if experiment_config is not None:
1128+
config["ExperimentConfig"] = experiment_config
1129+
1130+
app_specification = sagemaker.processing.ProcessingJob.prepare_app_specification(
1131+
container_arguments, container_entrypoint, processor.image_uri
1132+
)
1133+
config["AppSpecification"] = app_specification
1134+
1135+
config["RoleArn"] = processor.role
1136+
1137+
if processor.env is not None:
1138+
config["Environment"] = processor.env
1139+
1140+
if processor.network_config is not None:
1141+
config["NetworkConfig"] = processor.network_config
1142+
1143+
processing_resources = sagemaker.processing.ProcessingJob.prepare_processing_resources(
1144+
instance_count=processor.instance_count,
1145+
instance_type=processor.instance_type,
1146+
volume_kms_key_id=processor.volume_kms_key,
1147+
volume_size_in_gb=processor.volume_size_in_gb,
1148+
)
1149+
config["ProcessingResources"] = processing_resources
1150+
1151+
stopping_condition = sagemaker.processing.ProcessingJob.prepare_stopping_condition(
1152+
processor.max_runtime_in_seconds
1153+
)
1154+
config["StoppingCondition"] = stopping_condition
1155+
1156+
if processor.tags is not None:
1157+
config["Tags"] = processor.tags
1158+
1159+
return config
1160+
1161+
1162+
def input_output_list_converter(object_list):
1163+
"""Converts a list of ProcessingInput or ProcessingOutput objects to a list of dicts
1164+
1165+
Args:
1166+
object_list (list[ProcessingInput or ProcessingOutput]
1167+
1168+
Returns:
1169+
List of dicts
1170+
"""
1171+
return [obj._to_request_dict() for obj in object_list]

tests/unit/test_airflow.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
import pytest
1616
from mock import Mock, MagicMock, patch
1717

18-
from sagemaker import chainer, estimator, model, mxnet, tensorflow, transformer, tuner
18+
from sagemaker import chainer, estimator, model, mxnet, tensorflow, transformer, tuner, processing
19+
from sagemaker.processing import ProcessingInput, ProcessingOutput
1920
from sagemaker.workflow import airflow
2021
from sagemaker.amazon import amazon_estimator
2122
from sagemaker.amazon import knn, linear_learner, ntm, pca
@@ -1590,3 +1591,111 @@ def test_deploy_config_from_amazon_alg_estimator(sagemaker_session):
15901591
}
15911592

15921593
assert config == expected_config
1594+
1595+
1596+
@patch("sagemaker.utils.sagemaker_timestamp", MagicMock(return_value=TIME_STAMP))
1597+
def test_processing_config(sagemaker_session):
1598+
1599+
processor = processing.Processor(
1600+
role="arn:aws:iam::0122345678910:role/SageMakerPowerUser",
1601+
image_uri="{{ image_uri }}",
1602+
instance_count=2,
1603+
instance_type="ml.p2.xlarge",
1604+
entrypoint="{{ entrypoint }}",
1605+
volume_size_in_gb=30,
1606+
volume_kms_key="{{ kms_key }}",
1607+
output_kms_key="{{ kms_key }}",
1608+
max_runtime_in_seconds=3600,
1609+
base_job_name="processing_base_job_name",
1610+
sagemaker_session=sagemaker_session,
1611+
tags=[{"{{ key }}": "{{ value }}"}],
1612+
env={"{{ key }}": "{{ value }}"},
1613+
)
1614+
1615+
outputs = [
1616+
ProcessingOutput(
1617+
output_name="AnalyticsOutputName",
1618+
source="{{ Local Path }}",
1619+
destination="{{ S3Uri }}",
1620+
s3_upload_mode="EndOfJob",
1621+
)
1622+
]
1623+
inputs = [
1624+
ProcessingInput(
1625+
input_name="AnalyticsInputName",
1626+
source="{{ S3Uri }}",
1627+
destination="{{ Local Path }}",
1628+
s3_data_type="S3Prefix",
1629+
s3_input_mode="File",
1630+
s3_data_distribution_type="FullyReplicated",
1631+
s3_compression_type="None",
1632+
)
1633+
]
1634+
1635+
experiment_config = {}
1636+
experiment_config["ExperimentName"] = "ExperimentName"
1637+
experiment_config["TrialName"] = "TrialName"
1638+
experiment_config["TrialComponentDisplayName"] = "TrialComponentDisplayName"
1639+
1640+
config = airflow.processing_config(
1641+
processor,
1642+
inputs=inputs,
1643+
outputs=outputs,
1644+
job_name="ProcessingJobName",
1645+
container_arguments=["container_arg"],
1646+
container_entrypoint=["container_entrypoint"],
1647+
kms_key_id="KmsKeyID",
1648+
experiment_config=experiment_config,
1649+
)
1650+
expected_config = {
1651+
"AppSpecification": {
1652+
"ContainerArguments": ["container_arg"],
1653+
"ContainerEntrypoint": ["container_entrypoint"],
1654+
"ImageUri": "{{ image_uri }}",
1655+
},
1656+
"Environment": {"{{ key }}": "{{ value }}"},
1657+
"ExperimentConfig": {
1658+
"ExperimentName": "ExperimentName",
1659+
"TrialComponentDisplayName": "TrialComponentDisplayName",
1660+
"TrialName": "TrialName",
1661+
},
1662+
"ProcessingInputs": [
1663+
{
1664+
"InputName": "AnalyticsInputName",
1665+
"S3Input": {
1666+
"LocalPath": "{{ Local Path }}",
1667+
"S3CompressionType": "None",
1668+
"S3DataDistributionType": "FullyReplicated",
1669+
"S3DataType": "S3Prefix",
1670+
"S3InputMode": "File",
1671+
"S3Uri": "{{ S3Uri }}",
1672+
},
1673+
}
1674+
],
1675+
"ProcessingJobName": "ProcessingJobName",
1676+
"ProcessingOutputConfig": {
1677+
"KmsKeyId": "KmsKeyID",
1678+
"Outputs": [
1679+
{
1680+
"OutputName": "AnalyticsOutputName",
1681+
"S3Output": {
1682+
"LocalPath": "{{ Local Path }}",
1683+
"S3UploadMode": "EndOfJob",
1684+
"S3Uri": "{{ S3Uri }}",
1685+
},
1686+
}
1687+
],
1688+
},
1689+
"ProcessingResources": {
1690+
"ClusterConfig": {
1691+
"InstanceCount": 2,
1692+
"InstanceType": "ml.p2.xlarge",
1693+
"VolumeSizeInGB": 30,
1694+
"VolumeKmsKeyId": "{{ kms_key }}",
1695+
}
1696+
},
1697+
"RoleArn": "arn:aws:iam::0122345678910:role/SageMakerPowerUser",
1698+
"StoppingCondition": {"MaxRuntimeInSeconds": 3600},
1699+
"Tags": [{"{{ key }}": "{{ value }}"}],
1700+
}
1701+
assert config == expected_config

0 commit comments

Comments
 (0)