Skip to content

Commit 17d14d1

Browse files
author
Ignacio Quintero
committed
Cleanup tests
1 parent e3f6ab5 commit 17d14d1

File tree

3 files changed

+3
-54
lines changed

3 files changed

+3
-54
lines changed

tests/unit/test_chainer.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,6 @@ def test_create_model(sagemaker_session, chainer_version):
235235
job_name = 'new_name'
236236
chainer.fit(inputs='s3://mybucket/train', job_name='new_name')
237237
model = chainer.create_model()
238-
chainer.container_log_level
239238

240239
assert model.sagemaker_session == sagemaker_session
241240
assert model.framework_version == chainer_version
@@ -259,19 +258,10 @@ def test_create_model_with_custom_image(sagemaker_session):
259258
py_version=PYTHON_VERSION, base_job_name='job', source_dir=source_dir,
260259
enable_cloudwatch_metrics=enable_cloudwatch_metrics)
261260

262-
job_name = 'new_name'
263261
chainer.fit(inputs='s3://mybucket/train', job_name='new_name')
264262
model = chainer.create_model()
265-
chainer.container_log_level
266263

267-
assert model.sagemaker_session == sagemaker_session
268264
assert model.image == custom_image
269-
assert model.entry_point == SCRIPT_PATH
270-
assert model.role == ROLE
271-
assert model.name == job_name
272-
assert model.container_log_level == container_log_level
273-
assert model.source_dir == source_dir
274-
assert model.enable_cloudwatch_metrics == enable_cloudwatch_metrics
275265

276266

277267
@patch('time.strftime', return_value=TIMESTAMP)
@@ -450,15 +440,5 @@ def test_attach_custom_image(sagemaker_session):
450440
return_value=returned_job_description)
451441

452442
estimator = Chainer.attach(training_job_name='neo', sagemaker_session=sagemaker_session)
453-
assert estimator.latest_training_job.job_name == 'neo'
454-
assert estimator.role == 'arn:aws:iam::366:role/SageMakerRole'
455-
assert estimator.train_instance_count == 1
456-
assert estimator.train_max_run == 24 * 60 * 60
457-
assert estimator.input_mode == 'File'
458-
assert estimator.base_job_name == 'neo'
459-
assert estimator.output_path == 's3://place/output/neo'
460-
assert estimator.output_kms_key == ''
461-
assert estimator.hyperparameters()['training_steps'] == '100'
462-
assert estimator.source_dir == 's3://some/sourcedir.tar.gz'
463-
assert estimator.entry_point == 'iris-dnn-classifier.py'
443+
assert estimator.image_name == training_image
464444
assert estimator.train_image() == training_image

tests/unit/test_mxnet.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -350,15 +350,5 @@ def test_attach_custom_image(sagemaker_session):
350350
return_value=returned_job_description)
351351

352352
estimator = MXNet.attach(training_job_name='neo', sagemaker_session=sagemaker_session)
353-
assert estimator.latest_training_job.job_name == 'neo'
354353
assert estimator.image_name == training_image
355-
assert estimator.role == 'arn:aws:iam::366:role/SageMakerRole'
356-
assert estimator.train_instance_count == 1
357-
assert estimator.train_max_run == 24 * 60 * 60
358-
assert estimator.input_mode == 'File'
359-
assert estimator.base_job_name == 'neo'
360-
assert estimator.output_path == 's3://place/output/neo'
361-
assert estimator.output_kms_key == ''
362-
assert estimator.hyperparameters()['training_steps'] == '100'
363-
assert estimator.source_dir == 's3://some/sourcedir.tar.gz'
364-
assert estimator.entry_point == 'iris-dnn-classifier.py'
354+
assert estimator.train_image() == training_image

tests/unit/test_tf_estimator.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -220,14 +220,7 @@ def test_create_model_with_custom_image(sagemaker_session):
220220
tf.fit(inputs='s3://mybucket/train', job_name=job_name)
221221
model = tf.create_model()
222222

223-
assert model.sagemaker_session == sagemaker_session
224223
assert model.image == custom_image
225-
assert model.entry_point == SCRIPT_PATH
226-
assert model.role == ROLE
227-
assert model.name == job_name
228-
assert model.container_log_level == container_log_level
229-
assert model.source_dir == source_dir
230-
assert model.enable_cloudwatch_metrics == enable_cloudwatch_metrics
231224

232225

233226
@patch('time.strftime', return_value=TIMESTAMP)
@@ -612,19 +605,5 @@ def test_attach_custom_image(sagemaker_session):
612605
sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value=rjd)
613606

614607
estimator = TensorFlow.attach(training_job_name='neo', sagemaker_session=sagemaker_session)
615-
assert estimator.latest_training_job.job_name == 'neo'
616-
assert estimator.role == 'arn:aws:iam::366:role/SageMakerRole'
617-
assert estimator.train_instance_count == 1
618-
assert estimator.train_max_run == 24 * 60 * 60
619-
assert estimator.input_mode == 'File'
620-
assert estimator.training_steps == 100
621-
assert estimator.evaluation_steps == 10
622-
assert estimator.input_mode == 'File'
623-
assert estimator.base_job_name == 'neo'
624-
assert estimator.output_path == 's3://place/output/neo'
625-
assert estimator.output_kms_key == ''
626-
assert estimator.hyperparameters()['training_steps'] == '100'
627-
assert estimator.source_dir == 's3://some/sourcedir.tar.gz'
628-
assert estimator.entry_point == 'iris-dnn-classifier.py'
629-
assert estimator.checkpoint_path == 's3://other/1508872349'
608+
assert estimator.image_name == training_image
630609
assert estimator.train_image() == training_image

0 commit comments

Comments
 (0)