Skip to content

Commit d03c125

Browse files
imujjwal96jesterhazy
authored andcommitted
fix: emit training jobs tags to estimator (#803)
1 parent 7d9affd commit d03c125

File tree

9 files changed

+67
-0
lines changed

9 files changed

+67
-0
lines changed

src/sagemaker/estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,8 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name='m
320320

321321
job_details = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name)
322322
init_params = cls._prepare_init_params_from_job_description(job_details, model_channel_name)
323+
tags = sagemaker_session.sagemaker_client.list_tags(ResourceArn=job_details['TrainingJobArn'])['Tags']
324+
init_params.update(tags=tags)
323325

324326
estimator = cls(sagemaker_session=sagemaker_session, **init_params)
325327
estimator.latest_training_job = _TrainingJob(sagemaker_session=sagemaker_session,

tests/unit/test_chainer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@
5454
{'ModelName': 'model-2'}]
5555
}
5656

57+
LIST_TAGS_RESULT = {
58+
'Tags': [{'Key': 'TagtestKey', 'Value': 'TagtestValue'}]
59+
}
60+
5761

5862
@pytest.fixture()
5963
def sagemaker_session():
@@ -65,6 +69,7 @@ def sagemaker_session():
6569
session.sagemaker_client.describe_training_job = Mock(return_value=describe)
6670
session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
6771
session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
72+
session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT)
6873
session.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
6974
session.expand_role = Mock(name="expand_role", return_value=ROLE)
7075
return session
@@ -222,6 +227,7 @@ def test_attach_with_additional_hyperparameters(sagemaker_session, chainer_versi
222227
'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60},
223228
'TrainingJobName': 'neo',
224229
'TrainingJobStatus': 'Completed',
230+
'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo',
225231
'OutputDataConfig': {'KmsKeyId': '',
226232
'S3OutputPath': 's3://place/output/neo'},
227233
'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}}
@@ -404,6 +410,7 @@ def test_attach(sagemaker_session, chainer_version):
404410
'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60},
405411
'TrainingJobName': 'neo',
406412
'TrainingJobStatus': 'Completed',
413+
'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo',
407414
'OutputDataConfig': {'KmsKeyId': '',
408415
'S3OutputPath': 's3://place/output/neo'},
409416
'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}}
@@ -446,6 +453,7 @@ def test_attach_wrong_framework(sagemaker_session):
446453
'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60},
447454
'TrainingJobName': 'neo',
448455
'TrainingJobStatus': 'Completed',
456+
'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo',
449457
'OutputDataConfig': {'KmsKeyId': '',
450458
'S3OutputPath': 's3://place/output/neo'},
451459
'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}}
@@ -478,6 +486,7 @@ def test_attach_custom_image(sagemaker_session):
478486
'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60},
479487
'TrainingJobName': 'neo',
480488
'TrainingJobStatus': 'Completed',
489+
'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo',
481490
'OutputDataConfig': {'KmsKeyId': '',
482491
'S3OutputPath': 's3://place/output/neo'},
483492
'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}}

tests/unit/test_estimator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
},
8181
'TrainingJobName': 'neo',
8282
'TrainingJobStatus': 'Completed',
83+
'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo',
8384
'OutputDataConfig': {
8485
'KmsKeyId': '',
8586
'S3OutputPath': 's3://place/output/neo'
@@ -111,6 +112,10 @@
111112
{'ModelName': 'model-2'}]
112113
}
113114

115+
LIST_TAGS_RESULT = {
116+
'Tags': [{'Key': 'TagtestKey', 'Value': 'TagtestValue'}]
117+
}
118+
114119

115120
class DummyFramework(Framework):
116121
__framework_name__ = 'dummy'
@@ -157,6 +162,7 @@ def sagemaker_session():
157162
return_value=DESCRIBE_TRAINING_JOB_RESULT)
158163
sms.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
159164
sms.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
165+
sms.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT)
160166
return sms
161167

162168

@@ -460,6 +466,7 @@ def test_attach_framework(sagemaker_session):
460466
assert framework_estimator.subnets == ['foo']
461467
assert framework_estimator.security_group_ids == ['bar']
462468
assert framework_estimator.encrypt_inter_container_traffic is False
469+
assert framework_estimator.tags == LIST_TAGS_RESULT['Tags']
463470

464471

465472
def test_attach_without_hyperparameters(sagemaker_session):

tests/unit/test_mxnet.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@
5454
{'ModelName': 'model-2'}]
5555
}
5656

57+
LIST_TAGS_RESULT = {
58+
'Tags': [{'Key': 'TagtestKey', 'Value': 'TagtestValue'}]
59+
}
60+
5761

5862
@pytest.fixture()
5963
def sagemaker_session():
@@ -66,6 +70,7 @@ def sagemaker_session():
6670
session.sagemaker_client.describe_training_job = Mock(return_value=describe)
6771
session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
6872
session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
73+
session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT)
6974
session.wait_for_compilation_job = Mock(return_value=describe_compilation)
7075
session.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
7176
session.expand_role = Mock(name="expand_role", return_value=ROLE)
@@ -346,6 +351,7 @@ def test_attach(sagemaker_session, mxnet_version):
346351
'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60},
347352
'TrainingJobName': 'neo',
348353
'TrainingJobStatus': 'Completed',
354+
'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo',
349355
'OutputDataConfig': {
350356
'KmsKeyId': '',
351357
'S3OutputPath': 's3://place/output/neo'
@@ -369,6 +375,7 @@ def test_attach(sagemaker_session, mxnet_version):
369375
assert estimator.hyperparameters()['training_steps'] == '100'
370376
assert estimator.source_dir == 's3://some/sourcedir.tar.gz'
371377
assert estimator.entry_point == 'iris-dnn-classifier.py'
378+
assert estimator.tags == LIST_TAGS_RESULT['Tags']
372379

373380

374381
def test_attach_old_container(sagemaker_session):
@@ -392,6 +399,7 @@ def test_attach_old_container(sagemaker_session):
392399
'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60},
393400
'TrainingJobName': 'neo',
394401
'TrainingJobStatus': 'Completed',
402+
'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo',
395403
'OutputDataConfig': {'KmsKeyId': '', 'S3OutputPath': 's3://place/output/neo'},
396404
'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}}
397405
sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job',
@@ -434,6 +442,7 @@ def test_attach_wrong_framework(sagemaker_session):
434442
'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60},
435443
'TrainingJobName': 'neo',
436444
'TrainingJobStatus': 'Completed',
445+
'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo',
437446
'OutputDataConfig': {'KmsKeyId': '', 'S3OutputPath': 's3://place/output/neo'},
438447
'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}}
439448
sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value=rjd)
@@ -465,6 +474,7 @@ def test_attach_custom_image(sagemaker_session):
465474
'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60},
466475
'TrainingJobName': 'neo',
467476
'TrainingJobStatus': 'Completed',
477+
'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo',
468478
'OutputDataConfig': {'KmsKeyId': '', 'S3OutputPath': 's3://place/output/neo'},
469479
'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}}
470480
sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job',

tests/unit/test_pytorch.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@
5252
{'ModelName': 'model-2'}]
5353
}
5454

55+
LIST_TAGS_RESULT = {
56+
'Tags': [{'Key': 'TagtestKey', 'Value': 'TagtestValue'}]
57+
}
58+
5559

5660
@pytest.fixture(name='sagemaker_session')
5761
def fixture_sagemaker_session():
@@ -63,6 +67,7 @@ def fixture_sagemaker_session():
6367
session.sagemaker_client.describe_training_job = Mock(return_value=describe)
6468
session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
6569
session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
70+
session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT)
6671
session.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
6772
session.expand_role = Mock(name="expand_role", return_value=ROLE)
6873
return session
@@ -306,6 +311,7 @@ def test_attach(sagemaker_session, pytorch_version):
306311
'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60},
307312
'TrainingJobName': 'neo',
308313
'TrainingJobStatus': 'Completed',
314+
'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo',
309315
'OutputDataConfig': {'KmsKeyId': '',
310316
'S3OutputPath': 's3://place/output/neo'},
311317
'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}}
@@ -348,6 +354,7 @@ def test_attach_wrong_framework(sagemaker_session):
348354
'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60},
349355
'TrainingJobName': 'neo',
350356
'TrainingJobStatus': 'Completed',
357+
'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo',
351358
'OutputDataConfig': {'KmsKeyId': '',
352359
'S3OutputPath': 's3://place/output/neo'},
353360
'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}}
@@ -380,6 +387,7 @@ def test_attach_custom_image(sagemaker_session):
380387
'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60},
381388
'TrainingJobName': 'neo',
382389
'TrainingJobStatus': 'Completed',
390+
'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo',
383391
'OutputDataConfig': {'KmsKeyId': '',
384392
'S3OutputPath': 's3://place/output/neo'},
385393
'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}}

tests/unit/test_rl.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@
4949
{'ModelName': 'model-2'}]
5050
}
5151

52+
LIST_TAGS_RESULT = {
53+
'Tags': [{'Key': 'TagtestKey', 'Value': 'TagtestValue'}]
54+
}
55+
5256

5357
@pytest.fixture(name='sagemaker_session')
5458
def fixture_sagemaker_session():
@@ -60,6 +64,7 @@ def fixture_sagemaker_session():
6064
session.sagemaker_client.describe_training_job = Mock(return_value=describe)
6165
session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
6266
session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
67+
session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT)
6368
session.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
6469
session.expand_role = Mock(name="expand_role", return_value=ROLE)
6570
return session
@@ -360,6 +365,7 @@ def test_attach(sagemaker_session, rl_coach_mxnet_version):
360365
'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60},
361366
'TrainingJobName': 'neo',
362367
'TrainingJobStatus': 'Completed',
368+
'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo',
363369
'OutputDataConfig': {'KmsKeyId': '',
364370
'S3OutputPath': 's3://place/output/neo'},
365371
'TrainingJobOutput': {
@@ -406,6 +412,7 @@ def test_attach_wrong_framework(sagemaker_session):
406412
'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60},
407413
'TrainingJobName': 'neo',
408414
'TrainingJobStatus': 'Completed',
415+
'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo',
409416
'OutputDataConfig': {'KmsKeyId': '',
410417
'S3OutputPath': 's3://place/output/neo'},
411418
'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}}
@@ -438,6 +445,7 @@ def test_attach_custom_image(sagemaker_session):
438445
'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60},
439446
'TrainingJobName': 'neo',
440447
'TrainingJobStatus': 'Completed',
448+
'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo',
441449
'OutputDataConfig':
442450
{'KmsKeyId': '',
443451
'S3OutputPath': 's3://place/output/neo'},

tests/unit/test_sklearn.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@
5252
{'ModelName': 'model-2'}]
5353
}
5454

55+
LIST_TAGS_RESULT = {
56+
'Tags': [{'Key': 'TagtestKey', 'Value': 'TagtestValue'}]
57+
}
58+
5559

5660
@pytest.fixture()
5761
def sagemaker_session():
@@ -63,6 +67,7 @@ def sagemaker_session():
6367
session.sagemaker_client.describe_training_job = Mock(return_value=describe)
6468
session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
6569
session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
70+
session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT)
6671
session.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
6772
session.expand_role = Mock(name="expand_role", return_value=ROLE)
6873
return session
@@ -311,6 +316,7 @@ def test_attach(sagemaker_session, sklearn_version):
311316
'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60},
312317
'TrainingJobName': 'neo',
313318
'TrainingJobStatus': 'Completed',
319+
'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo',
314320
'OutputDataConfig': {'KmsKeyId': '',
315321
'S3OutputPath': 's3://place/output/neo'},
316322
'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}}
@@ -353,6 +359,7 @@ def test_attach_wrong_framework(sagemaker_session):
353359
'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60},
354360
'TrainingJobName': 'neo',
355361
'TrainingJobStatus': 'Completed',
362+
'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo',
356363
'OutputDataConfig': {'KmsKeyId': '',
357364
'S3OutputPath': 's3://place/output/neo'},
358365
'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}}
@@ -385,6 +392,7 @@ def test_attach_custom_image(sagemaker_session):
385392
'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60},
386393
'TrainingJobName': 'neo',
387394
'TrainingJobStatus': 'Completed',
395+
'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo',
388396
'OutputDataConfig': {'KmsKeyId': '',
389397
'S3OutputPath': 's3://place/output/neo'},
390398
'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}}

tests/unit/test_tf_estimator.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@
5757
{'ModelName': 'model-2'}]
5858
}
5959

60+
LIST_TAGS_RESULT = {
61+
'Tags': [{'Key': 'TagtestKey', 'Value': 'TagtestValue'}]
62+
}
63+
6064

6165
@pytest.fixture()
6266
def sagemaker_session():
@@ -69,6 +73,7 @@ def sagemaker_session():
6973
session.sagemaker_client.describe_training_job = Mock(return_value=describe)
7074
session.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
7175
session.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
76+
session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT)
7277
return session
7378

7479

@@ -496,6 +501,7 @@ def test_attach(sagemaker_session, tf_version):
496501
'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60},
497502
'TrainingJobName': 'neo',
498503
'TrainingJobStatus': 'Completed',
504+
'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo',
499505
'OutputDataConfig': {'KmsKeyId': '', 'S3OutputPath': 's3://place/output/neo'},
500506
'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}}
501507
sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value=rjd)
@@ -544,6 +550,7 @@ def test_attach_new_repo_name(sagemaker_session, tf_version):
544550
'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60},
545551
'TrainingJobName': 'neo',
546552
'TrainingJobStatus': 'Completed',
553+
'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo',
547554
'OutputDataConfig': {'KmsKeyId': '', 'S3OutputPath': 's3://place/output/neo'},
548555
'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}}
549556
sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value=rjd)
@@ -593,6 +600,7 @@ def test_attach_old_container(sagemaker_session):
593600
'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60},
594601
'TrainingJobName': 'neo',
595602
'TrainingJobStatus': 'Completed',
603+
'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo',
596604
'OutputDataConfig': {'KmsKeyId': '', 'S3OutputPath': 's3://place/output/neo'},
597605
'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}}
598606
sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value=rjd)
@@ -642,6 +650,7 @@ def test_attach_wrong_framework(sagemaker_session):
642650
},
643651
'TrainingJobName': 'neo',
644652
'TrainingJobStatus': 'Completed',
653+
'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo',
645654
'OutputDataConfig': {
646655
'KmsKeyId': '',
647656
'S3OutputPath': 's3://place/output/neo'
@@ -681,6 +690,7 @@ def test_attach_custom_image(sagemaker_session):
681690
'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60},
682691
'TrainingJobName': 'neo',
683692
'TrainingJobStatus': 'Completed',
693+
'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo',
684694
'OutputDataConfig': {'KmsKeyId': '', 'S3OutputPath': 's3://place/output/neo'},
685695
'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}}
686696
sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value=rjd)
@@ -856,6 +866,7 @@ def test_tf_script_mode_attach(sagemaker_session, tf_version):
856866
'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60},
857867
'TrainingJobName': 'neo',
858868
'TrainingJobStatus': 'Completed',
869+
'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/neo',
859870
'OutputDataConfig': {'KmsKeyId': '', 'S3OutputPath': 's3://place/output/neo'},
860871
'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}}
861872
sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value=rjd)

0 commit comments

Comments
 (0)