1
+ # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License"). You
4
+ # may not use this file except in compliance with the License. A copy of
5
+ # the License is located at
6
+ #
7
+ # http://aws.amazon.com/apache2.0/
8
+ #
9
+ # or in the "license" file accompanying this file. This file is
10
+ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
+ # ANY KIND, either express or implied. See the License for the specific
12
+ # language governing permissions and limitations under the License.
13
+ from abc import abstractmethod
14
+ from six import string_types
15
+
16
+ from sagemaker .session import s3_input
17
+
18
+
19
+ class _Job (object ):
20
+ def __init__ (self , sagemaker_session , training_job_name ):
21
+ self .sagemaker_session = sagemaker_session
22
+ self .job_name = training_job_name
23
+
24
+ @abstractmethod
25
+ def start_new (cls , estimator , inputs ):
26
+ """Create a new Amazon SageMaker job from the estimator.
27
+
28
+ Args:
29
+ estimator (sagemaker.estimator.EstimatorBase): Estimator object created by the user.
30
+ inputs (str): Parameters used when called :meth:`~sagemaker.estimator.EstimatorBase.fit`.
31
+
32
+ Returns:
33
+ sagemaker.estimator.Framework: Constructed object that captures all information about the started job.
34
+ """
35
+
36
+ pass
37
+
38
+ @staticmethod
39
+ def _load_config (inputs , estimator ):
40
+ input_config = _Job ._format_inputs_to_input_config (inputs )
41
+ role = estimator .sagemaker_session .expand_role (estimator .role )
42
+ output_config = _Job ._prepare_output_config (estimator .output_path ,
43
+ estimator .output_kms_key )
44
+ resource_config = _Job ._prepare_resource_config (estimator .train_instance_count ,
45
+ estimator .train_instance_type ,
46
+ estimator .train_volume_size )
47
+ stop_condition = _Job ._prepare_stopping_condition (estimator .train_max_run )
48
+
49
+ return input_config , role , output_config , resource_config , stop_condition
50
+
51
+ @staticmethod
52
+ def _format_inputs_to_input_config (inputs ):
53
+ input_dict = {}
54
+ if isinstance (inputs , string_types ):
55
+ input_dict ['training' ] = _Job ._format_s3_uri_input (inputs )
56
+ elif isinstance (inputs , s3_input ):
57
+ input_dict ['training' ] = inputs
58
+ elif isinstance (inputs , dict ):
59
+ for k , v in inputs .items ():
60
+ input_dict [k ] = _Job ._format_s3_uri_input (v )
61
+ else :
62
+ raise ValueError ('Cannot format input {}. Expecting one of str, dict or s3_input' .format (inputs ))
63
+
64
+ channels = []
65
+ for channel_name , channel_s3_input in input_dict .items ():
66
+ channel_config = channel_s3_input .config .copy ()
67
+ channel_config ['ChannelName' ] = channel_name
68
+ channels .append (channel_config )
69
+ return channels
70
+
71
+ @staticmethod
72
+ def _format_s3_uri_input (input ):
73
+ if isinstance (input , str ):
74
+ if not input .startswith ('s3://' ):
75
+ raise ValueError ('Training input data must be a valid S3 URI and must start with "s3://"' )
76
+ return s3_input (input )
77
+ if isinstance (input , s3_input ):
78
+ return input
79
+ else :
80
+ raise ValueError ('Cannot format input {}. Expecting one of str or s3_input' .format (input ))
81
+
82
+ @staticmethod
83
+ def _prepare_output_config (s3_path , kms_key_id ):
84
+ config = {'S3OutputPath' : s3_path }
85
+ if kms_key_id is not None :
86
+ config ['KmsKeyId' ] = kms_key_id
87
+ return config
88
+
89
+ @staticmethod
90
+ def _prepare_resource_config (instance_count , instance_type , volume_size ):
91
+ resource_config = {'InstanceCount' : instance_count ,
92
+ 'InstanceType' : instance_type ,
93
+ 'VolumeSizeInGB' : volume_size }
94
+ return resource_config
95
+
96
+ @staticmethod
97
+ def _prepare_stopping_condition (max_run ):
98
+ stop_condition = {'MaxRuntimeInSeconds' : max_run }
99
+ return stop_condition
100
+
101
+ @property
102
+ def name (self ):
103
+ return self .job_name
104
+
105
+ def wait (self , logs = True ):
106
+ if logs :
107
+ self .sagemaker_session .logs_for_job (self .job_name , wait = True )
108
+ else :
109
+ self .sagemaker_session .wait_for_job (self .job_name )
0 commit comments