-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feature: Add DataProcessing Fields for Batch Transform #827
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
516df28
2d61291
cdcde78
349d30b
59357d5
06e7d9e
5a7f08c
3e52c2e
0cb74c2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -79,7 +79,7 @@ def __init__(self, model_name, instance_count, instance_type, strategy=None, ass | |
self.sagemaker_session = sagemaker_session or Session() | ||
|
||
def transform(self, data, data_type='S3Prefix', content_type=None, compression_type=None, split_type=None, | ||
job_name=None): | ||
job_name=None, input_filter=None, output_filter=None, join_source=None): | ||
"""Start a new transform job. | ||
|
||
Args: | ||
|
@@ -97,6 +97,15 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t | |
split_type (str): The record delimiter for the input object (default: 'None'). | ||
Valid values: 'None', 'Line', 'RecordIO', and 'TFRecord'. | ||
job_name (str): job name (default: None). If not specified, one will be generated. | ||
input_filter (str): A JSONPath to select a portion of the input to pass to the algorithm container for | ||
inference. If you omit the field, it gets the value '$', representing the entire input. | ||
Some examples: "$[1:]", "$.features"(default: None). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add a link here for where to find additional information. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. I just linked to Top level API for now as docs are not public yet |
||
output_filter (str): A JSONPath to select a portion of the joined/original output to return as the output. | ||
Some examples: "$[1:]", "$.prediction" (default: None). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
join_source (str): The source of data to be joined to the transform output. It can be set to 'Input' | ||
meaning the entire input record will be joined to the inference result. | ||
You can use OutputFilter to select the useful portion before uploading to S3. (default: None). | ||
Valid values: Input, None. | ||
""" | ||
local_mode = self.sagemaker_session.local_mode | ||
if not local_mode and not data.startswith('s3://'): | ||
|
@@ -116,7 +125,7 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t | |
self.output_path = 's3://{}/{}'.format(self.sagemaker_session.default_bucket(), self._current_job_name) | ||
|
||
self.latest_transform_job = _TransformJob.start_new(self, data, data_type, content_type, compression_type, | ||
split_type) | ||
split_type, input_filter, output_filter, join_source) | ||
|
||
def delete_model(self): | ||
"""Delete the corresponding SageMaker model for this Transformer. | ||
|
@@ -214,16 +223,19 @@ def _prepare_init_params_from_job_description(cls, job_details): | |
|
||
class _TransformJob(_Job): | ||
@classmethod | ||
def start_new(cls, transformer, data, data_type, content_type, compression_type, split_type): | ||
def start_new(cls, transformer, data, data_type, content_type, compression_type, | ||
split_type, input_filter, output_filter, join_source): | ||
config = _TransformJob._load_config(data, data_type, content_type, compression_type, split_type, transformer) | ||
data_processing = _TransformJob._prepare_data_processing(input_filter, output_filter, join_source) | ||
|
||
transformer.sagemaker_session.transform(job_name=transformer._current_job_name, | ||
model_name=transformer.model_name, strategy=transformer.strategy, | ||
max_concurrent_transforms=transformer.max_concurrent_transforms, | ||
max_payload=transformer.max_payload, env=transformer.env, | ||
input_config=config['input_config'], | ||
output_config=config['output_config'], | ||
resource_config=config['resource_config'], tags=transformer.tags) | ||
resource_config=config['resource_config'], | ||
tags=transformer.tags, data_processing=data_processing) | ||
|
||
return cls(transformer.sagemaker_session, transformer._current_job_name) | ||
|
||
|
@@ -287,3 +299,21 @@ def _prepare_resource_config(instance_count, instance_type, volume_kms_key): | |
config['VolumeKmsKeyId'] = volume_kms_key | ||
|
||
return config | ||
|
||
@staticmethod | ||
def _prepare_data_processing(input_filter, output_filter, join_source): | ||
config = {} | ||
|
||
if input_filter is not None: | ||
config['InputFilter'] = input_filter | ||
|
||
if output_filter is not None: | ||
config['OutputFilter'] = output_filter | ||
|
||
if join_source is not None: | ||
config['JoinSource'] = join_source | ||
|
||
if len(config) == 0: | ||
return None | ||
|
||
return config |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,7 +68,7 @@ def _create_kms_key(kms_client, | |
role_arn=role_arn, | ||
sagemaker_role=sagemaker_role) | ||
else: | ||
principal = "{account_id}".format(account_id=account_id) | ||
principal = '"{account_id}"'.format(account_id=account_id) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any special reason to include this change on this PR? |
||
|
||
response = kms_client.create_key( | ||
Policy=KEY_POLICY.format(id=POLICY_NAME, principal=principal, sagemaker_role=sagemaker_role), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,15 +54,19 @@ def test_transform_mxnet(sagemaker_session, mxnet_full_version): | |
key_prefix=transform_input_key_prefix) | ||
|
||
kms_key_arn = get_or_create_kms_key(sagemaker_session) | ||
output_filter = "$" | ||
|
||
transformer = _create_transformer_and_transform_job(mx, transform_input, kms_key_arn) | ||
transformer = _create_transformer_and_transform_job(mx, transform_input, kms_key_arn, | ||
input_filter=None, output_filter=output_filter, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added input_filter. I am setting json_source to None as it allows me to not change output of the test |
||
join_source=None) | ||
with timeout_and_delete_model_with_transformer(transformer, sagemaker_session, | ||
minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES): | ||
transformer.wait() | ||
|
||
job_desc = transformer.sagemaker_session.sagemaker_client.describe_transform_job( | ||
TransformJobName=transformer.latest_transform_job.name) | ||
assert kms_key_arn == job_desc['TransformResources']['VolumeKmsKeyId'] | ||
assert output_filter == job_desc['DataProcessing']['OutputFilter'] | ||
|
||
|
||
@pytest.mark.canary_quick | ||
|
@@ -232,7 +236,9 @@ def test_transform_byo_estimator(sagemaker_session): | |
assert tags == model_tags | ||
|
||
|
||
def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None): | ||
def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None, | ||
input_filter=None, output_filter=None, join_source=None): | ||
transformer = estimator.transformer(1, 'ml.m4.xlarge', volume_kms_key=volume_kms_key) | ||
transformer.transform(transform_input, content_type='text/csv') | ||
transformer.transform(transform_input, content_type='text/csv', input_filter=input_filter, | ||
output_filter=output_filter, join_source=join_source) | ||
return transformer |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -588,7 +588,7 @@ def test_transform_pack_to_request(sagemaker_session): | |
|
||
sagemaker_session.transform(job_name=JOB_NAME, model_name=model_name, strategy=None, max_concurrent_transforms=None, | ||
max_payload=None, env=None, input_config=in_config, output_config=out_config, | ||
resource_config=resource_config, tags=None) | ||
resource_config=resource_config, tags=None, data_processing=None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You need to unit test the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice catch. done. There are two tests 1 with required parameter and other with optional parameter to cover different use case. I used them accordingly #864 |
||
|
||
_, _, actual_args = sagemaker_session.sagemaker_client.method_calls[0] | ||
assert actual_args == expected_args | ||
|
@@ -603,7 +603,7 @@ def test_transform_pack_to_request_with_optional_params(sagemaker_session): | |
sagemaker_session.transform(job_name=JOB_NAME, model_name='my-model', strategy=strategy, | ||
max_concurrent_transforms=max_concurrent_transforms, | ||
env=env, max_payload=max_payload, input_config={}, output_config={}, | ||
resource_config={}, tags=TAGS) | ||
resource_config={}, tags=TAGS, data_processing=None) | ||
|
||
_, _, actual_args = sagemaker_session.sagemaker_client.method_calls[0] | ||
assert actual_args['BatchStrategy'] == strategy | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The line
For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
if from the previous argumentThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch addressed in #864