Skip to content

Commit a49cec2

Browse files
authored
fix: add Batch Transform data processing options to Airflow config (#1514)
1 parent 496f979 commit a49cec2

File tree

4 files changed

+92
-3
lines changed

4 files changed

+92
-3
lines changed

src/sagemaker/transformer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def transform(
135135
136136
* 'ManifestFile' - the S3 URI points to a single manifest file listing each S3
137137
object to use as an input for the transform job.
138+
138139
content_type (str): MIME type of the input data (default: None).
139140
compression_type (str): Compression type of the input data, if
140141
compressed (default: None). Valid values: 'Gzip', None.

src/sagemaker/workflow/airflow.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,9 @@ def transform_config(
671671
compression_type=None,
672672
split_type=None,
673673
job_name=None,
674+
input_filter=None,
675+
output_filter=None,
676+
join_source=None,
674677
):
675678
"""Export Airflow transform config from a SageMaker transformer
676679
@@ -686,13 +689,38 @@ def transform_config(
686689
687690
* 'ManifestFile' - the S3 URI points to a single manifest file listing each S3 object
688691
to use as an input for the transform job.
692+
689693
content_type (str): MIME type of the input data (default: None).
690694
compression_type (str): Compression type of the input data, if
691695
compressed (default: None). Valid values: 'Gzip', None.
692696
split_type (str): The record delimiter for the input object (default:
693697
'None'). Valid values: 'None', 'Line', 'RecordIO', and 'TFRecord'.
694698
job_name (str): job name (default: None). If not specified, one will be
695699
generated.
700+
input_filter (str): A JSONPath to select a portion of the input to
701+
pass to the algorithm container for inference. If you omit the
702+
field, it gets the value '$', representing the entire input.
703+
For CSV data, each row is taken as a JSON array,
704+
so only index-based JSONPaths can be applied, e.g. $[0], $[1:].
705+
CSV data should follow the `RFC format <https://tools.ietf.org/html/rfc4180>`_.
706+
See `Supported JSONPath Operators
707+
<https://docs.aws.amazon.com/sagemaker/latest/dg/batch-transform-data-processing.html#data-processing-operators>`_
708+
for a table of supported JSONPath operators.
709+
For more information, see the SageMaker API documentation for
710+
`CreateTransformJob
711+
<https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTransformJob.html>`_.
712+
Some examples: "$[1:]", "$.features" (default: None).
713+
output_filter (str): A JSONPath to select a portion of the
714+
joined/original output to return as the output.
715+
For more information, see the SageMaker API documentation for
716+
`CreateTransformJob
717+
<https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTransformJob.html>`_.
718+
Some examples: "$[1:]", "$.prediction" (default: None).
719+
join_source (str): The source of data to be joined to the transform
720+
output. It can be set to 'Input' meaning the entire input record
721+
will be joined to the inference result. You can use OutputFilter
722+
to select the useful portion before uploading to S3. (default:
723+
None). Valid values: Input, None.
696724
697725
Returns:
698726
dict: Transform config that can be directly used by
@@ -723,6 +751,12 @@ def transform_config(
723751
"TransformResources": job_config["resource_config"],
724752
}
725753

754+
data_processing = sagemaker.transformer._TransformJob._prepare_data_processing(
755+
input_filter, output_filter, join_source
756+
)
757+
if data_processing is not None:
758+
config["DataProcessing"] = data_processing
759+
726760
if transformer.strategy is not None:
727761
config["BatchStrategy"] = transformer.strategy
728762

@@ -768,6 +802,9 @@ def transform_config_from_estimator(
768802
model_server_workers=None,
769803
image=None,
770804
vpc_config_override=None,
805+
input_filter=None,
806+
output_filter=None,
807+
join_source=None,
771808
):
772809
"""Export Airflow transform config from a SageMaker estimator
773810
@@ -836,9 +873,35 @@ def transform_config_from_estimator(
836873
image (str): An container image to use for deploying the model
837874
vpc_config_override (dict[str, list[str]]): Override for VpcConfig set on
838875
the model. Default: use subnets and security groups from this Estimator.
876+
839877
* 'Subnets' (list[str]): List of subnet ids.
840878
* 'SecurityGroupIds' (list[str]): List of security group ids.
841879
880+
input_filter (str): A JSONPath to select a portion of the input to
881+
pass to the algorithm container for inference. If you omit the
882+
field, it gets the value '$', representing the entire input.
883+
For CSV data, each row is taken as a JSON array,
884+
so only index-based JSONPaths can be applied, e.g. $[0], $[1:].
885+
CSV data should follow the `RFC format <https://tools.ietf.org/html/rfc4180>`_.
886+
See `Supported JSONPath Operators
887+
<https://docs.aws.amazon.com/sagemaker/latest/dg/batch-transform-data-processing.html#data-processing-operators>`_
888+
for a table of supported JSONPath operators.
889+
For more information, see the SageMaker API documentation for
890+
`CreateTransformJob
891+
<https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTransformJob.html>`_.
892+
Some examples: "$[1:]", "$.features" (default: None).
893+
output_filter (str): A JSONPath to select a portion of the
894+
joined/original output to return as the output.
895+
For more information, see the SageMaker API documentation for
896+
`CreateTransformJob
897+
<https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTransformJob.html>`_.
898+
Some examples: "$[1:]", "$.prediction" (default: None).
899+
join_source (str): The source of data to be joined to the transform
900+
output. It can be set to 'Input' meaning the entire input record
901+
will be joined to the inference result. You can use OutputFilter
902+
to select the useful portion before uploading to S3. (default:
903+
None). Valid values: Input, None.
904+
842905
Returns:
843906
dict: Transform config that can be directly used by
844907
SageMakerTransformOperator in Airflow.
@@ -891,7 +954,16 @@ def transform_config_from_estimator(
891954
transformer.model_name = model_base_config["ModelName"]
892955

893956
transform_base_config = transform_config(
894-
transformer, data, data_type, content_type, compression_type, split_type, job_name
957+
transformer,
958+
data,
959+
data_type,
960+
content_type,
961+
compression_type,
962+
split_type,
963+
job_name,
964+
input_filter,
965+
output_filter,
966+
join_source,
895967
)
896968

897969
config = {"Model": model_base_config, "Transform": transform_base_config}

tests/integ/test_airflow_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,8 @@ def _build_airflow_workflow(estimator, instance_type, inputs=None, mini_batch_si
682682
instance_type=estimator.train_instance_type,
683683
data=inputs,
684684
content_type="text/csv",
685+
input_filter="$",
686+
output_filter="$",
685687
)
686688

687689
default_args = {

tests/unit/test_airflow.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
1413
from __future__ import absolute_import
1514

1615
import pytest
@@ -21,7 +20,6 @@
2120
from sagemaker.amazon import amazon_estimator
2221
from sagemaker.amazon import knn, linear_learner, ntm, pca
2322

24-
2523
REGION = "us-west-2"
2624
BUCKET_NAME = "output"
2725
TIME_STAMP = "1111"
@@ -1162,6 +1160,9 @@ def test_transform_config(sagemaker_session):
11621160
content_type="{{ content_type }}",
11631161
compression_type="{{ compression_type }}",
11641162
split_type="{{ split_type }}",
1163+
input_filter="{{ input_filter }}",
1164+
output_filter="{{ output_filter }}",
1165+
join_source="{{ join_source }}",
11651166
)
11661167
expected_config = {
11671168
"TransformJobName": "tensorflow-transform-%s" % TIME_STAMP,
@@ -1190,6 +1191,11 @@ def test_transform_config(sagemaker_session):
11901191
"MaxPayloadInMB": "{{ max_payload }}",
11911192
"Environment": {"{{ key }}": "{{ value }}"},
11921193
"Tags": [{"{{ key }}": "{{ value }}"}],
1194+
"DataProcessing": {
1195+
"InputFilter": "{{ input_filter }}",
1196+
"JoinSource": "{{ join_source }}",
1197+
"OutputFilter": "{{ output_filter }}",
1198+
},
11931199
}
11941200

11951201
assert config == expected_config
@@ -1238,6 +1244,9 @@ def test_transform_config_from_framework_estimator(ecr_prefix, sagemaker_session
12381244
instance_count="{{ instance_count }}",
12391245
instance_type="ml.p2.xlarge",
12401246
data=transform_data,
1247+
input_filter="{{ input_filter }}",
1248+
output_filter="{{ output_filter }}",
1249+
join_source="{{ join_source }}",
12411250
)
12421251
expected_config = {
12431252
"Model": {
@@ -1272,6 +1281,11 @@ def test_transform_config_from_framework_estimator(ecr_prefix, sagemaker_session
12721281
"InstanceType": "ml.p2.xlarge",
12731282
},
12741283
"Environment": {},
1284+
"DataProcessing": {
1285+
"InputFilter": "{{ input_filter }}",
1286+
"JoinSource": "{{ join_source }}",
1287+
"OutputFilter": "{{ output_filter }}",
1288+
},
12751289
},
12761290
}
12771291

0 commit comments

Comments
 (0)