15
15
import logging
16
16
import json
17
17
import os
18
+
18
19
import pytest
19
20
from mock import Mock , patch
20
21
39
40
REGION = 'us-west-2'
40
41
JOB_NAME = '{}-{}' .format (IMAGE_NAME , TIMESTAMP )
41
42
42
- COMMON_TRAIN_ARGS = {'volume_size' : 30 ,
43
- 'hyperparameters' : {
44
- 'sagemaker_program' : 'dummy_script.py' ,
45
- 'sagemaker_enable_cloudwatch_metrics' : False ,
46
- 'sagemaker_container_log_level' : logging .INFO ,
47
- },
48
- 'input_mode' : 'File' ,
49
- 'instance_type' : 'c4.4xlarge' ,
50
- 'inputs' : 's3://mybucket/train' ,
51
- 'instance_count' : 1 ,
52
- 'role' : 'DummyRole' ,
53
- 'kms_key_id' : None ,
54
- 'max_run' : 24 ,
55
- 'wait' : True }
43
+ COMMON_TRAIN_ARGS = {
44
+ 'volume_size' : 30 ,
45
+ 'hyperparameters' : {
46
+ 'sagemaker_program' : 'dummy_script.py' ,
47
+ 'sagemaker_enable_cloudwatch_metrics' : False ,
48
+ 'sagemaker_container_log_level' : logging .INFO ,
49
+ },
50
+ 'input_mode' : 'File' ,
51
+ 'instance_type' : 'c4.4xlarge' ,
52
+ 'inputs' : 's3://mybucket/train' ,
53
+ 'instance_count' : 1 ,
54
+ 'role' : 'DummyRole' ,
55
+ 'kms_key_id' : None ,
56
+ 'max_run' : 24 ,
57
+ 'wait' : True ,
58
+ }
56
59
57
60
DESCRIBE_TRAINING_JOB_RESULT = {
58
61
'ModelArtifacts' : {
@@ -275,19 +278,6 @@ def test_attach_framework(sagemaker_session):
275
278
assert framework_estimator .entry_point == 'iris-dnn-classifier.py'
276
279
277
280
278
- def test_fit_then_fit_again (sagemaker_session ):
279
- fw = DummyFramework (entry_point = SCRIPT_PATH , role = ROLE , sagemaker_session = sagemaker_session ,
280
- train_instance_count = INSTANCE_COUNT , train_instance_type = INSTANCE_TYPE ,
281
- enable_cloudwatch_metrics = True )
282
- fw .fit (inputs = s3_input ('s3://mybucket/train' ))
283
- first_job_name = fw .latest_training_job .name
284
-
285
- fw .fit (inputs = s3_input ('s3://mybucket/train2' ))
286
- second_job_name = fw .latest_training_job .name
287
-
288
- assert first_job_name != second_job_name
289
-
290
-
291
281
@patch ('time.strftime' , return_value = TIMESTAMP )
292
282
def test_fit_verify_job_name (strftime , sagemaker_session ):
293
283
fw = DummyFramework (entry_point = SCRIPT_PATH , role = 'DummyRole' , sagemaker_session = sagemaker_session ,
@@ -304,42 +294,55 @@ def test_fit_verify_job_name(strftime, sagemaker_session):
304
294
assert fw .latest_training_job .name == JOB_NAME
305
295
306
296
307
- def test_fit_force_name (sagemaker_session ):
297
+ def test_prepare_for_training_unique_job_name_generation (sagemaker_session ):
298
+ fw = DummyFramework (entry_point = SCRIPT_PATH , role = ROLE , sagemaker_session = sagemaker_session ,
299
+ train_instance_count = INSTANCE_COUNT , train_instance_type = INSTANCE_TYPE ,
300
+ enable_cloudwatch_metrics = True )
301
+ fw .prepare_for_training ()
302
+ first_job_name = fw ._current_job_name
303
+
304
+ fw .prepare_for_training ()
305
+ second_job_name = fw ._current_job_name
306
+
307
+ assert first_job_name != second_job_name
308
+
309
+
310
+ def test_prepare_for_training_force_name (sagemaker_session ):
308
311
fw = DummyFramework (entry_point = SCRIPT_PATH , role = ROLE , sagemaker_session = sagemaker_session ,
309
312
train_instance_count = INSTANCE_COUNT , train_instance_type = INSTANCE_TYPE ,
310
313
base_job_name = 'some' , enable_cloudwatch_metrics = True )
311
- fw .fit ( inputs = s3_input ( 's3://mybucket/train' ), job_name = 'use_it' )
312
- assert 'use_it' == fw .latest_training_job . name
314
+ fw .prepare_for_training ( job_name = 'use_it' )
315
+ assert 'use_it' == fw ._current_job_name
313
316
314
317
315
318
@patch ('time.strftime' , return_value = TIMESTAMP )
316
- def test_fit_force_generation (strftime , sagemaker_session ):
319
+ def test_prepare_for_training_force_name_generation (strftime , sagemaker_session ):
317
320
fw = DummyFramework (entry_point = SCRIPT_PATH , role = ROLE , sagemaker_session = sagemaker_session ,
318
321
train_instance_count = INSTANCE_COUNT , train_instance_type = INSTANCE_TYPE ,
319
322
base_job_name = 'some' , enable_cloudwatch_metrics = True )
320
323
fw .base_job_name = None
321
- fw .fit ( inputs = s3_input ( 's3://mybucket/train' ) )
322
- assert JOB_NAME == fw .latest_training_job . name
324
+ fw .prepare_for_training ( )
325
+ assert JOB_NAME == fw ._current_job_name
323
326
324
327
325
328
@patch ('time.strftime' , return_value = TIMESTAMP )
326
329
def test_init_with_source_dir_s3 (strftime , sagemaker_session ):
327
- uri = 'bucket/mydata'
328
-
329
330
fw = DummyFramework (entry_point = SCRIPT_PATH , source_dir = 's3://location' , role = ROLE ,
330
331
sagemaker_session = sagemaker_session ,
331
332
train_instance_count = INSTANCE_COUNT , train_instance_type = INSTANCE_TYPE ,
332
333
enable_cloudwatch_metrics = False )
333
- fw .fit ('s3://{}' .format (uri ))
334
-
335
- expected_hyperparameters = BASE_HP .copy ()
336
- expected_hyperparameters ['sagemaker_enable_cloudwatch_metrics' ] = 'false'
337
- expected_hyperparameters ['sagemaker_container_log_level' ] = str (logging .INFO )
338
- expected_hyperparameters ['sagemaker_submit_directory' ] = json .dumps ("s3://location" )
339
- expected_hyperparameters ['sagemaker_region' ] = '"us-west-2"'
340
-
341
- actual_hyperparameter = sagemaker_session .method_calls [1 ][2 ]['hyperparameters' ]
342
- assert actual_hyperparameter == expected_hyperparameters
334
+ fw .prepare_for_training ()
335
+
336
+ expected_hyperparameters = {
337
+ 'sagemaker_program' : SCRIPT_NAME ,
338
+ 'sagemaker_submit_directory' : 's3://mybucket/{}/source/sourcedir.tar.gz' .format (JOB_NAME ),
339
+ 'sagemaker_job_name' : JOB_NAME ,
340
+ 'sagemaker_enable_cloudwatch_metrics' : False ,
341
+ 'sagemaker_container_log_level' : logging .INFO ,
342
+ 'sagemaker_submit_directory' : 's3://location' ,
343
+ 'sagemaker_region' : 'us-west-2' ,
344
+ }
345
+ assert fw ._hyperparameters == expected_hyperparameters
343
346
344
347
345
348
# _TrainingJob 'utils'
0 commit comments