Skip to content

Commit fd8770e

Browse files
committed
feature: handler for stopping transform job
1 parent 686569e commit fd8770e

File tree

4 files changed

+85
-0
lines changed

4 files changed

+85
-0
lines changed

src/sagemaker/session.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,27 @@ def wait_for_transform_job(self, job, poll=5):
902902
self._check_job_status(job, desc, 'TransformJobStatus')
903903
return desc
904904

905+
def stop_transform_job(self, name):
906+
"""Stop the Amazon SageMaker hyperparameter tuning job with the specified name.
907+
908+
Args:
909+
name (str): Name of the Amazon SageMaker batch transform job.
910+
911+
Raises:
912+
ClientError: If an error occurs while trying to stop the batch transform job.
913+
"""
914+
try:
915+
LOGGER.info('Stopping transform job: {}'.format(name))
916+
self.sagemaker_client.stop_transform_job(TransformJobName=name)
917+
except ClientError as e:
918+
error_code = e.response['Error']['Code']
919+
# allow to pass if the job already stopped
920+
if error_code == 'ValidationException':
921+
LOGGER.info('Transform job: {} is already stopped or not running.'.format(name))
922+
else:
923+
LOGGER.error('Error occurred while attempting to stop transform job: {}.'.format(name))
924+
raise
925+
905926
def _check_job_status(self, job, desc, status_key_name):
906927
"""Check to see if the job completed successfully and, if not, construct and
907928
raise a ValueError.

src/sagemaker/transformer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,12 @@ def wait(self):
156156
self._ensure_last_transform_job()
157157
self.latest_transform_job.wait()
158158

159+
def stop_transform_job(self):
160+
"""Stop latest running batch transform job.
161+
"""
162+
self._ensure_last_transform_job()
163+
self.latest_transform_job.stop()
164+
159165
def _ensure_last_transform_job(self):
160166
if self.latest_transform_job is None:
161167
raise ValueError('No transform job available')
@@ -230,6 +236,9 @@ def start_new(cls, transformer, data, data_type, content_type, compression_type,
230236
def wait(self):
231237
self.sagemaker_session.wait_for_transform_job(self.job_name)
232238

239+
def stop(self):
240+
self.sagemaker_session.stop_transform_job(name=self.job_name)
241+
233242
@staticmethod
234243
def _load_config(data, data_type, content_type, compression_type, split_type, transformer):
235244
input_config = _TransformJob._format_inputs_to_input_config(data, data_type, content_type,

tests/integ/test_transformer.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
import pickle
1818
import sys
19+
import time
1920

2021
import pytest
2122

@@ -232,6 +233,45 @@ def test_transform_byo_estimator(sagemaker_session):
232233
assert tags == model_tags
233234

234235

236+
def test_stop_transform_job(sagemaker_session, mxnet_full_version):
237+
data_path = os.path.join(DATA_DIR, 'mxnet_mnist')
238+
script_path = os.path.join(data_path, 'mnist.py')
239+
tags = [{'Key': 'some-tag', 'Value': 'value-for-tag'}]
240+
241+
mx = MXNet(entry_point=script_path, role='SageMakerRole', train_instance_count=1,
242+
train_instance_type='ml.c4.xlarge', sagemaker_session=sagemaker_session,
243+
framework_version=mxnet_full_version)
244+
245+
train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'),
246+
key_prefix='integ-test-data/mxnet_mnist/train')
247+
test_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'),
248+
key_prefix='integ-test-data/mxnet_mnist/test')
249+
job_name = unique_name_from_base('test-mxnet-transform')
250+
251+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
252+
mx.fit({'train': train_input, 'test': test_input}, job_name=job_name)
253+
254+
transform_input_path = os.path.join(data_path, 'transform', 'data.csv')
255+
transform_input_key_prefix = 'integ-test-data/mxnet_mnist/transform'
256+
transform_input = mx.sagemaker_session.upload_data(path=transform_input_path,
257+
key_prefix=transform_input_key_prefix)
258+
259+
transformer = mx.transformer(1, 'ml.m4.xlarge', tags=tags)
260+
transformer.transform(transform_input, content_type='text/csv')
261+
262+
time.sleep(15)
263+
264+
latest_transform_job_name = transformer.latest_transform_job.name
265+
266+
print('Attempting to stop {}'.format(latest_transform_job_name))
267+
268+
transformer.stop_transform_job()
269+
270+
desc = transformer.latest_transform_job.sagemaker_session.sagemaker_client \
271+
.describe_transform_job(TransformJobName=latest_transform_job_name)
272+
assert desc['TransformJobStatus'] == 'Stopping'
273+
274+
235275
def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None):
236276
transformer = estimator.transformer(1, 'ml.m4.xlarge', volume_kms_key=volume_kms_key)
237277
transformer.transform(transform_input, content_type='text/csv')

tests/unit/test_transformer.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,3 +397,18 @@ def test_transform_job_wait(sagemaker_session):
397397
job.wait()
398398

399399
assert sagemaker_session.wait_for_transform_job.called_once
400+
401+
402+
def test_stop_transform_job(sagemaker_session, transformer):
403+
sagemaker_session.stop_transform_job = Mock(name='stop_transform_job')
404+
transformer.latest_transform_job = _TransformJob(sagemaker_session, JOB_NAME)
405+
406+
transformer.stop_transform_job()
407+
408+
sagemaker_session.stop_transform_job.assert_called_once_with(name=JOB_NAME)
409+
410+
411+
def test_stop_transform_job_no_transform_job(transformer):
412+
with pytest.raises(ValueError) as e:
413+
transformer.stop_transform_job()
414+
assert 'No transform job available' in str(e)

0 commit comments

Comments
 (0)