Skip to content

Commit 065872a

Browse files
ajaykarpurknakad
authored andcommitted
Processing Jobs revision round 2 (#243)
* Rename ScriptModeProcessor to ScriptProcessor * Move arguments param down to run method in Processor classes * Remove py_version from ScriptProcessor and add command as a param to ScriptProcessor.run * Remove Enums in favor of string literals
1 parent 77d41e9 commit 065872a

File tree

7 files changed

+76
-161
lines changed

7 files changed

+76
-161
lines changed

src/sagemaker/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from sagemaker.model import Model, ModelPackage # noqa: F401
5151
from sagemaker.pipeline import PipelineModel # noqa: F401
5252
from sagemaker.predictor import RealTimePredictor # noqa: F401
53-
from sagemaker.processing import Processor, ScriptModeProcessor # noqa: F401
53+
from sagemaker.processing import Processor, ScriptProcessor # noqa: F401
5454
from sagemaker.session import Session # noqa: F401
5555
from sagemaker.session import container_def, pipeline_container_def # noqa: F401
5656
from sagemaker.session import production_variant # noqa: F401

src/sagemaker/processing.py

Lines changed: 54 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,7 @@
2424
from sagemaker.job import _Job
2525
from sagemaker.utils import base_name_from_image, name_from_base
2626
from sagemaker.session import Session
27-
from sagemaker.s3 import (
28-
S3CompressionType,
29-
S3DataDistributionType,
30-
S3DataType,
31-
S3DownloadMode,
32-
S3InputMode,
33-
S3UploadMode,
34-
S3Uploader,
35-
)
27+
from sagemaker.s3 import S3Uploader
3628
from sagemaker.network import NetworkConfig # noqa: F401 # pylint: disable=unused-import
3729

3830

@@ -46,7 +38,6 @@ def __init__(
4638
instance_count,
4739
instance_type,
4840
entrypoint=None,
49-
arguments=None,
5041
volume_size_in_gb=30,
5142
volume_kms_key=None,
5243
max_runtime_in_seconds=24 * 60 * 60,
@@ -72,8 +63,6 @@ def __init__(
7263
instance_type (str): Type of EC2 instance to use for
7364
processing, for example, 'ml.c4.xlarge'.
7465
entrypoint (str): The entrypoint for the processing job.
75-
arguments ([str]): A list of string arguments to be passed to a
76-
processing job.
7766
volume_size_in_gb (int): Size in GB of the EBS volume
7867
to use for storing data during processing (default: 30).
7968
volume_kms_key (str): A KMS key for the processing
@@ -99,7 +88,6 @@ def __init__(
9988
self.instance_count = instance_count
10089
self.instance_type = instance_type
10190
self.entrypoint = entrypoint
102-
self.arguments = arguments
10391
self.volume_size_in_gb = volume_size_in_gb
10492
self.volume_kms_key = volume_kms_key
10593
self.max_runtime_in_seconds = max_runtime_in_seconds
@@ -112,8 +100,9 @@ def __init__(
112100
self.jobs = []
113101
self.latest_job = None
114102
self._current_job_name = None
103+
self.arguments = None
115104

116-
def run(self, inputs=None, outputs=None, wait=True, logs=True, job_name=None):
105+
def run(self, inputs=None, outputs=None, arguments=None, wait=True, logs=True, job_name=None):
117106
"""Run a processing job.
118107
119108
Args:
@@ -122,6 +111,8 @@ def run(self, inputs=None, outputs=None, wait=True, logs=True, job_name=None):
122111
outputs ([sagemaker.processor.ProcessingOutput]): Outputs for the processing
123112
job. These can be specified as either a path string or a ProcessingOutput
124113
object.
114+
arguments ([str]): A list of string arguments to be passed to a
115+
processing job.
125116
wait (bool): Whether the call should wait until the job completes (default: True).
126117
logs (bool): Whether to show the logs produced by the job.
127118
Only meaningful when wait is True (default: True).
@@ -138,6 +129,7 @@ def run(self, inputs=None, outputs=None, wait=True, logs=True, job_name=None):
138129

139130
normalized_inputs = self._normalize_inputs(inputs)
140131
normalized_outputs = self._normalize_outputs(outputs)
132+
self.arguments = arguments
141133

142134
self.latest_job = ProcessingJob.start_new(self, normalized_inputs, normalized_outputs)
143135
self.jobs.append(self.latest_job)
@@ -243,7 +235,7 @@ def _normalize_outputs(self, outputs=None):
243235
return normalized_outputs
244236

245237

246-
class ScriptModeProcessor(Processor):
238+
class ScriptProcessor(Processor):
247239
"""Handles Amazon SageMaker processing tasks for jobs using a machine learning framework."""
248240

249241
def __init__(
@@ -252,8 +244,6 @@ def __init__(
252244
image_uri,
253245
instance_count,
254246
instance_type,
255-
py_version="py3",
256-
arguments=None,
257247
volume_size_in_gb=30,
258248
volume_kms_key=None,
259249
max_runtime_in_seconds=24 * 60 * 60,
@@ -263,7 +253,7 @@ def __init__(
263253
tags=None,
264254
network_config=None,
265255
):
266-
"""Initialize a ``ScriptModeProcessor`` instance. The ScriptModeProcessor
256+
"""Initialize a ``ScriptProcessor`` instance. The ScriptProcessor
267257
handles Amazon SageMaker processing tasks for jobs using script mode.
268258
269259
Args:
@@ -279,8 +269,6 @@ def __init__(
279269
instance_type (str): Type of EC2 instance to use for
280270
processing, for example, 'ml.c4.xlarge'.
281271
py_version (str): The python version to use, for example, 'py3'.
282-
arguments ([str]): A list of string arguments to be passed to a
283-
processing job.
284272
volume_size_in_gb (int): Size in GB of the EBS volume
285273
to use for storing data during processing (default: 30).
286274
volume_kms_key (str): A KMS key for the processing
@@ -301,16 +289,14 @@ def __init__(
301289
object that configures network isolation, encryption of
302290
inter-container traffic, security group IDs, and subnets.
303291
"""
304-
self.py_version = py_version
305292
self._CODE_CONTAINER_BASE_PATH = "/input/"
306293
self._CODE_CONTAINER_INPUT_NAME = "code"
307294

308-
super(ScriptModeProcessor, self).__init__(
295+
super(ScriptProcessor, self).__init__(
309296
role=role,
310297
image_uri=image_uri,
311298
instance_count=instance_count,
312299
instance_type=instance_type,
313-
arguments=arguments,
314300
volume_size_in_gb=volume_size_in_gb,
315301
volume_kms_key=volume_kms_key,
316302
max_runtime_in_seconds=max_runtime_in_seconds,
@@ -322,11 +308,22 @@ def __init__(
322308
)
323309

324310
def run(
325-
self, code, script_name=None, inputs=None, outputs=None, wait=True, logs=True, job_name=None
311+
self,
312+
command,
313+
code,
314+
script_name=None,
315+
inputs=None,
316+
outputs=None,
317+
arguments=None,
318+
wait=True,
319+
logs=True,
320+
job_name=None,
326321
):
327322
"""Run a processing job with Script Mode.
328323
329324
Args:
325+
command([str]): This is a list of strings that includes the executable, along
326+
with any command-line flags. For example: ["python3", "-v"]
330327
code (str): This can be an S3 uri or a local path to either
331328
a directory or a file with the user's script to run.
332329
script_name (str): If the user provides a directory for source,
@@ -337,6 +334,8 @@ def run(
337334
outputs ([str or sagemaker.processor.ProcessingOutput]): Outputs for the processing
338335
job. These can be specified as either a path string or a ProcessingOutput
339336
object.
337+
arguments ([str]): A list of string arguments to be passed to a
338+
processing job.
340339
wait (bool): Whether the call should wait until the job completes (default: True).
341340
logs (bool): Whether to show the logs produced by the job.
342341
Only meaningful when wait is True (default: True).
@@ -349,10 +348,15 @@ def run(
349348
customer_code_s3_uri = self._upload_code(code)
350349
inputs_with_code = self._convert_code_and_add_to_inputs(inputs, customer_code_s3_uri)
351350

352-
self._set_entrypoint(customer_script_name)
351+
self._set_entrypoint(command, customer_script_name)
353352

354-
super(ScriptModeProcessor, self).run(
355-
inputs=inputs_with_code, outputs=outputs, wait=wait, logs=logs, job_name=job_name
353+
super(ScriptProcessor, self).run(
354+
inputs=inputs_with_code,
355+
outputs=outputs,
356+
arguments=arguments,
357+
wait=wait,
358+
logs=logs,
359+
job_name=job_name,
356360
)
357361

358362
def _get_customer_script_name(self, code, script_name):
@@ -418,43 +422,16 @@ def _convert_code_and_add_to_inputs(self, inputs, s3_uri):
418422
the ProcessingInput object created from s3_uri appended to the list.
419423
420424
"""
421-
input_list = inputs
422425
code_file_input = ProcessingInput(
423426
source=s3_uri,
424427
destination=os.path.join(
425428
self._CODE_CONTAINER_BASE_PATH, self._CODE_CONTAINER_INPUT_NAME
426429
),
427430
input_name=self._CODE_CONTAINER_INPUT_NAME,
428431
)
429-
input_list.append(code_file_input)
430-
return input_list
431-
432-
def _get_execution_program(self, script_name):
433-
"""Determine which executable to run the user's script with
434-
based on the file extension.
432+
return inputs + [code_file_input]
435433

436-
Args:
437-
script_name (str): A filename with an extension.
438-
439-
Returns:
440-
str: A name of an executable to run the user's script with.
441-
"""
442-
file_extension = os.path.splitext(script_name)[1]
443-
if file_extension == ".py":
444-
if self.py_version == "py3":
445-
return "python3"
446-
if self.py_version == "py2":
447-
return "python2"
448-
return "python"
449-
if file_extension == ".sh":
450-
return "bash"
451-
raise ValueError(
452-
"""Script Mode supports Python or Bash scripts.
453-
To use a custom entrypoint, please use Processor.
454-
"""
455-
)
456-
457-
def _set_entrypoint(self, customer_script_name):
434+
def _set_entrypoint(self, command, customer_script_name):
458435
"""Sets the entrypoint based on the customer's script and corresponding executable.
459436
460437
Args:
@@ -463,8 +440,7 @@ def _set_entrypoint(self, customer_script_name):
463440
customer_script_location = os.path.join(
464441
self._CODE_CONTAINER_BASE_PATH, self._CODE_CONTAINER_INPUT_NAME, customer_script_name
465442
)
466-
execution_program = self._get_execution_program(customer_script_name)
467-
self.entrypoint = [execution_program, customer_script_location]
443+
self.entrypoint = command + [customer_script_location]
468444

469445

470446
class ProcessingJob(_Job):
@@ -564,11 +540,11 @@ def __init__(
564540
source,
565541
destination,
566542
input_name=None,
567-
s3_data_type=S3DataType.MANIFEST_FILE,
568-
s3_input_mode=S3InputMode.FILE,
569-
s3_download_mode=S3DownloadMode.CONTINUOUS,
570-
s3_data_distribution_type=S3DataDistributionType.FULLY_REPLICATED,
571-
s3_compression_type=S3CompressionType.NONE,
543+
s3_data_type="ManifestFile",
544+
s3_input_mode="File",
545+
s3_download_mode="Continuous",
546+
s3_data_distribution_type="FullyReplicated",
547+
s3_compression_type="None",
572548
):
573549
"""Initialize a ``ProcessingInput`` instance. ProcessingInput accepts parameters
574550
that specify an S3 input for a processing job and provides a method
@@ -579,11 +555,12 @@ def __init__(
579555
destination (str): The destination of the input.
580556
input_name (str): The user-provided name for the input. If a name
581557
is not provided, one will be generated.
582-
s3_data_type (sagemaker.s3.S3DataType):
583-
s3_input_mode (sagemaker.s3.S3InputMode):
584-
s3_download_mode (sagemaker.s3.S3DownloadMode):
585-
s3_data_distribution_type (sagemaker.s3.S3DataDistributionType):
586-
s3_compression_type (sagemaker.s3.S3CompressionType):
558+
s3_data_type (str): Valid options are "ManifestFile" or "S3Prefix".
559+
s3_input_mode (str): Valid options are "Pipe" or "File".
560+
s3_download_mode (str): Valid options are "StartOfJob" or "Continuous".
561+
s3_data_distribution_type (str): Valid options are "FullyReplicated"
562+
or "ShardedByS3Key".
563+
s3_compression_type (str): Valid options are "None" or "Gzip".
587564
"""
588565
self.source = source
589566
self.destination = destination
@@ -602,21 +579,18 @@ def to_request_dict(self):
602579
"S3Input": {
603580
"S3Uri": self.source,
604581
"LocalPath": self.destination,
605-
"S3DataType": self.s3_data_type.value,
606-
"S3InputMode": self.s3_input_mode.value,
607-
"S3DownloadMode": self.s3_download_mode.value,
608-
"S3DataDistributionType": self.s3_data_distribution_type.value,
582+
"S3DataType": self.s3_data_type,
583+
"S3InputMode": self.s3_input_mode,
584+
"S3DownloadMode": self.s3_download_mode,
585+
"S3DataDistributionType": self.s3_data_distribution_type,
609586
},
610587
}
611588

612589
# Check the compression type, then add it to the dictionary.
613-
if (
614-
self.s3_compression_type == S3CompressionType.GZIP
615-
and self.s3_input_mode != S3InputMode.PIPE
616-
):
590+
if self.s3_compression_type == "Gzip" and self.s3_input_mode != "Pipe":
617591
raise ValueError("Data can only be gzipped when the input mode is Pipe.")
618592
if self.s3_compression_type is not None:
619-
s3_input_request["S3Input"]["S3CompressionType"] = self.s3_compression_type.value
593+
s3_input_request["S3Input"]["S3CompressionType"] = self.s3_compression_type
620594

621595
# Return the request dictionary.
622596
return s3_input_request
@@ -627,12 +601,7 @@ class ProcessingOutput(object):
627601
a method to turn those parameters into a dictionary."""
628602

629603
def __init__(
630-
self,
631-
source,
632-
destination,
633-
output_name=None,
634-
kms_key_id=None,
635-
s3_upload_mode=S3UploadMode.CONTINUOUS,
604+
self, source, destination, output_name=None, kms_key_id=None, s3_upload_mode="Continuous"
636605
):
637606
"""Initialize a ``ProcessingOutput`` instance. ProcessingOutput accepts parameters that
638607
specify an S3 output for a processing job and provides a method to turn
@@ -643,7 +612,7 @@ def __init__(
643612
destination (str): The destination of the output.
644613
output_name (str): The name of the output.
645614
kms_key_id (str): The KMS key id for the output.
646-
s3_upload_mode (sagemaker.s3.S3UploadMode):
615+
s3_upload_mode (str): Valid options are "EndOfJob" or "Continuous".
647616
"""
648617
self.source = source
649618
self.destination = destination
@@ -659,7 +628,7 @@ def to_request_dict(self):
659628
"S3Output": {
660629
"S3Uri": self.destination,
661630
"LocalPath": self.source,
662-
"S3UploadMode": self.s3_upload_mode.value,
631+
"S3UploadMode": self.s3_upload_mode,
663632
},
664633
}
665634

src/sagemaker/s3.py

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -13,54 +13,10 @@
1313
"""This module contains Enums and helper methods related to S3."""
1414
from __future__ import print_function, absolute_import
1515

16-
from enum import Enum
17-
1816
from six.moves.urllib.parse import urlparse
1917
from sagemaker.session import Session
2018

2119

22-
class S3DataType(Enum):
23-
"""Provides enumerated values of S3 data types."""
24-
25-
MANIFEST_FILE = "ManifestFile"
26-
S3_PREFIX = "S3Prefix"
27-
28-
29-
class S3InputMode(Enum):
30-
"""Provides enumerated values of S3 input modes."""
31-
32-
PIPE = "Pipe"
33-
FILE = "File"
34-
35-
36-
class S3DownloadMode(Enum):
37-
"""Provides enumerated values of S3 download modes."""
38-
39-
START_OF_JOB = "StartOfJob"
40-
CONTINUOUS = "Continuous"
41-
42-
43-
class S3DataDistributionType(Enum):
44-
"""Provides enumerated values of S3 data distribution types."""
45-
46-
FULLY_REPLICATED = "FullyReplicated"
47-
SHARDED_BY_S3_KEY = "ShardedByS3Key"
48-
49-
50-
class S3CompressionType(Enum):
51-
"""Provides enumerated values of S3 compression types."""
52-
53-
NONE = "None"
54-
GZIP = "Gzip"
55-
56-
57-
class S3UploadMode(Enum):
58-
"""Provides enumerated values of S3 upload modes."""
59-
60-
END_OF_JOB = "EndOfJob"
61-
CONTINUOUS = "Continuous"
62-
63-
6420
def parse_s3_url(url):
6521
"""Returns an (s3 bucket, key name/prefix) tuple from a url with an s3
6622
scheme.

0 commit comments

Comments
 (0)