Skip to content

Commit 9f55759

Browse files
imujjwal96pengk19
authored andcommitted
feature: emit estimator transformer tags to model (aws#815)
1 parent b7495cc commit 9f55759

File tree

5 files changed

+113
-10
lines changed

5 files changed

+113
-10
lines changed

src/sagemaker/estimator.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -500,15 +500,16 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
500500
volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
501501
compute instance (default: None).
502502
"""
503+
tags = tags or self.tags
504+
503505
if self.latest_training_job is not None:
504-
model_name = self.sagemaker_session.create_model_from_job(self.latest_training_job.name, role=role)
506+
model_name = self.sagemaker_session.create_model_from_job(self.latest_training_job.name, role=role,
507+
tags=tags)
505508
else:
506509
logging.warning('No finished training job found associated with this estimator. Please make sure'
507510
'this estimator is only used for building workflow config')
508511
model_name = self._current_job_name
509512

510-
tags = tags or self.tags
511-
512513
return Transformer(model_name, instance_count, instance_type, strategy=strategy, assemble_with=assemble_with,
513514
output_path=output_path, output_kms_key=output_kms_key, accept=accept,
514515
max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
@@ -1061,7 +1062,8 @@ def transformer(self, instance_count, instance_type, strategy=None, assemble_wit
10611062
container_def = model.prepare_container_def(instance_type)
10621063
model_name = model.name or name_from_image(container_def['Image'])
10631064
vpc_config = model.vpc_config
1064-
self.sagemaker_session.create_model(model_name, role, container_def, vpc_config)
1065+
tags = tags or self.tags
1066+
self.sagemaker_session.create_model(model_name, role, container_def, vpc_config, tags=tags)
10651067
transform_env = model.env.copy()
10661068
if env is not None:
10671069
transform_env.update(env)

src/sagemaker/session.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,8 @@ def create_model(self, name, role, container_defs, vpc_config=None,
624624
return name
625625

626626
def create_model_from_job(self, training_job_name, name=None, role=None, primary_container_image=None,
627-
model_data_url=None, env=None, vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT):
627+
model_data_url=None, env=None, vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
628+
tags=None):
628629
"""Create an Amazon SageMaker ``Model`` from a SageMaker Training Job.
629630
630631
Args:
@@ -642,6 +643,8 @@ def create_model_from_job(self, training_job_name, name=None, role=None, primary
642643
Default: use VpcConfig from training job.
643644
* 'Subnets' (list[str]): List of subnet ids.
644645
* 'SecurityGroupIds' (list[str]): List of security group ids.
646+
tags(List[dict[str, str]]): Optional. The list of tags to add to the model. For more, see
647+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
645648
646649
Returns:
647650
str: The name of the created ``Model``.
@@ -655,7 +658,7 @@ def create_model_from_job(self, training_job_name, name=None, role=None, primary
655658
model_data_url=model_data_url or training_job['ModelArtifacts']['S3ModelArtifacts'],
656659
env=env)
657660
vpc_config = _vpc_config_from_training_job(training_job, vpc_config_override)
658-
return self.create_model(name, role, primary_container, vpc_config=vpc_config)
661+
return self.create_model(name, role, primary_container, vpc_config=vpc_config, tags=tags)
659662

660663
def create_model_package_from_algorithm(self, name, description, algorithm_arn, model_data):
661664
"""Create a SageMaker Model Package from the results of training with an Algorithm Package

tests/integ/test_transformer.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from sagemaker import KMeans
2323
from sagemaker.mxnet import MXNet
2424
from sagemaker.transformer import Transformer
25+
from sagemaker.estimator import Estimator
2526
from sagemaker.utils import unique_name_from_base
2627
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES, TRANSFORM_DEFAULT_TIMEOUT_MINUTES
2728
from tests.integ.kms_utils import get_or_create_kms_key
@@ -148,6 +149,89 @@ def test_transform_mxnet_vpc(sagemaker_session, mxnet_full_version):
148149
assert [security_group_id] == model_desc['VpcConfig']['SecurityGroupIds']
149150

150151

152+
def test_transform_mxnet_tags(sagemaker_session, mxnet_full_version):
153+
data_path = os.path.join(DATA_DIR, 'mxnet_mnist')
154+
script_path = os.path.join(data_path, 'mnist.py')
155+
tags = [{'Key': 'some-tag', 'Value': 'value-for-tag'}]
156+
157+
mx = MXNet(entry_point=script_path, role='SageMakerRole', train_instance_count=1,
158+
train_instance_type='ml.c4.xlarge', sagemaker_session=sagemaker_session,
159+
framework_version=mxnet_full_version)
160+
161+
train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'),
162+
key_prefix='integ-test-data/mxnet_mnist/train')
163+
test_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'),
164+
key_prefix='integ-test-data/mxnet_mnist/test')
165+
job_name = unique_name_from_base('test-mxnet-transform')
166+
167+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
168+
mx.fit({'train': train_input, 'test': test_input}, job_name=job_name)
169+
170+
transform_input_path = os.path.join(data_path, 'transform', 'data.csv')
171+
transform_input_key_prefix = 'integ-test-data/mxnet_mnist/transform'
172+
transform_input = mx.sagemaker_session.upload_data(path=transform_input_path,
173+
key_prefix=transform_input_key_prefix)
174+
175+
transformer = mx.transformer(1, 'ml.m4.xlarge', tags=tags)
176+
transformer.transform(transform_input, content_type='text/csv')
177+
178+
with timeout_and_delete_model_with_transformer(transformer, sagemaker_session,
179+
minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES):
180+
transformer.wait()
181+
model_desc = sagemaker_session.sagemaker_client.describe_model(ModelName=transformer.model_name)
182+
model_tags = sagemaker_session.sagemaker_client.list_tags(ResourceArn=model_desc['ModelArn'])['Tags']
183+
assert tags == model_tags
184+
185+
186+
def test_transform_byo_estimator(sagemaker_session):
187+
data_path = os.path.join(DATA_DIR, 'one_p_mnist')
188+
pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'}
189+
tags = [{'Key': 'some-tag', 'Value': 'value-for-tag'}]
190+
191+
# Load the data into memory as numpy arrays
192+
train_set_path = os.path.join(data_path, 'mnist.pkl.gz')
193+
with gzip.open(train_set_path, 'rb') as f:
194+
train_set, _, _ = pickle.load(f, **pickle_args)
195+
196+
kmeans = KMeans(role='SageMakerRole', train_instance_count=1,
197+
train_instance_type='ml.c4.xlarge', k=10, sagemaker_session=sagemaker_session,
198+
output_path='s3://{}/'.format(sagemaker_session.default_bucket()))
199+
200+
# set kmeans specific hp
201+
kmeans.init_method = 'random'
202+
kmeans.max_iterators = 1
203+
kmeans.tol = 1
204+
kmeans.num_trials = 1
205+
kmeans.local_init_method = 'kmeans++'
206+
kmeans.half_life_time_size = 1
207+
kmeans.epochs = 1
208+
209+
records = kmeans.record_set(train_set[0][:100])
210+
211+
job_name = unique_name_from_base('test-kmeans-attach')
212+
213+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
214+
kmeans.fit(records, job_name=job_name)
215+
216+
transform_input_path = os.path.join(data_path, 'transform_input.csv')
217+
transform_input_key_prefix = 'integ-test-data/one_p_mnist/transform'
218+
transform_input = kmeans.sagemaker_session.upload_data(path=transform_input_path,
219+
key_prefix=transform_input_key_prefix)
220+
221+
estimator = Estimator.attach(training_job_name=job_name,
222+
sagemaker_session=sagemaker_session)
223+
224+
transformer = estimator.transformer(1, 'ml.m4.xlarge', tags=tags)
225+
transformer.transform(transform_input, content_type='text/csv')
226+
227+
with timeout_and_delete_model_with_transformer(transformer, sagemaker_session,
228+
minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES):
229+
transformer.wait()
230+
model_desc = sagemaker_session.sagemaker_client.describe_model(ModelName=transformer.model_name)
231+
model_tags = sagemaker_session.sagemaker_client.list_tags(ResourceArn=model_desc['ModelArn'])['Tags']
232+
assert tags == model_tags
233+
234+
151235
def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None):
152236
transformer = estimator.transformer(1, 'ml.m4.xlarge', volume_kms_key=volume_kms_key)
153237
transformer.transform(transform_input, content_type='text/csv')

tests/unit/test_estimator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,7 @@ def test_framework_transformer_creation(name_from_image, sagemaker_session):
624624
transformer = fw.transformer(INSTANCE_COUNT, INSTANCE_TYPE)
625625

626626
name_from_image.assert_called_with(MODEL_IMAGE)
627-
sagemaker_session.create_model.assert_called_with(MODEL_IMAGE, ROLE, MODEL_CONTAINER_DEF, None)
627+
sagemaker_session.create_model.assert_called_with(MODEL_IMAGE, ROLE, MODEL_CONTAINER_DEF, None, tags=None)
628628

629629
assert isinstance(transformer, Transformer)
630630
assert transformer.sagemaker_session == sagemaker_session
@@ -659,7 +659,7 @@ def test_framework_transformer_creation_with_optional_params(name_from_image, sa
659659
max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
660660
volume_kms_key=kms_key, env=env, role=new_role, model_server_workers=1)
661661

662-
sagemaker_session.create_model.assert_called_with(MODEL_IMAGE, new_role, MODEL_CONTAINER_DEF, vpc_config)
662+
sagemaker_session.create_model.assert_called_with(MODEL_IMAGE, new_role, MODEL_CONTAINER_DEF, vpc_config, tags=TAGS)
663663
assert transformer.strategy == strategy
664664
assert transformer.assemble_with == assemble_with
665665
assert transformer.output_path == OUTPUT_PATH
@@ -698,7 +698,7 @@ def test_estimator_transformer_creation(sagemaker_session):
698698

699699
transformer = estimator.transformer(INSTANCE_COUNT, INSTANCE_TYPE)
700700

701-
sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=None)
701+
sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=None, tags=None)
702702
assert isinstance(transformer, Transformer)
703703
assert transformer.sagemaker_session == sagemaker_session
704704
assert transformer.instance_count == INSTANCE_COUNT
@@ -728,7 +728,7 @@ def test_estimator_transformer_creation_with_optional_params(sagemaker_session):
728728
max_concurrent_transforms=max_concurrent_transforms, max_payload=max_payload,
729729
env=env, role=ROLE)
730730

731-
sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=ROLE)
731+
sagemaker_session.create_model_from_job.assert_called_with(JOB_NAME, role=ROLE, tags=TAGS)
732732
assert transformer.strategy == strategy
733733
assert transformer.assemble_with == assemble_with
734734
assert transformer.output_path == OUTPUT_PATH

tests/unit/test_session.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ def test_s3_input_all_arguments():
291291
}
292292

293293
COMPLETED_DESCRIBE_JOB_RESULT = dict(DEFAULT_EXPECTED_TRAIN_JOB_ARGS)
294+
COMPLETED_DESCRIBE_JOB_RESULT.update({'TrainingJobArn': 'arn:aws:sagemaker:us-west-2:336:training-job/' + JOB_NAME})
294295
COMPLETED_DESCRIBE_JOB_RESULT.update({'TrainingJobStatus': 'Completed'})
295296
COMPLETED_DESCRIBE_JOB_RESULT.update(
296297
{'ModelArtifacts': {
@@ -870,6 +871,19 @@ def test_create_model_from_job(sagemaker_session):
870871
VpcConfig=VPC_CONFIG)
871872

872873

874+
def test_create_model_from_job_with_tags(sagemaker_session):
875+
ims = sagemaker_session
876+
ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT
877+
ims.create_model_from_job(JOB_NAME, tags=TAGS)
878+
879+
assert call(TrainingJobName=JOB_NAME) in ims.sagemaker_client.describe_training_job.call_args_list
880+
ims.sagemaker_client.create_model.assert_called_with(ExecutionRoleArn=EXPANDED_ROLE,
881+
ModelName=JOB_NAME,
882+
PrimaryContainer=PRIMARY_CONTAINER,
883+
VpcConfig=VPC_CONFIG,
884+
Tags=TAGS)
885+
886+
873887
def test_create_model_from_job_with_image(sagemaker_session):
874888
ims = sagemaker_session
875889
ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT

0 commit comments

Comments
 (0)