24
24
from sagemaker .job import _Job
25
25
from sagemaker .utils import base_name_from_image , name_from_base
26
26
from sagemaker .session import Session
27
- from sagemaker .s3 import (
28
- S3CompressionType ,
29
- S3DataDistributionType ,
30
- S3DataType ,
31
- S3DownloadMode ,
32
- S3InputMode ,
33
- S3UploadMode ,
34
- S3Uploader ,
35
- )
27
+ from sagemaker .s3 import S3Uploader
36
28
from sagemaker .network import NetworkConfig # noqa: F401 # pylint: disable=unused-import
37
29
38
30
@@ -46,7 +38,6 @@ def __init__(
46
38
instance_count ,
47
39
instance_type ,
48
40
entrypoint = None ,
49
- arguments = None ,
50
41
volume_size_in_gb = 30 ,
51
42
volume_kms_key = None ,
52
43
max_runtime_in_seconds = 24 * 60 * 60 ,
@@ -72,8 +63,6 @@ def __init__(
72
63
instance_type (str): Type of EC2 instance to use for
73
64
processing, for example, 'ml.c4.xlarge'.
74
65
entrypoint (str): The entrypoint for the processing job.
75
- arguments ([str]): A list of string arguments to be passed to a
76
- processing job.
77
66
volume_size_in_gb (int): Size in GB of the EBS volume
78
67
to use for storing data during processing (default: 30).
79
68
volume_kms_key (str): A KMS key for the processing
@@ -99,7 +88,6 @@ def __init__(
99
88
self .instance_count = instance_count
100
89
self .instance_type = instance_type
101
90
self .entrypoint = entrypoint
102
- self .arguments = arguments
103
91
self .volume_size_in_gb = volume_size_in_gb
104
92
self .volume_kms_key = volume_kms_key
105
93
self .max_runtime_in_seconds = max_runtime_in_seconds
@@ -112,8 +100,9 @@ def __init__(
112
100
self .jobs = []
113
101
self .latest_job = None
114
102
self ._current_job_name = None
103
+ self .arguments = None
115
104
116
- def run (self , inputs = None , outputs = None , wait = True , logs = True , job_name = None ):
105
+ def run (self , inputs = None , outputs = None , arguments = None , wait = True , logs = True , job_name = None ):
117
106
"""Run a processing job.
118
107
119
108
Args:
@@ -122,6 +111,8 @@ def run(self, inputs=None, outputs=None, wait=True, logs=True, job_name=None):
122
111
outputs ([sagemaker.processor.ProcessingOutput]): Outputs for the processing
123
112
job. These can be specified as either a path string or a ProcessingOutput
124
113
object.
114
+ arguments ([str]): A list of string arguments to be passed to a
115
+ processing job.
125
116
wait (bool): Whether the call should wait until the job completes (default: True).
126
117
logs (bool): Whether to show the logs produced by the job.
127
118
Only meaningful when wait is True (default: True).
@@ -138,6 +129,7 @@ def run(self, inputs=None, outputs=None, wait=True, logs=True, job_name=None):
138
129
139
130
normalized_inputs = self ._normalize_inputs (inputs )
140
131
normalized_outputs = self ._normalize_outputs (outputs )
132
+ self .arguments = arguments
141
133
142
134
self .latest_job = ProcessingJob .start_new (self , normalized_inputs , normalized_outputs )
143
135
self .jobs .append (self .latest_job )
@@ -243,7 +235,7 @@ def _normalize_outputs(self, outputs=None):
243
235
return normalized_outputs
244
236
245
237
246
- class ScriptModeProcessor (Processor ):
238
+ class ScriptProcessor (Processor ):
247
239
"""Handles Amazon SageMaker processing tasks for jobs using a machine learning framework."""
248
240
249
241
def __init__ (
@@ -252,8 +244,6 @@ def __init__(
252
244
image_uri ,
253
245
instance_count ,
254
246
instance_type ,
255
- py_version = "py3" ,
256
- arguments = None ,
257
247
volume_size_in_gb = 30 ,
258
248
volume_kms_key = None ,
259
249
max_runtime_in_seconds = 24 * 60 * 60 ,
@@ -263,7 +253,7 @@ def __init__(
263
253
tags = None ,
264
254
network_config = None ,
265
255
):
266
- """Initialize a ``ScriptModeProcessor `` instance. The ScriptModeProcessor
256
+ """Initialize a ``ScriptProcessor `` instance. The ScriptProcessor
267
257
handles Amazon SageMaker processing tasks for jobs using script mode.
268
258
269
259
Args:
@@ -279,8 +269,6 @@ def __init__(
279
269
instance_type (str): Type of EC2 instance to use for
280
270
processing, for example, 'ml.c4.xlarge'.
281
271
py_version (str): The python version to use, for example, 'py3'.
282
- arguments ([str]): A list of string arguments to be passed to a
283
- processing job.
284
272
volume_size_in_gb (int): Size in GB of the EBS volume
285
273
to use for storing data during processing (default: 30).
286
274
volume_kms_key (str): A KMS key for the processing
@@ -301,16 +289,14 @@ def __init__(
301
289
object that configures network isolation, encryption of
302
290
inter-container traffic, security group IDs, and subnets.
303
291
"""
304
- self .py_version = py_version
305
292
self ._CODE_CONTAINER_BASE_PATH = "/input/"
306
293
self ._CODE_CONTAINER_INPUT_NAME = "code"
307
294
308
- super (ScriptModeProcessor , self ).__init__ (
295
+ super (ScriptProcessor , self ).__init__ (
309
296
role = role ,
310
297
image_uri = image_uri ,
311
298
instance_count = instance_count ,
312
299
instance_type = instance_type ,
313
- arguments = arguments ,
314
300
volume_size_in_gb = volume_size_in_gb ,
315
301
volume_kms_key = volume_kms_key ,
316
302
max_runtime_in_seconds = max_runtime_in_seconds ,
@@ -322,11 +308,22 @@ def __init__(
322
308
)
323
309
324
310
def run (
325
- self , code , script_name = None , inputs = None , outputs = None , wait = True , logs = True , job_name = None
311
+ self ,
312
+ command ,
313
+ code ,
314
+ script_name = None ,
315
+ inputs = None ,
316
+ outputs = None ,
317
+ arguments = None ,
318
+ wait = True ,
319
+ logs = True ,
320
+ job_name = None ,
326
321
):
327
322
"""Run a processing job with Script Mode.
328
323
329
324
Args:
325
+ command([str]): This is a list of strings that includes the executable, along
326
+ with any command-line flags. For example: ["python3", "-v"]
330
327
code (str): This can be an S3 uri or a local path to either
331
328
a directory or a file with the user's script to run.
332
329
script_name (str): If the user provides a directory for source,
@@ -337,6 +334,8 @@ def run(
337
334
outputs ([str or sagemaker.processor.ProcessingOutput]): Outputs for the processing
338
335
job. These can be specified as either a path string or a ProcessingOutput
339
336
object.
337
+ arguments ([str]): A list of string arguments to be passed to a
338
+ processing job.
340
339
wait (bool): Whether the call should wait until the job completes (default: True).
341
340
logs (bool): Whether to show the logs produced by the job.
342
341
Only meaningful when wait is True (default: True).
@@ -349,10 +348,15 @@ def run(
349
348
customer_code_s3_uri = self ._upload_code (code )
350
349
inputs_with_code = self ._convert_code_and_add_to_inputs (inputs , customer_code_s3_uri )
351
350
352
- self ._set_entrypoint (customer_script_name )
351
+ self ._set_entrypoint (command , customer_script_name )
353
352
354
- super (ScriptModeProcessor , self ).run (
355
- inputs = inputs_with_code , outputs = outputs , wait = wait , logs = logs , job_name = job_name
353
+ super (ScriptProcessor , self ).run (
354
+ inputs = inputs_with_code ,
355
+ outputs = outputs ,
356
+ arguments = arguments ,
357
+ wait = wait ,
358
+ logs = logs ,
359
+ job_name = job_name ,
356
360
)
357
361
358
362
def _get_customer_script_name (self , code , script_name ):
@@ -418,43 +422,16 @@ def _convert_code_and_add_to_inputs(self, inputs, s3_uri):
418
422
the ProcessingInput object created from s3_uri appended to the list.
419
423
420
424
"""
421
- input_list = inputs
422
425
code_file_input = ProcessingInput (
423
426
source = s3_uri ,
424
427
destination = os .path .join (
425
428
self ._CODE_CONTAINER_BASE_PATH , self ._CODE_CONTAINER_INPUT_NAME
426
429
),
427
430
input_name = self ._CODE_CONTAINER_INPUT_NAME ,
428
431
)
429
- input_list .append (code_file_input )
430
- return input_list
431
-
432
- def _get_execution_program (self , script_name ):
433
- """Determine which executable to run the user's script with
434
- based on the file extension.
432
+ return inputs + [code_file_input ]
435
433
436
- Args:
437
- script_name (str): A filename with an extension.
438
-
439
- Returns:
440
- str: A name of an executable to run the user's script with.
441
- """
442
- file_extension = os .path .splitext (script_name )[1 ]
443
- if file_extension == ".py" :
444
- if self .py_version == "py3" :
445
- return "python3"
446
- if self .py_version == "py2" :
447
- return "python2"
448
- return "python"
449
- if file_extension == ".sh" :
450
- return "bash"
451
- raise ValueError (
452
- """Script Mode supports Python or Bash scripts.
453
- To use a custom entrypoint, please use Processor.
454
- """
455
- )
456
-
457
- def _set_entrypoint (self , customer_script_name ):
434
+ def _set_entrypoint (self , command , customer_script_name ):
458
435
"""Sets the entrypoint based on the customer's script and corresponding executable.
459
436
460
437
Args:
@@ -463,8 +440,7 @@ def _set_entrypoint(self, customer_script_name):
463
440
customer_script_location = os .path .join (
464
441
self ._CODE_CONTAINER_BASE_PATH , self ._CODE_CONTAINER_INPUT_NAME , customer_script_name
465
442
)
466
- execution_program = self ._get_execution_program (customer_script_name )
467
- self .entrypoint = [execution_program , customer_script_location ]
443
+ self .entrypoint = command + [customer_script_location ]
468
444
469
445
470
446
class ProcessingJob (_Job ):
@@ -564,11 +540,11 @@ def __init__(
564
540
source ,
565
541
destination ,
566
542
input_name = None ,
567
- s3_data_type = S3DataType . MANIFEST_FILE ,
568
- s3_input_mode = S3InputMode . FILE ,
569
- s3_download_mode = S3DownloadMode . CONTINUOUS ,
570
- s3_data_distribution_type = S3DataDistributionType . FULLY_REPLICATED ,
571
- s3_compression_type = S3CompressionType . NONE ,
543
+ s3_data_type = "ManifestFile" ,
544
+ s3_input_mode = "File" ,
545
+ s3_download_mode = "Continuous" ,
546
+ s3_data_distribution_type = "FullyReplicated" ,
547
+ s3_compression_type = "None" ,
572
548
):
573
549
"""Initialize a ``ProcessingInput`` instance. ProcessingInput accepts parameters
574
550
that specify an S3 input for a processing job and provides a method
@@ -579,11 +555,12 @@ def __init__(
579
555
destination (str): The destination of the input.
580
556
input_name (str): The user-provided name for the input. If a name
581
557
is not provided, one will be generated.
582
- s3_data_type (sagemaker.s3.S3DataType):
583
- s3_input_mode (sagemaker.s3.S3InputMode):
584
- s3_download_mode (sagemaker.s3.S3DownloadMode):
585
- s3_data_distribution_type (sagemaker.s3.S3DataDistributionType):
586
- s3_compression_type (sagemaker.s3.S3CompressionType):
558
+ s3_data_type (str): Valid options are "ManifestFile" or "S3Prefix".
559
+ s3_input_mode (str): Valid options are "Pipe" or "File".
560
+ s3_download_mode (str): Valid options are "StartOfJob" or "Continuous".
561
+ s3_data_distribution_type (str): Valid options are "FullyReplicated"
562
+ or "ShardedByS3Key".
563
+ s3_compression_type (str): Valid options are "None" or "Gzip".
587
564
"""
588
565
self .source = source
589
566
self .destination = destination
@@ -602,21 +579,18 @@ def to_request_dict(self):
602
579
"S3Input" : {
603
580
"S3Uri" : self .source ,
604
581
"LocalPath" : self .destination ,
605
- "S3DataType" : self .s3_data_type . value ,
606
- "S3InputMode" : self .s3_input_mode . value ,
607
- "S3DownloadMode" : self .s3_download_mode . value ,
608
- "S3DataDistributionType" : self .s3_data_distribution_type . value ,
582
+ "S3DataType" : self .s3_data_type ,
583
+ "S3InputMode" : self .s3_input_mode ,
584
+ "S3DownloadMode" : self .s3_download_mode ,
585
+ "S3DataDistributionType" : self .s3_data_distribution_type ,
609
586
},
610
587
}
611
588
612
589
# Check the compression type, then add it to the dictionary.
613
- if (
614
- self .s3_compression_type == S3CompressionType .GZIP
615
- and self .s3_input_mode != S3InputMode .PIPE
616
- ):
590
+ if self .s3_compression_type == "Gzip" and self .s3_input_mode != "Pipe" :
617
591
raise ValueError ("Data can only be gzipped when the input mode is Pipe." )
618
592
if self .s3_compression_type is not None :
619
- s3_input_request ["S3Input" ]["S3CompressionType" ] = self .s3_compression_type . value
593
+ s3_input_request ["S3Input" ]["S3CompressionType" ] = self .s3_compression_type
620
594
621
595
# Return the request dictionary.
622
596
return s3_input_request
@@ -627,12 +601,7 @@ class ProcessingOutput(object):
627
601
a method to turn those parameters into a dictionary."""
628
602
629
603
def __init__ (
630
- self ,
631
- source ,
632
- destination ,
633
- output_name = None ,
634
- kms_key_id = None ,
635
- s3_upload_mode = S3UploadMode .CONTINUOUS ,
604
+ self , source , destination , output_name = None , kms_key_id = None , s3_upload_mode = "Continuous"
636
605
):
637
606
"""Initialize a ``ProcessingOutput`` instance. ProcessingOutput accepts parameters that
638
607
specify an S3 output for a processing job and provides a method to turn
@@ -643,7 +612,7 @@ def __init__(
643
612
destination (str): The destination of the output.
644
613
output_name (str): The name of the output.
645
614
kms_key_id (str): The KMS key id for the output.
646
- s3_upload_mode (sagemaker.s3.S3UploadMode):
615
+ s3_upload_mode (str): Valid options are "EndOfJob" or "Continuous".
647
616
"""
648
617
self .source = source
649
618
self .destination = destination
@@ -659,7 +628,7 @@ def to_request_dict(self):
659
628
"S3Output" : {
660
629
"S3Uri" : self .destination ,
661
630
"LocalPath" : self .source ,
662
- "S3UploadMode" : self .s3_upload_mode . value ,
631
+ "S3UploadMode" : self .s3_upload_mode ,
663
632
},
664
633
}
665
634
0 commit comments