Skip to content

Commit 52d0189

Browse files
authored
feature: Apache Airflow integration for SageMaker Processing Jobs (#1620)
1 parent 1ab0641 commit 52d0189

File tree

3 files changed

+302
-1
lines changed

3 files changed

+302
-1
lines changed

src/sagemaker/processing.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,93 @@ def stop(self):
740740
"""Stops the processing job."""
741741
self.sagemaker_session.stop_processing_job(self.name)
742742

743+
@staticmethod
744+
def prepare_app_specification(container_arguments, container_entrypoint, image_uri):
745+
"""Prepares a dict that represents a ProcessingJob's AppSpecification.
746+
747+
Args:
748+
container_arguments (list[str]): The arguments for a container
749+
used to run a processing job.
750+
container_entrypoint (list[str]): The entrypoint for a container
751+
used to run a processing job.
752+
image_uri (str): The container image to be run by the processing job.
753+
754+
Returns:
755+
dict: Represents AppSpecification which configures the
756+
processing job to run a specified Docker container image.
757+
"""
758+
config = {"ImageUri": image_uri}
759+
if container_arguments is not None:
760+
config["ContainerArguments"] = container_arguments
761+
if container_entrypoint is not None:
762+
config["ContainerEntrypoint"] = container_entrypoint
763+
return config
764+
765+
@staticmethod
766+
def prepare_output_config(kms_key_id, outputs):
767+
"""Prepares a dict that represents a ProcessingOutputConfig.
768+
769+
Args:
770+
kms_key_id (str): The AWS Key Management Service (AWS KMS) key that
771+
Amazon SageMaker uses to encrypt the processing job output.
772+
KmsKeyId can be an ID of a KMS key, ARN of a KMS key, alias of a KMS key,
773+
or alias of a KMS key. The KmsKeyId is applied to all outputs.
774+
outputs (list[dict]): Output configuration information for a processing job.
775+
776+
Returns:
777+
dict: Represents output configuration for the processing job.
778+
"""
779+
config = {"Outputs": outputs}
780+
if kms_key_id is not None:
781+
config["KmsKeyId"] = kms_key_id
782+
return config
783+
784+
@staticmethod
785+
def prepare_processing_resources(
786+
instance_count, instance_type, volume_kms_key_id, volume_size_in_gb
787+
):
788+
"""Prepares a dict that represents the ProcessingResources.
789+
790+
Args:
791+
instance_count (int): The number of ML compute instances
792+
to use in the processing job. For distributed processing jobs,
793+
specify a value greater than 1. The default value is 1.
794+
instance_type (str): The ML compute instance type for the processing job.
795+
volume_kms_key_id (str): The AWS Key Management Service (AWS KMS) key
796+
that Amazon SageMaker uses to encrypt data on the storage
797+
volume attached to the ML compute instance(s) that run the processing job.
798+
volume_size_in_gb (int): The size of the ML storage volume in gigabytes
799+
that you want to provision. You must specify sufficient
800+
ML storage for your scenario.
801+
802+
Returns:
803+
dict: Represents ProcessingResources which identifies the resources,
804+
ML compute instances, and ML storage volumes to deploy
805+
for a processing job.
806+
"""
807+
processing_resources = {}
808+
cluster_config = {
809+
"InstanceCount": instance_count,
810+
"InstanceType": instance_type,
811+
"VolumeSizeInGB": volume_size_in_gb,
812+
}
813+
if volume_kms_key_id is not None:
814+
cluster_config["VolumeKmsKeyId"] = volume_kms_key_id
815+
processing_resources["ClusterConfig"] = cluster_config
816+
return processing_resources
817+
818+
@staticmethod
819+
def prepare_stopping_condition(max_runtime_in_seconds):
820+
"""Prepares a dict that represents the job's StoppingCondition.
821+
822+
Args:
823+
max_runtime_in_seconds (int): Specifies the maximum runtime in seconds.
824+
825+
Returns:
826+
dict
827+
"""
828+
return {"MaxRuntimeInSeconds": max_runtime_in_seconds}
829+
743830

744831
class ProcessingInput(object):
745832
"""Accepts parameters that specify an Amazon S3 input for a processing job and

src/sagemaker/workflow/airflow.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,3 +1070,108 @@ def deploy_config_from_estimator(
10701070
model.name = model_name
10711071
config = deploy_config(model, initial_instance_count, instance_type, endpoint_name, tags)
10721072
return config
1073+
1074+
1075+
def processing_config(
1076+
processor,
1077+
inputs=None,
1078+
outputs=None,
1079+
job_name=None,
1080+
experiment_config=None,
1081+
container_arguments=None,
1082+
container_entrypoint=None,
1083+
kms_key_id=None,
1084+
):
1085+
"""Export Airflow processing config from a SageMaker processor
1086+
1087+
Args:
1088+
processor (sagemaker.processor.Processor): The SageMaker
1089+
processor to export Airflow config from.
1090+
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
1091+
the processing job. These must be provided as
1092+
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
1093+
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
1094+
the processing job. These can be specified as either path strings or
1095+
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
1096+
job_name (str): Processing job name. If not specified, the processor generates
1097+
a default job name, based on the base job name and current timestamp.
1098+
experiment_config (dict[str, str]): Experiment management configuration.
1099+
Dictionary contains three optional keys:
1100+
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
1101+
container_arguments ([str]): The arguments for a container used to run a processing job.
1102+
container_entrypoint ([str]): The entrypoint for a container used to run a processing job.
1103+
kms_key_id (str): The AWS Key Management Service (AWS KMS) key that Amazon SageMaker
1104+
uses to encrypt the processing job output. KmsKeyId can be an ID of a KMS key,
1105+
ARN of a KMS key, alias of a KMS key, or alias of a KMS key.
1106+
The KmsKeyId is applied to all outputs.
1107+
1108+
Returns:
1109+
dict: Processing config that can be directly used by
1110+
SageMakerProcessingOperator in Airflow.
1111+
"""
1112+
if job_name is not None:
1113+
processor._current_job_name = job_name
1114+
else:
1115+
base_name = processor.base_job_name
1116+
processor._current_job_name = (
1117+
utils.name_from_base(base_name)
1118+
if base_name is not None
1119+
else utils.base_name_from_image(processor.image_uri)
1120+
)
1121+
1122+
config = {
1123+
"ProcessingJobName": processor._current_job_name,
1124+
"ProcessingInputs": input_output_list_converter(inputs),
1125+
}
1126+
1127+
processing_output_config = sagemaker.processing.ProcessingJob.prepare_output_config(
1128+
kms_key_id, input_output_list_converter(outputs)
1129+
)
1130+
1131+
config["ProcessingOutputConfig"] = processing_output_config
1132+
1133+
if experiment_config is not None:
1134+
config["ExperimentConfig"] = experiment_config
1135+
1136+
app_specification = sagemaker.processing.ProcessingJob.prepare_app_specification(
1137+
container_arguments, container_entrypoint, processor.image_uri
1138+
)
1139+
config["AppSpecification"] = app_specification
1140+
1141+
config["RoleArn"] = processor.role
1142+
1143+
if processor.env is not None:
1144+
config["Environment"] = processor.env
1145+
1146+
if processor.network_config is not None:
1147+
config["NetworkConfig"] = processor.network_config
1148+
1149+
processing_resources = sagemaker.processing.ProcessingJob.prepare_processing_resources(
1150+
instance_count=processor.instance_count,
1151+
instance_type=processor.instance_type,
1152+
volume_kms_key_id=processor.volume_kms_key,
1153+
volume_size_in_gb=processor.volume_size_in_gb,
1154+
)
1155+
config["ProcessingResources"] = processing_resources
1156+
1157+
stopping_condition = sagemaker.processing.ProcessingJob.prepare_stopping_condition(
1158+
processor.max_runtime_in_seconds
1159+
)
1160+
config["StoppingCondition"] = stopping_condition
1161+
1162+
if processor.tags is not None:
1163+
config["Tags"] = processor.tags
1164+
1165+
return config
1166+
1167+
1168+
def input_output_list_converter(object_list):
1169+
"""Converts a list of ProcessingInput or ProcessingOutput objects to a list of dicts
1170+
1171+
Args:
1172+
object_list (list[ProcessingInput or ProcessingOutput]
1173+
1174+
Returns:
1175+
List of dicts
1176+
"""
1177+
return [obj._to_request_dict() for obj in object_list]

tests/unit/test_airflow.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
import pytest
1616
from mock import Mock, MagicMock, patch
1717

18-
from sagemaker import chainer, estimator, model, mxnet, tensorflow, transformer, tuner
18+
from sagemaker import chainer, estimator, model, mxnet, tensorflow, transformer, tuner, processing
19+
from sagemaker.processing import ProcessingInput, ProcessingOutput
1920
from sagemaker.workflow import airflow
2021
from sagemaker.amazon import amazon_estimator
2122
from sagemaker.amazon import knn, linear_learner, ntm, pca
@@ -1592,3 +1593,111 @@ def test_deploy_config_from_amazon_alg_estimator(sagemaker_session):
15921593
}
15931594

15941595
assert config == expected_config
1596+
1597+
1598+
@patch("sagemaker.utils.sagemaker_timestamp", MagicMock(return_value=TIME_STAMP))
1599+
def test_processing_config(sagemaker_session):
1600+
1601+
processor = processing.Processor(
1602+
role="arn:aws:iam::0122345678910:role/SageMakerPowerUser",
1603+
image_uri="{{ image_uri }}",
1604+
instance_count=2,
1605+
instance_type="ml.p2.xlarge",
1606+
entrypoint="{{ entrypoint }}",
1607+
volume_size_in_gb=30,
1608+
volume_kms_key="{{ kms_key }}",
1609+
output_kms_key="{{ kms_key }}",
1610+
max_runtime_in_seconds=3600,
1611+
base_job_name="processing_base_job_name",
1612+
sagemaker_session=sagemaker_session,
1613+
tags=[{"{{ key }}": "{{ value }}"}],
1614+
env={"{{ key }}": "{{ value }}"},
1615+
)
1616+
1617+
outputs = [
1618+
ProcessingOutput(
1619+
output_name="AnalyticsOutputName",
1620+
source="{{ Local Path }}",
1621+
destination="{{ S3Uri }}",
1622+
s3_upload_mode="EndOfJob",
1623+
)
1624+
]
1625+
inputs = [
1626+
ProcessingInput(
1627+
input_name="AnalyticsInputName",
1628+
source="{{ S3Uri }}",
1629+
destination="{{ Local Path }}",
1630+
s3_data_type="S3Prefix",
1631+
s3_input_mode="File",
1632+
s3_data_distribution_type="FullyReplicated",
1633+
s3_compression_type="None",
1634+
)
1635+
]
1636+
1637+
experiment_config = {}
1638+
experiment_config["ExperimentName"] = "ExperimentName"
1639+
experiment_config["TrialName"] = "TrialName"
1640+
experiment_config["TrialComponentDisplayName"] = "TrialComponentDisplayName"
1641+
1642+
config = airflow.processing_config(
1643+
processor,
1644+
inputs=inputs,
1645+
outputs=outputs,
1646+
job_name="ProcessingJobName",
1647+
container_arguments=["container_arg"],
1648+
container_entrypoint=["container_entrypoint"],
1649+
kms_key_id="KmsKeyID",
1650+
experiment_config=experiment_config,
1651+
)
1652+
expected_config = {
1653+
"AppSpecification": {
1654+
"ContainerArguments": ["container_arg"],
1655+
"ContainerEntrypoint": ["container_entrypoint"],
1656+
"ImageUri": "{{ image_uri }}",
1657+
},
1658+
"Environment": {"{{ key }}": "{{ value }}"},
1659+
"ExperimentConfig": {
1660+
"ExperimentName": "ExperimentName",
1661+
"TrialComponentDisplayName": "TrialComponentDisplayName",
1662+
"TrialName": "TrialName",
1663+
},
1664+
"ProcessingInputs": [
1665+
{
1666+
"InputName": "AnalyticsInputName",
1667+
"S3Input": {
1668+
"LocalPath": "{{ Local Path }}",
1669+
"S3CompressionType": "None",
1670+
"S3DataDistributionType": "FullyReplicated",
1671+
"S3DataType": "S3Prefix",
1672+
"S3InputMode": "File",
1673+
"S3Uri": "{{ S3Uri }}",
1674+
},
1675+
}
1676+
],
1677+
"ProcessingJobName": "ProcessingJobName",
1678+
"ProcessingOutputConfig": {
1679+
"KmsKeyId": "KmsKeyID",
1680+
"Outputs": [
1681+
{
1682+
"OutputName": "AnalyticsOutputName",
1683+
"S3Output": {
1684+
"LocalPath": "{{ Local Path }}",
1685+
"S3UploadMode": "EndOfJob",
1686+
"S3Uri": "{{ S3Uri }}",
1687+
},
1688+
}
1689+
],
1690+
},
1691+
"ProcessingResources": {
1692+
"ClusterConfig": {
1693+
"InstanceCount": 2,
1694+
"InstanceType": "ml.p2.xlarge",
1695+
"VolumeSizeInGB": 30,
1696+
"VolumeKmsKeyId": "{{ kms_key }}",
1697+
}
1698+
},
1699+
"RoleArn": "arn:aws:iam::0122345678910:role/SageMakerPowerUser",
1700+
"StoppingCondition": {"MaxRuntimeInSeconds": 3600},
1701+
"Tags": [{"{{ key }}": "{{ value }}"}],
1702+
}
1703+
assert config == expected_config

0 commit comments

Comments
 (0)