Skip to content

Commit 0538787

Browse files
committed
feature: emit estimator transformer tags to model
1 parent 04410d8 commit 0538787

File tree

4 files changed

+12
-5
lines changed

4 files changed

+12
-5
lines changed

src/sagemaker/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1061,7 +1061,7 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
10611061
container_def = model.prepare_container_def(instance_type)
10621062
model_name = model.name or name_from_image(container_def['Image'])
10631063
vpc_config = model.vpc_config
1064-
self.sagemaker_session.create_model(model_name, role, container_def, vpc_config)
1064+
self.sagemaker_session.create_model(model_name, role, container_def, vpc_config, tags=tags)
10651065
transform_env = model.env.copy()
10661066
if env is not None:
10671067
transform_env.update(env)

src/sagemaker/session.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,7 @@ def create_model_from_job(self, training_job_name, name=None, role=None, primary
643643
str: The name of the created ``Model``.
644644
"""
645645
training_job = self.sagemaker_client.describe_training_job(TrainingJobName=training_job_name)
646+
tags = self.sagemaker_client.list_tags(ResourceArn=training_job['TrainingJobArn'])['Tags']
646647
name = name or training_job_name
647648
role = role or training_job['RoleArn']
648649
env = env or {}
@@ -651,7 +652,7 @@ def create_model_from_job(self, training_job_name, name=None, role=None, primary
651652
model_data_url=model_data_url or training_job['ModelArtifacts']['S3ModelArtifacts'],
652653
env=env)
653654
vpc_config = _vpc_config_from_training_job(training_job, vpc_config_override)
654-
return self.create_model(name, role, primary_container, vpc_config=vpc_config)
655+
return self.create_model(name, role, primary_container, vpc_config=vpc_config, tags=tags)
655656

656657
def create_model_package_from_algorithm(self, name, description, algorithm_arn, model_data):
657658
"""Create a SageMaker Model Package from the results of training with an Algorithm Package

tests/unit/test_estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,7 @@ def test_framework_transformer_creation(name_from_image, sagemaker_session):
624624
transformer = fw.transformer(INSTANCE_COUNT, INSTANCE_TYPE)
625625

626626
name_from_image.assert_called_with(MODEL_IMAGE)
627-
sagemaker_session.create_model.assert_called_with(MODEL_IMAGE, ROLE, MODEL_CONTAINER_DEF, None)
627+
sagemaker_session.create_model.assert_called_with(MODEL_IMAGE, ROLE, MODEL_CONTAINER_DEF, None, tags=None)
628628

629629
assert isinstance(transformer, Transformer)
630630
assert transformer.sagemaker_session == sagemaker_session
@@ -659,7 +659,7 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa
659659
max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
660660
volume_kms_key=kms_key, env=env, role=new_role, model_server_workers=1)
661661

662-
sagemaker_session.create_model.assert_called_with(MODEL_IMAGE, new_role, MODEL_CONTAINER_DEF, vpc_config)
662+
sagemaker_session.create_model.assert_called_with(MODEL_IMAGE, new_role, MODEL_CONTAINER_DEF, vpc_config, tags=TAGS)
663663
assert transformer.strategy == strategy
664664
assert transformer.assemble_with == assemble_with
665665
assert transformer.output_path == OUTPUT_PATH

tests/unit/test_session.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ def test_s3_input_all_arguments():
291291
}
292292

293293
COMPLETED_DESCRIBE_JOB_RESULT = dict(DEFAULT_EXPECTED_TRAIN_JOB_ARGS)
294+
COMPLETED_DESCRIBE_JOB_RESULT.update({'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/' + JOB_NAME})
294295
COMPLETED_DESCRIBE_JOB_RESULT.update({'TrainingJobStatus': 'Completed'})
295296
COMPLETED_DESCRIBE_JOB_RESULT.update(
296297
{'ModelArtifacts': {
@@ -861,18 +862,21 @@ def test_create_model_failure(expand_container_def, sagemaker_session):
861862
def test_create_model_from_job(sagemaker_session):
862863
ims = sagemaker_session
863864
ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT
865+
ims.sagemaker_client.list_tags.return_value = {'Tags': TAGS}
864866
ims.create_model_from_job(JOB_NAME)
865867

866868
assert call(TrainingJobName=JOB_NAME) in ims.sagemaker_client.describe_training_job.call_args_list
867869
ims.sagemaker_client.create_model.assert_called_with(ExecutionRoleArn=EXPANDED_ROLE,
868870
ModelName=JOB_NAME,
869871
PrimaryContainer=PRIMARY_CONTAINER,
870-
VpcConfig=VPC_CONFIG)
872+
VpcConfig=VPC_CONFIG,
873+
Tags=TAGS)
871874

872875

873876
def test_create_model_from_job_with_image(sagemaker_session):
874877
ims = sagemaker_session
875878
ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT
879+
ims.sagemaker_client.list_tags.return_value = {'Tags': TAGS}
876880
ims.create_model_from_job(JOB_NAME, primary_container_image='some-image')
877881
[create_model_call] = ims.sagemaker_client.create_model.call_args_list
878882
assert dict(create_model_call[1]['PrimaryContainer'])['Image'] == 'some-image'
@@ -881,6 +885,7 @@ def test_create_model_from_job_with_image(sagemaker_session):
881885
def test_create_model_from_job_with_container_def(sagemaker_session):
882886
ims = sagemaker_session
883887
ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT
888+
ims.sagemaker_client.list_tags.return_value = {'Tags': TAGS}
884889
ims.create_model_from_job(JOB_NAME, primary_container_image='some-image', model_data_url='some-data',
885890
env={'a': 'b'})
886891
[create_model_call] = ims.sagemaker_client.create_model.call_args_list
@@ -895,6 +900,7 @@ def test_create_model_from_job_with_vpc_config_override(sagemaker_session):
895900

896901
ims = sagemaker_session
897902
ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT
903+
ims.sagemaker_client.list_tags.return_value = {'Tags': TAGS}
898904
ims.create_model_from_job(JOB_NAME, vpc_config_override=vpc_config_override)
899905
assert ims.sagemaker_client.create_model.call_args[1]['VpcConfig'] == vpc_config_override
900906

0 commit comments

Comments
 (0)