Skip to content

Commit b6abe7c

Browse files
committed
feature: Apache Airflow integration for SageMaker Processing Jobs
1 parent 18af12b commit b6abe7c

File tree

3 files changed

+266
-1
lines changed

3 files changed

+266
-1
lines changed

src/sagemaker/processing.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,66 @@ def stop(self):
739739
"""Stops the processing job."""
740740
self.sagemaker_session.stop_processing_job(self.name)
741741

742+
@staticmethod
743+
def _prepare_app_specification(container_arguments, container_entrypoint, image_uri):
744+
"""
745+
Args:
746+
container_arguments:
747+
container_entrypoint:
748+
image_uri:
749+
"""
750+
config = {}
751+
if container_arguments is not None:
752+
config["ContainerArguments"] = container_arguments
753+
if container_entrypoint is not None:
754+
config["ContainerEntrypoint"] = container_entrypoint
755+
config["ImageUri"] = image_uri
756+
return config
757+
758+
@staticmethod
759+
def _prepare_output_config(kms_key_id, outputs):
760+
"""
761+
Args:
762+
kms_key_id:
763+
outputs:
764+
"""
765+
config = {}
766+
if kms_key_id is not None:
767+
config["KmsKeyId"] = kms_key_id
768+
config["Outputs"] = outputs
769+
return config
770+
771+
@staticmethod
772+
def _prepare_processing_resources(
773+
instance_count, instance_type, volume_kms_key_id, volume_size_in_gb
774+
):
775+
"""
776+
Args:
777+
instance_count:
778+
instance_type:
779+
volume_kms_key_id:
780+
volume_size_in_gb:
781+
"""
782+
processing_resources = {}
783+
cluster_config = {}
784+
if volume_kms_key_id is not None:
785+
cluster_config["VolumeKmsKeyId"] = volume_kms_key_id
786+
cluster_config["InstanceCount"] = instance_count
787+
cluster_config["InstanceType"] = instance_type
788+
cluster_config["VolumeSizeInGB"] = volume_size_in_gb
789+
processing_resources["ClusterConfig"] = cluster_config
790+
return processing_resources
791+
792+
@staticmethod
793+
def _prepare_stopping_condition(max_runtime_in_seconds):
794+
"""
795+
Args:
796+
max_runtime_in_seconds
797+
"""
798+
stopping_condition = {}
799+
stopping_condition["MaxRuntimeInSeconds"] = max_runtime_in_seconds
800+
return stopping_condition
801+
742802

743803
class ProcessingInput(object):
744804
"""Accepts parameters that specify an Amazon S3 input for a processing job and

src/sagemaker/workflow/airflow.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,3 +1065,99 @@ def deploy_config_from_estimator(
10651065
model.name = model_name
10661066
config = deploy_config(model, initial_instance_count, instance_type, endpoint_name, tags)
10671067
return config
1068+
1069+
1070+
def processing_config(
1071+
processor,
1072+
inputs=None,
1073+
outputs=None,
1074+
job_name=None,
1075+
experiment_config=None,
1076+
container_arguments=None,
1077+
container_entrypoint=None,
1078+
kms_key_id=None,
1079+
):
1080+
"""Export Airflow processing config from a SageMaker processor
1081+
1082+
Args:
1083+
processor (sagemaker.processor.Processor): The SageMaker
1084+
processor to export Airflow config from.
1085+
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
1086+
the processing job. These must be provided as
1087+
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
1088+
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
1089+
the processing job. These can be specified as either path strings or
1090+
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
1091+
job_name (str): Processing job name. If not specified, the processor generates
1092+
a default job name, based on the base job name and current timestamp.
1093+
experiment_config (dict[str, str]): Experiment management configuration.
1094+
Dictionary contains three optional keys:
1095+
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
1096+
container_arguments ([str]): The arguments for a container used to run a processing job.
1097+
container_entrypoint ([str]): The entrypoint for a container used to run a processing job.
1098+
kms_key_id (str): The AWS Key Management Service (AWS KMS) key that Amazon SageMaker
1099+
uses to encrypt the processing job output. KmsKeyId can be an ID of a KMS key,
1100+
ARN of a KMS key, alias of a KMS key, or alias of a KMS key.
1101+
The KmsKeyId is applied to all outputs.
1102+
1103+
Returns:
1104+
dict: Processing config that can be directly used by
1105+
SageMakerProcessingOperator in Airflow.
1106+
"""
1107+
if job_name is not None:
1108+
processor._current_job_name = job_name
1109+
else:
1110+
base_name = processor.base_job_name
1111+
processor._current_job_name = (
1112+
utils.name_from_base(base_name)
1113+
if base_name is not None
1114+
else utils.base_name_from_image(processor.image_uri)
1115+
)
1116+
1117+
input_dict = []
1118+
for i in inputs:
1119+
input_dict.append(i._to_request_dict())
1120+
config = {"ProcessingJobName": processor._current_job_name, "ProcessingInputs": input_dict}
1121+
1122+
output_dict = []
1123+
for output in outputs:
1124+
output_dict.append(output._to_request_dict())
1125+
processing_output_config = sagemaker.processing.ProcessingJob._prepare_output_config(
1126+
kms_key_id, output_dict
1127+
)
1128+
1129+
config["ProcessingOutputConfig"] = processing_output_config
1130+
1131+
if experiment_config is not None:
1132+
config["ExperimentConfig"] = experiment_config
1133+
1134+
app_specification = sagemaker.processing.ProcessingJob._prepare_app_specification(
1135+
container_arguments, container_entrypoint, processor.image_uri
1136+
)
1137+
config["AppSpecification"] = app_specification
1138+
1139+
config["RoleArn"] = processor.role
1140+
1141+
if processor.env is not None:
1142+
config["Environment"] = processor.env
1143+
1144+
if processor.network_config is not None:
1145+
config["NetworkConfig"] = processor.network_config
1146+
1147+
processing_resources = sagemaker.processing.ProcessingJob._prepare_processing_resources(
1148+
instance_count=processor.instance_count,
1149+
instance_type=processor.instance_type,
1150+
volume_kms_key_id=processor.volume_kms_key,
1151+
volume_size_in_gb=processor.volume_size_in_gb,
1152+
)
1153+
config["ProcessingResources"] = processing_resources
1154+
1155+
stopping_condition = sagemaker.processing.ProcessingJob._prepare_stopping_condition(
1156+
processor.max_runtime_in_seconds
1157+
)
1158+
config["StoppingCondition"] = stopping_condition
1159+
1160+
if processor.tags is not None:
1161+
config["Tags"] = processor.tags
1162+
1163+
return config

tests/unit/test_airflow.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
import pytest
1616
from mock import Mock, MagicMock, patch
1717

18-
from sagemaker import chainer, estimator, model, mxnet, tensorflow, transformer, tuner
18+
from sagemaker import chainer, estimator, model, mxnet, tensorflow, transformer, tuner, processing
19+
from sagemaker.processing import ProcessingInput, ProcessingOutput
1920
from sagemaker.workflow import airflow
2021
from sagemaker.amazon import amazon_estimator
2122
from sagemaker.amazon import knn, linear_learner, ntm, pca
@@ -1590,3 +1591,111 @@ def test_deploy_config_from_amazon_alg_estimator(sagemaker_session):
15901591
}
15911592

15921593
assert config == expected_config
1594+
1595+
1596+
@patch("sagemaker.utils.sagemaker_timestamp", MagicMock(return_value=TIME_STAMP))
1597+
def test_processing_config(sagemaker_session):
1598+
1599+
processor = processing.Processor(
1600+
role="arn:aws:iam::0122345678910:role/SageMakerPowerUser",
1601+
image_uri="{{ image_uri }}",
1602+
instance_count=2,
1603+
instance_type="ml.p2.xlarge",
1604+
entrypoint="{{ entrypoint }}",
1605+
volume_size_in_gb=30,
1606+
volume_kms_key="{{ kms_key }}",
1607+
output_kms_key="{{ kms_key }}",
1608+
max_runtime_in_seconds=3600,
1609+
base_job_name="processing_base_job_name",
1610+
sagemaker_session=sagemaker_session,
1611+
tags=[{"{{ key }}": "{{ value }}"}],
1612+
env={"{{ key }}": "{{ value }}"},
1613+
)
1614+
1615+
outputs = [
1616+
ProcessingOutput(
1617+
output_name="AnalyticsOutputName",
1618+
source="{{ Local Path }}",
1619+
destination="{{ S3Uri }}",
1620+
s3_upload_mode="EndOfJob",
1621+
)
1622+
]
1623+
inputs = [
1624+
ProcessingInput(
1625+
input_name="AnalyticsInputName",
1626+
source="{{ S3Uri }}",
1627+
destination="{{ Local Path }}",
1628+
s3_data_type="S3Prefix",
1629+
s3_input_mode="File",
1630+
s3_data_distribution_type="FullyReplicated",
1631+
s3_compression_type="None",
1632+
)
1633+
]
1634+
1635+
experiment_config = {}
1636+
experiment_config["ExperimentName"] = "ExperimentName"
1637+
experiment_config["TrialName"] = "TrialName"
1638+
experiment_config["TrialComponentDisplayName"] = "TrialComponentDisplayName"
1639+
1640+
config = airflow.processing_config(
1641+
processor,
1642+
inputs=inputs,
1643+
outputs=outputs,
1644+
job_name="ProcessingJobName",
1645+
container_arguments=["container_arg"],
1646+
container_entrypoint=["container_entrypoint"],
1647+
kms_key_id="KmsKeyID",
1648+
experiment_config=experiment_config,
1649+
)
1650+
expected_config = {
1651+
"AppSpecification": {
1652+
"ContainerArguments": ["container_arg"],
1653+
"ContainerEntrypoint": ["container_entrypoint"],
1654+
"ImageUri": "{{ image_uri }}",
1655+
},
1656+
"Environment": {"{{ key }}": "{{ value }}"},
1657+
"ExperimentConfig": {
1658+
"ExperimentName": "ExperimentName",
1659+
"TrialComponentDisplayName": "TrialComponentDisplayName",
1660+
"TrialName": "TrialName",
1661+
},
1662+
"ProcessingInputs": [
1663+
{
1664+
"InputName": "AnalyticsInputName",
1665+
"S3Input": {
1666+
"LocalPath": "{{ Local Path }}",
1667+
"S3CompressionType": "None",
1668+
"S3DataDistributionType": "FullyReplicated",
1669+
"S3DataType": "S3Prefix",
1670+
"S3InputMode": "File",
1671+
"S3Uri": "{{ S3Uri }}",
1672+
},
1673+
}
1674+
],
1675+
"ProcessingJobName": "ProcessingJobName",
1676+
"ProcessingOutputConfig": {
1677+
"KmsKeyId": "KmsKeyID",
1678+
"Outputs": [
1679+
{
1680+
"OutputName": "AnalyticsOutputName",
1681+
"S3Output": {
1682+
"LocalPath": "{{ Local Path }}",
1683+
"S3UploadMode": "EndOfJob",
1684+
"S3Uri": "{{ S3Uri }}",
1685+
},
1686+
}
1687+
],
1688+
},
1689+
"ProcessingResources": {
1690+
"ClusterConfig": {
1691+
"InstanceCount": 2,
1692+
"InstanceType": "ml.p2.xlarge",
1693+
"VolumeSizeInGB": 30,
1694+
"VolumeKmsKeyId": "{{ kms_key }}",
1695+
}
1696+
},
1697+
"RoleArn": "arn:aws:iam::0122345678910:role/SageMakerPowerUser",
1698+
"StoppingCondition": {"MaxRuntimeInSeconds": 3600},
1699+
"Tags": [{"{{ key }}": "{{ value }}"}],
1700+
}
1701+
assert config == expected_config

0 commit comments

Comments
 (0)