Skip to content

Commit 8feff4a

Browse files
committed
Add integration test for transformer tags
1 parent 87904c4 commit 8feff4a

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

tests/integ/test_transformer.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,40 @@ def test_transform_mxnet_vpc(sagemaker_session, mxnet_full_version):
148148
assert [security_group_id] == model_desc['VpcConfig']['SecurityGroupIds']
149149

150150

151+
def test_transform_mxnet_logs(sagemaker_session, mxnet_full_version):
152+
data_path = os.path.join(DATA_DIR, 'mxnet_mnist')
153+
script_path = os.path.join(data_path, 'mnist.py')
154+
tags = [{'Key': 'some-tag', 'Value': 'value-for-tag'}]
155+
156+
mx = MXNet(entry_point=script_path, role='SageMakerRole', train_instance_count=1,
157+
train_instance_type='ml.c4.xlarge', sagemaker_session=sagemaker_session,
158+
framework_version=mxnet_full_version)
159+
160+
train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'),
161+
key_prefix='integ-test-data/mxnet_mnist/train')
162+
test_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'),
163+
key_prefix='integ-test-data/mxnet_mnist/test')
164+
job_name = unique_name_from_base('test-mxnet-transform')
165+
166+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
167+
mx.fit({'train': train_input, 'test': test_input}, job_name=job_name)
168+
169+
transform_input_path = os.path.join(data_path, 'transform', 'data.csv')
170+
transform_input_key_prefix = 'integ-test-data/mxnet_mnist/transform'
171+
transform_input = mx.sagemaker_session.upload_data(path=transform_input_path,
172+
key_prefix=transform_input_key_prefix)
173+
174+
transformer = mx.transformer(1, 'ml.m4.xlarge', tags=tags)
175+
transformer.transform(transform_input, content_type='text/csv')
176+
177+
with timeout_and_delete_model_with_transformer(transformer, sagemaker_session,
178+
minutes=TRANSFORM_DEFAULT_TIMEOUT_MINUTES):
179+
transformer.wait()
180+
model_desc = sagemaker_session.sagemaker_client.describe_model(ModelName=transformer.model_name)
181+
model_tags = sagemaker_session.sagemaker_client.list_tags(ResourceArn=model_desc['ModelArn'])['Tags']
182+
assert tags == model_tags
183+
184+
151185
def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None):
152186
transformer = estimator.transformer(1, 'ml.m4.xlarge', volume_kms_key=volume_kms_key)
153187
transformer.transform(transform_input, content_type='text/csv')

0 commit comments

Comments
 (0)