Skip to content

feature: Add DataProcessing Fields for Batch Transform #827

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def stop_tuning_job(self, name):
raise

def transform(self, job_name, model_name, strategy, max_concurrent_transforms, max_payload, env,
input_config, output_config, resource_config, tags):
input_config, output_config, resource_config, tags, data_processing):
"""Create an Amazon SageMaker transform job.

Args:
Expand All @@ -514,8 +514,9 @@ def transform(self, job_name, model_name, strategy, max_concurrent_transforms, m
input_config (dict): A dictionary describing the input data (and its location) for the job.
output_config (dict): A dictionary describing the output location for the job.
resource_config (dict): A dictionary describing the resources to complete the job.
tags (list[dict]): List of tags for labeling a training job. For more, see
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
tags (list[dict]): List of tags for labeling a transform job.
data_processing(dict): A dictionary describing config for combining the input data and transformed data.
For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The line For more, see https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html. if from the previous argument

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch addressed in #864

"""
transform_request = {
'TransformJobName': job_name,
Expand All @@ -540,6 +541,9 @@ def transform(self, job_name, model_name, strategy, max_concurrent_transforms, m
if tags is not None:
transform_request['Tags'] = tags

if data_processing is not None:
transform_request['DataProcessing'] = data_processing

LOGGER.info('Creating transform job with name: {}'.format(job_name))
LOGGER.debug('Transform request: {}'.format(json.dumps(transform_request, indent=4)))
self.sagemaker_client.create_transform_job(**transform_request)
Expand Down
38 changes: 34 additions & 4 deletions src/sagemaker/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(self, model_name, instance_count, instance_type, strategy=None, ass
self.sagemaker_session = sagemaker_session or Session()

def transform(self, data, data_type='S3Prefix', content_type=None, compression_type=None, split_type=None,
job_name=None):
job_name=None, input_filter=None, output_filter=None, join_source=None):
"""Start a new transform job.

Args:
Expand All @@ -97,6 +97,15 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t
split_type (str): The record delimiter for the input object (default: 'None').
Valid values: 'None', 'Line', 'RecordIO', and 'TFRecord'.
job_name (str): job name (default: None). If not specified, one will be generated.
input_filter (str): A JSONPath to select a portion of the input to pass to the algorithm container for
inference. If you omit the field, it gets the value '$', representing the entire input.
Some examples: "$[1:]", "$.features"(default: None).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a link here for where to find additional information.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I just linked to Top level API for now as docs are not public yet

output_filter (str): A JSONPath to select a portion of the joined/original output to return as the output.
Some examples: "$[1:]", "$.prediction" (default: None).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

join_source (str): The source of data to be joined to the transform output. It can be set to 'Input'
meaning the entire input record will be joined to the inference result.
You can use OutputFilter to select the useful portion before uploading to S3. (default: None).
Valid values: Input, None.
"""
local_mode = self.sagemaker_session.local_mode
if not local_mode and not data.startswith('s3://'):
Expand All @@ -116,7 +125,7 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t
self.output_path = 's3://{}/{}'.format(self.sagemaker_session.default_bucket(), self._current_job_name)

self.latest_transform_job = _TransformJob.start_new(self, data, data_type, content_type, compression_type,
split_type)
split_type, input_filter, output_filter, join_source)

def delete_model(self):
"""Delete the corresponding SageMaker model for this Transformer.
Expand Down Expand Up @@ -214,16 +223,19 @@ def _prepare_init_params_from_job_description(cls, job_details):

class _TransformJob(_Job):
@classmethod
def start_new(cls, transformer, data, data_type, content_type, compression_type, split_type):
def start_new(cls, transformer, data, data_type, content_type, compression_type,
split_type, input_filter, output_filter, join_source):
config = _TransformJob._load_config(data, data_type, content_type, compression_type, split_type, transformer)
data_processing = _TransformJob._prepare_data_processing(input_filter, output_filter, join_source)

transformer.sagemaker_session.transform(job_name=transformer._current_job_name,
model_name=transformer.model_name, strategy=transformer.strategy,
max_concurrent_transforms=transformer.max_concurrent_transforms,
max_payload=transformer.max_payload, env=transformer.env,
input_config=config['input_config'],
output_config=config['output_config'],
resource_config=config['resource_config'], tags=transformer.tags)
resource_config=config['resource_config'],
tags=transformer.tags, data_processing=data_processing)

return cls(transformer.sagemaker_session, transformer._current_job_name)

Expand Down Expand Up @@ -287,3 +299,21 @@ def _prepare_resource_config(instance_count, instance_type, volume_kms_key):
config['VolumeKmsKeyId'] = volume_kms_key

return config

@staticmethod
def _prepare_data_processing(input_filter, output_filter, join_source):
config = {}

if input_filter is not None:
config['InputFilter'] = input_filter

if output_filter is not None:
config['OutputFilter'] = output_filter

if join_source is not None:
config['JoinSource'] = join_source

if len(config) == 0:
return None

return config
2 changes: 1 addition & 1 deletion tests/integ/kms_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _create_kms_key(kms_client,
role_arn=role_arn,
sagemaker_role=sagemaker_role)
else:
principal = "{account_id}".format(account_id=account_id)
principal = '"{account_id}"'.format(account_id=account_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any special reason to include this change on this PR?


response = kms_client.create_key(
Policy=KEY_POLICY.format(id=POLICY_NAME, principal=principal, sagemaker_role=sagemaker_role),
Expand Down
12 changes: 9 additions & 3 deletions tests/integ/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,19 @@ def test_transform_mxnet(sagemaker_session, mxnet_full_version):
key_prefix=transform_input_key_prefix)

kms_key_arn = get_or_create_kms_key(sagemaker_session)
output_filter = "$"

transformer = _create_transformer_and_transform_job(mx, transform_input, kms_key_arn)
transformer = _create_transformer_and_transform_job(mx, transform_input, kms_key_arn,
input_filter=None, output_filter=output_filter,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add an input_filter and and join_source on this test as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added input_filter. I am setting json_source to None as it allows me to not change output of the test

join_source=None)
with timeout_and_delete_model_with_transformer(transformer, sagemaker_session,
minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES):
transformer.wait()

job_desc = transformer.sagemaker_session.sagemaker_client.describe_transform_job(
TransformJobName=transformer.latest_transform_job.name)
assert kms_key_arn == job_desc['TransformResources']['VolumeKmsKeyId']
assert output_filter == job_desc['DataProcessing']['OutputFilter']


@pytest.mark.canary_quick
Expand Down Expand Up @@ -232,7 +236,9 @@ def test_transform_byo_estimator(sagemaker_session):
assert tags == model_tags


def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None):
def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None,
input_filter=None, output_filter=None, join_source=None):
transformer = estimator.transformer(1, 'ml.m4.xlarge', volume_kms_key=volume_kms_key)
transformer.transform(transform_input, content_type='text/csv')
transformer.transform(transform_input, content_type='text/csv', input_filter=input_filter,
output_filter=output_filter, join_source=join_source)
return transformer
4 changes: 2 additions & 2 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def test_transform_pack_to_request(sagemaker_session):

sagemaker_session.transform(job_name=JOB_NAME, model_name=model_name, strategy=None, max_concurrent_transforms=None,
max_payload=None, env=None, input_config=in_config, output_config=out_config,
resource_config=resource_config, tags=None)
resource_config=resource_config, tags=None, data_processing=None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to unit test the transform function with different data_processing use cases.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch. done. There are two tests 1 with required parameter and other with optional parameter to cover different use case. I used them accordingly #864


_, _, actual_args = sagemaker_session.sagemaker_client.method_calls[0]
assert actual_args == expected_args
Expand All @@ -603,7 +603,7 @@ def test_transform_pack_to_request_with_optional_params(sagemaker_session):
sagemaker_session.transform(job_name=JOB_NAME, model_name='my-model', strategy=strategy,
max_concurrent_transforms=max_concurrent_transforms,
env=env, max_payload=max_payload, input_config={}, output_config={},
resource_config={}, tags=TAGS)
resource_config={}, tags=TAGS, data_processing=None)

_, _, actual_args = sagemaker_session.sagemaker_client.method_calls[0]
assert actual_args['BatchStrategy'] == strategy
Expand Down
29 changes: 26 additions & 3 deletions tests/unit/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,18 @@ def test_transform_with_all_params(start_new_job, transformer):
content_type = 'text/csv'
compression = 'Gzip'
split = 'Line'
input_filter = "$.feature"
output_filter = "$['sagemaker_output', 'id']"
join_source = "Input"

transformer.transform(DATA, S3_DATA_TYPE, content_type=content_type, compression_type=compression, split_type=split,
job_name=JOB_NAME)
job_name=JOB_NAME, input_filter=input_filter, output_filter=output_filter,
join_source=join_source)

assert transformer._current_job_name == JOB_NAME
assert transformer.output_path == OUTPUT_PATH
start_new_job.assert_called_once_with(transformer, DATA, S3_DATA_TYPE, content_type, compression, split)
start_new_job.assert_called_once_with(transformer, DATA, S3_DATA_TYPE, content_type, compression,
split, input_filter, output_filter, join_source)


@patch('sagemaker.transformer.name_from_base')
Expand Down Expand Up @@ -300,7 +305,8 @@ def test_start_new(transformer, sagemaker_session):
transformer._current_job_name = JOB_NAME

job = _TransformJob(sagemaker_session, JOB_NAME)
started_job = job.start_new(transformer, DATA, S3_DATA_TYPE, None, None, None)
started_job = job.start_new(transformer, DATA, S3_DATA_TYPE, None, None, None,
None, None, None)

assert started_job.sagemaker_session == sagemaker_session
sagemaker_session.transform.assert_called_once()
Expand Down Expand Up @@ -392,6 +398,23 @@ def test_prepare_resource_config():
assert config == {'InstanceCount': INSTANCE_COUNT, 'InstanceType': INSTANCE_TYPE, 'VolumeKmsKeyId': KMS_KEY_ID}


def test_data_processing_config():
actual_config = _TransformJob._prepare_data_processing("$", None, None)
assert actual_config == {'InputFilter': "$"}

actual_config = _TransformJob._prepare_data_processing(None, "$", None)
assert actual_config == {'OutputFilter': "$"}

actual_config = _TransformJob._prepare_data_processing(None, None, "Input")
assert actual_config == {'JoinSource': "Input"}

actual_config = _TransformJob._prepare_data_processing("$[0]", "$[1]", "Input")
assert actual_config == {'InputFilter': "$[0]", 'OutputFilter': "$[1]", 'JoinSource': "Input"}

actual_config = _TransformJob._prepare_data_processing(None, None, None)
assert actual_config is None


def test_transform_job_wait(sagemaker_session):
job = _TransformJob(sagemaker_session, JOB_NAME)
job.wait()
Expand Down