Skip to content

Commit 18e8d9c

Browse files
ajaykarpurknakad
authored andcommitted
Migrate to updated Processing Jobs API (#244)
* Update internal boto models * Rename Analytics to Processing in API calls and update docstrings * Make ProcessingOutput.destination optional * Default max_runtime_in_seconds to None * Default ProcessingOutput s3_upload_mode to EndOfJob * Default ProcessingInput s3_data_type to S3Prefix * Remove S3DownloadMode * Move output KMS key ID to ProcessingOutputConfig * Remove kms_key_id docstring from ProcessingOutput * Add unit test for sklearn with no additional inputs * Add unit test for ScriptProcessor * Change the method used to generate the image URI in SKLearnProcessor * Uncomment params in test_sklearn integ test
1 parent 065872a commit 18e8d9c

File tree

13 files changed

+3783
-960
lines changed

13 files changed

+3783
-960
lines changed

src/sagemaker/processing.py

Lines changed: 50 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ def __init__(
4040
entrypoint=None,
4141
volume_size_in_gb=30,
4242
volume_kms_key=None,
43-
max_runtime_in_seconds=24 * 60 * 60,
43+
output_kms_key=None,
44+
max_runtime_in_seconds=None,
4445
base_job_name=None,
4546
sagemaker_session=None,
4647
env=None,
@@ -67,9 +68,10 @@ def __init__(
6768
to use for storing data during processing (default: 30).
6869
volume_kms_key (str): A KMS key for the processing
6970
volume.
71+
output_kms_key (str): The KMS key id for all ProcessingOutputs.
7072
max_runtime_in_seconds (int): Timeout in seconds
71-
(default: 24 * 60 * 60). After this amount of time Amazon
72-
SageMaker terminates the job regardless of its current status.
73+
After this amount of time Amazon SageMaker terminates the job
74+
regardless of its current status.
7375
base_job_name (str): Prefix for processing name. If not specified,
7476
the processor generates a default job name, based on the
7577
training image name and current timestamp.
@@ -90,6 +92,7 @@ def __init__(
9092
self.entrypoint = entrypoint
9193
self.volume_size_in_gb = volume_size_in_gb
9294
self.volume_kms_key = volume_kms_key
95+
self.output_kms_key = output_kms_key
9396
self.max_runtime_in_seconds = max_runtime_in_seconds
9497
self.base_job_name = base_job_name
9598
self.sagemaker_session = sagemaker_session or Session()
@@ -106,9 +109,9 @@ def run(self, inputs=None, outputs=None, arguments=None, wait=True, logs=True, j
106109
"""Run a processing job.
107110
108111
Args:
109-
inputs ([sagemaker.processor.ProcessingInput]): Input files for the processing
112+
inputs ([sagemaker.processing.ProcessingInput]): Input files for the processing
110113
job. These must be provided as ProcessingInput objects.
111-
outputs ([sagemaker.processor.ProcessingOutput]): Outputs for the processing
114+
outputs ([sagemaker.processing.ProcessingOutput]): Outputs for the processing
112115
job. These can be specified as either a path string or a ProcessingOutput
113116
object.
114117
arguments ([str]): A list of string arguments to be passed to a
@@ -161,11 +164,11 @@ def _normalize_inputs(self, inputs=None):
161164
"""Ensure that all the ProcessingInput objects have names and S3 uris.
162165
163166
Args:
164-
inputs ([sagemaker.processor.ProcessingInput]): A list of ProcessingInput
167+
inputs ([sagemaker.processing.ProcessingInput]): A list of ProcessingInput
165168
objects to be normalized.
166169
167170
Returns:
168-
[sagemaker.processor.ProcessingInput]: The list of normalized
171+
[sagemaker.processing.ProcessingInput]: The list of normalized
169172
ProcessingInput objects.
170173
"""
171174
# Initialize a list of normalized ProcessingInput objects.
@@ -203,12 +206,12 @@ def _normalize_outputs(self, outputs=None):
203206
names and S3 uris.
204207
205208
Args:
206-
outputs ([sagemaker.processor.ProcessingOutput]): A list
209+
outputs ([sagemaker.processing.ProcessingOutput]): A list
207210
of outputs to be normalized. Can be either strings or
208211
ProcessingOutput objects.
209212
210213
Returns:
211-
[sagemaker.processor.ProcessingOutput]: The list of normalized
214+
[sagemaker.processing.ProcessingOutput]: The list of normalized
212215
ProcessingOutput objects.
213216
"""
214217
# Initialize a list of normalized ProcessingOutput objects.
@@ -246,7 +249,8 @@ def __init__(
246249
instance_type,
247250
volume_size_in_gb=30,
248251
volume_kms_key=None,
249-
max_runtime_in_seconds=24 * 60 * 60,
252+
output_kms_key=None,
253+
max_runtime_in_seconds=None,
250254
base_job_name=None,
251255
sagemaker_session=None,
252256
env=None,
@@ -273,9 +277,10 @@ def __init__(
273277
to use for storing data during processing (default: 30).
274278
volume_kms_key (str): A KMS key for the processing
275279
volume.
276-
max_runtime_in_seconds (int): Timeout in seconds
277-
(default: 24 * 60 * 60). After this amount of time Amazon
278-
SageMaker terminates the job regardless of its current status.
280+
output_kms_key (str): The KMS key id for all ProcessingOutputs.
281+
max_runtime_in_seconds (int): Timeout in seconds.
282+
After this amount of time Amazon SageMaker terminates the job
283+
regardless of its current status.
279284
base_job_name (str): Prefix for processing name. If not specified,
280285
the processor generates a default job name, based on the
281286
training image name and current timestamp.
@@ -299,6 +304,7 @@ def __init__(
299304
instance_type=instance_type,
300305
volume_size_in_gb=volume_size_in_gb,
301306
volume_kms_key=volume_kms_key,
307+
output_kms_key=output_kms_key,
302308
max_runtime_in_seconds=max_runtime_in_seconds,
303309
base_job_name=base_job_name,
304310
sagemaker_session=sagemaker_session,
@@ -329,9 +335,9 @@ def run(
329335
script_name (str): If the user provides a directory for source,
330336
they must specify script_name as the file within that
331337
directory to use.
332-
inputs ([sagemaker.processor.ProcessingInput]): Input files for the processing
338+
inputs ([sagemaker.processing.ProcessingInput]): Input files for the processing
333339
job. These must be provided as ProcessingInput objects.
334-
outputs ([str or sagemaker.processor.ProcessingOutput]): Outputs for the processing
340+
outputs ([str or sagemaker.processing.ProcessingOutput]): Outputs for the processing
335341
job. These can be specified as either a path string or a ProcessingOutput
336342
object.
337343
arguments ([str]): A list of string arguments to be passed to a
@@ -414,11 +420,11 @@ def _convert_code_and_add_to_inputs(self, inputs, s3_uri):
414420
"""Creates a ProcessingInput object from an S3 uri and adds it to the list of inputs.
415421
416422
Args:
417-
inputs ([sagemaker.processor.ProcessingInput]): List of ProcessingInput objects.
423+
inputs ([sagemaker.processing.ProcessingInput]): List of ProcessingInput objects.
418424
s3_uri (str): S3 uri of the input to be added to inputs.
419425
420426
Returns:
421-
[sagemaker.processor.ProcessingInput]: A new list of ProcessingInput objects, with
427+
[sagemaker.processing.ProcessingInput]: A new list of ProcessingInput objects, with
422428
the ProcessingInput object created from s3_uri appended to the list.
423429
424430
"""
@@ -429,7 +435,7 @@ def _convert_code_and_add_to_inputs(self, inputs, s3_uri):
429435
),
430436
input_name=self._CODE_CONTAINER_INPUT_NAME,
431437
)
432-
return inputs + [code_file_input]
438+
return (inputs or []) + [code_file_input]
433439

434440
def _set_entrypoint(self, command, customer_script_name):
435441
"""Sets the entrypoint based on the customer's script and corresponding executable.
@@ -458,8 +464,8 @@ def start_new(cls, processor, inputs, outputs):
458464
Args:
459465
processor (sagemaker.processing.Processor): The Processor instance
460466
that started the job.
461-
inputs ([sagemaker.processor.ProcessingInput]): A list of ProcessingInput objects.
462-
outputs ([sagemaker.processor.ProcessingOutput]): A list of ProcessingOutput objects.
467+
inputs ([sagemaker.processing.ProcessingInput]): A list of ProcessingInput objects.
468+
outputs ([sagemaker.processing.ProcessingOutput]): A list of ProcessingOutput objects.
463469
464470
Returns:
465471
sagemaker.processing.ProcessingJob: The instance of ProcessingJob created
@@ -471,35 +477,51 @@ def start_new(cls, processor, inputs, outputs):
471477

472478
# Add arguments to the dictionary.
473479
process_request_args["inputs"] = [input.to_request_dict() for input in inputs]
474-
process_request_args["outputs"] = [output.to_request_dict() for output in outputs]
480+
481+
process_request_args["output_config"] = {
482+
"Outputs": [output.to_request_dict() for output in outputs]
483+
}
484+
if processor.output_kms_key is not None:
485+
process_request_args["output_config"]["KmsKeyId"] = processor.output_kms_key
486+
475487
process_request_args["job_name"] = processor._current_job_name
488+
476489
process_request_args["resources"] = {
477490
"ClusterConfig": {
478491
"InstanceType": processor.instance_type,
479492
"InstanceCount": processor.instance_count,
480493
"VolumeSizeInGB": processor.volume_size_in_gb,
481494
}
482495
}
483-
process_request_args["stopping_condition"] = {
484-
"MaxRuntimeInSeconds": processor.max_runtime_in_seconds
485-
}
496+
497+
if processor.max_runtime_in_seconds is not None:
498+
process_request_args["stopping_condition"] = {
499+
"MaxRuntimeInSeconds": processor.max_runtime_in_seconds
500+
}
501+
else:
502+
process_request_args["stopping_condition"] = None
503+
486504
process_request_args["app_specification"] = {"ImageUri": processor.image_uri}
487505
if processor.arguments is not None:
488506
process_request_args["app_specification"]["ContainerArguments"] = processor.arguments
489507
if processor.entrypoint is not None:
490508
process_request_args["app_specification"]["ContainerEntrypoint"] = processor.entrypoint
509+
491510
process_request_args["environment"] = processor.env
511+
492512
if processor.network_config is not None:
493513
process_request_args["network_config"] = processor.network_config.to_request_dict()
494514
else:
495515
process_request_args["network_config"] = None
516+
496517
process_request_args["role_arn"] = processor.role
518+
497519
process_request_args["tags"] = processor.tags
498520

499521
# Print the job name and the user's inputs and outputs as lists of dictionaries.
500522
print("Job Name: ", process_request_args["job_name"])
501523
print("Inputs: ", process_request_args["inputs"])
502-
print("Outputs: ", process_request_args["outputs"])
524+
print("Outputs: ", process_request_args["output_config"]["Outputs"])
503525

504526
# Call sagemaker_session.process using the arguments dictionary.
505527
processor.sagemaker_session.process(**process_request_args)
@@ -521,7 +543,7 @@ def wait(self, logs=True):
521543

522544
def describe(self, print_response=True):
523545
"""Prints out a response from the DescribeProcessingJob API call."""
524-
describe_response = self.sagemaker_session.describe_analytics_job(self.job_name)
546+
describe_response = self.sagemaker_session.describe_processing_job(self.job_name)
525547
if print_response:
526548
print(describe_response)
527549
return describe_response
@@ -540,9 +562,8 @@ def __init__(
540562
source,
541563
destination,
542564
input_name=None,
543-
s3_data_type="ManifestFile",
565+
s3_data_type="S3Prefix",
544566
s3_input_mode="File",
545-
s3_download_mode="Continuous",
546567
s3_data_distribution_type="FullyReplicated",
547568
s3_compression_type="None",
548569
):
@@ -557,7 +578,6 @@ def __init__(
557578
is not provided, one will be generated.
558579
s3_data_type (str): Valid options are "ManifestFile" or "S3Prefix".
559580
s3_input_mode (str): Valid options are "Pipe" or "File".
560-
s3_download_mode (str): Valid options are "StartOfJob" or "Continuous".
561581
s3_data_distribution_type (str): Valid options are "FullyReplicated"
562582
or "ShardedByS3Key".
563583
s3_compression_type (str): Valid options are "None" or "Gzip".
@@ -567,7 +587,6 @@ def __init__(
567587
self.input_name = input_name
568588
self.s3_data_type = s3_data_type
569589
self.s3_input_mode = s3_input_mode
570-
self.s3_download_mode = s3_download_mode
571590
self.s3_data_distribution_type = s3_data_distribution_type
572591
self.s3_compression_type = s3_compression_type
573592

@@ -581,7 +600,6 @@ def to_request_dict(self):
581600
"LocalPath": self.destination,
582601
"S3DataType": self.s3_data_type,
583602
"S3InputMode": self.s3_input_mode,
584-
"S3DownloadMode": self.s3_download_mode,
585603
"S3DataDistributionType": self.s3_data_distribution_type,
586604
},
587605
}
@@ -600,9 +618,7 @@ class ProcessingOutput(object):
600618
"""Accepts parameters that specify an S3 output for a processing job and provides
601619
a method to turn those parameters into a dictionary."""
602620

603-
def __init__(
604-
self, source, destination, output_name=None, kms_key_id=None, s3_upload_mode="Continuous"
605-
):
621+
def __init__(self, source, destination=None, output_name=None, s3_upload_mode="EndOfJob"):
606622
"""Initialize a ``ProcessingOutput`` instance. ProcessingOutput accepts parameters that
607623
specify an S3 output for a processing job and provides a method to turn
608624
those parameters into a dictionary.
@@ -611,13 +627,11 @@ def __init__(
611627
source (str): The source for the output.
612628
destination (str): The destination of the output.
613629
output_name (str): The name of the output.
614-
kms_key_id (str): The KMS key id for the output.
615630
s3_upload_mode (str): Valid options are "EndOfJob" or "Continuous".
616631
"""
617632
self.source = source
618633
self.destination = destination
619634
self.output_name = output_name
620-
self.kms_key_id = kms_key_id
621635
self.s3_upload_mode = s3_upload_mode
622636

623637
def to_request_dict(self):
@@ -632,9 +646,5 @@ def to_request_dict(self):
632646
},
633647
}
634648

635-
# Check the KMS key ID, then add it to the dictionary.
636-
if self.kms_key_id is not None:
637-
s3_output_request["S3Output"]["KmsKeyId"] = self.kms_key_id
638-
639649
# Return the request dictionary.
640650
return s3_output_request

0 commit comments

Comments
 (0)