17
17
import os
18
18
from abc import ABCMeta
19
19
from abc import abstractmethod
20
- from six import with_metaclass , string_types
20
+ from six import with_metaclass
21
21
22
22
from sagemaker .fw_utils import tar_and_upload_dir , parse_s3_url , UploadedCode , validate_source_dir
23
- from sagemaker .local import LocalSession , file_input
24
-
23
+ from sagemaker .job import _Job
24
+ from sagemaker . local import LocalSession
25
25
from sagemaker .model import Model
26
26
from sagemaker .model import (SCRIPT_PARAM_NAME , DIR_PARAM_NAME , CLOUDWATCH_METRICS_PARAM_NAME ,
27
27
CONTAINER_LOG_LEVEL_PARAM_NAME , JOB_NAME_PARAM_NAME , SAGEMAKER_REGION_PARAM_NAME )
28
-
29
28
from sagemaker .predictor import RealTimePredictor
30
29
from sagemaker .session import Session
31
30
from sagemaker .session import s3_input
@@ -310,10 +309,9 @@ def delete_endpoint(self):
310
309
self .sagemaker_session .delete_endpoint (self .latest_training_job .name )
311
310
312
311
313
- class _TrainingJob (object ):
312
+ class _TrainingJob (_Job ):
314
313
def __init__ (self , sagemaker_session , training_job_name ):
315
- self .sagemaker_session = sagemaker_session
316
- self .job_name = training_job_name
314
+ super (_TrainingJob , self ).__init__ (sagemaker_session , training_job_name )
317
315
318
316
@classmethod
319
317
def start_new (cls , estimator , inputs ):
@@ -324,7 +322,8 @@ def start_new(cls, estimator, inputs):
324
322
inputs (str): Parameters used when called :meth:`~sagemaker.estimator.EstimatorBase.fit`.
325
323
326
324
Returns:
327
- sagemaker.estimator.Framework: Constructed object that captures all information about the started job.
325
+ sagemaker.estimator._TrainingJob: Constructed object that captures all information about the started
326
+ training job.
328
327
"""
329
328
330
329
local_mode = estimator .sagemaker_session .local_mode
@@ -334,86 +333,19 @@ def start_new(cls, estimator, inputs):
334
333
if not local_mode :
335
334
raise ValueError ('File URIs are supported in local mode only. Please use a S3 URI instead.' )
336
335
337
- input_config = _TrainingJob ._format_inputs_to_input_config (inputs )
338
- role = estimator .sagemaker_session .expand_role (estimator .role )
339
- output_config = _TrainingJob ._prepare_output_config (estimator .output_path , estimator .output_kms_key )
340
- resource_config = _TrainingJob ._prepare_resource_config (estimator .train_instance_count ,
341
- estimator .train_instance_type ,
342
- estimator .train_volume_size )
343
- stop_condition = _TrainingJob ._prepare_stopping_condition (estimator .train_max_run )
336
+ config = _Job ._load_config (inputs , estimator )
344
337
345
338
if estimator .hyperparameters () is not None :
346
339
hyperparameters = {str (k ): str (v ) for (k , v ) in estimator .hyperparameters ().items ()}
347
340
348
341
estimator .sagemaker_session .train (image = estimator .train_image (), input_mode = estimator .input_mode ,
349
- input_config = input_config , role = role , job_name = estimator ._current_job_name ,
350
- output_config = output_config , resource_config = resource_config ,
351
- hyperparameters = hyperparameters , stop_condition = stop_condition )
342
+ input_config = config ['input_config' ], role = config ['role' ],
343
+ job_name = estimator ._current_job_name , output_config = config ['output_config' ],
344
+ resource_config = config ['resource_config' ], hyperparameters = hyperparameters ,
345
+ stop_condition = config ['stop_condition' ])
352
346
353
347
return cls (estimator .sagemaker_session , estimator ._current_job_name )
354
348
355
- @staticmethod
356
- def _format_inputs_to_input_config (inputs ):
357
- input_dict = {}
358
- if isinstance (inputs , string_types ):
359
- input_dict ['training' ] = _TrainingJob ._format_string_uri_input (inputs )
360
- elif isinstance (inputs , s3_input ):
361
- input_dict ['training' ] = inputs
362
- elif isinstance (input , file_input ):
363
- input_dict ['training' ] = inputs
364
- elif isinstance (inputs , dict ):
365
- for k , v in inputs .items ():
366
- input_dict [k ] = _TrainingJob ._format_string_uri_input (v )
367
- else :
368
- raise ValueError ('Cannot format input {}. Expecting one of str, dict or s3_input' .format (inputs ))
369
-
370
- channels = []
371
- for channel_name , channel_s3_input in input_dict .items ():
372
- channel_config = channel_s3_input .config .copy ()
373
- channel_config ['ChannelName' ] = channel_name
374
- channels .append (channel_config )
375
- return channels
376
-
377
- @staticmethod
378
- def _format_string_uri_input (input ):
379
- if isinstance (input , str ):
380
- if input .startswith ('s3://' ):
381
- return s3_input (input )
382
- elif input .startswith ('file://' ):
383
- return file_input (input )
384
- else :
385
- raise ValueError ('Training input data must be a valid S3 or FILE URI: must start with "s3://" or '
386
- '"file://"' )
387
- elif isinstance (input , s3_input ):
388
- return input
389
- elif isinstance (input , file_input ):
390
- return input
391
- else :
392
- raise ValueError ('Cannot format input {}. Expecting one of str, s3_input, or file_input' .format (input ))
393
-
394
- @staticmethod
395
- def _prepare_output_config (s3_path , kms_key_id ):
396
- config = {'S3OutputPath' : s3_path }
397
- if kms_key_id is not None :
398
- config ['KmsKeyId' ] = kms_key_id
399
- return config
400
-
401
- @staticmethod
402
- def _prepare_resource_config (instance_count , instance_type , volume_size ):
403
- resource_config = {'InstanceCount' : instance_count ,
404
- 'InstanceType' : instance_type ,
405
- 'VolumeSizeInGB' : volume_size }
406
- return resource_config
407
-
408
- @staticmethod
409
- def _prepare_stopping_condition (max_run ):
410
- stop_condition = {'MaxRuntimeInSeconds' : max_run }
411
- return stop_condition
412
-
413
- @property
414
- def name (self ):
415
- return self .job_name
416
-
417
349
def wait (self , logs = True ):
418
350
if logs :
419
351
self .sagemaker_session .logs_for_job (self .job_name , wait = True )
@@ -474,8 +406,7 @@ def train_image(self):
474
406
"""
475
407
Returns the docker image to use for training.
476
408
477
- The fit() method, that does the model training, calls this method to find the image to use
478
- for model training.
409
+ The fit() method, that does the model training, calls this method to find the image to use for model training.
479
410
"""
480
411
return self .image_name
481
412
0 commit comments