Skip to content

Commit 1b403e4

Browse files
committed
feature: handler for stopping transform job
1 parent d50b24f commit 1b403e4

File tree

4 files changed

+90
-16
lines changed

4 files changed

+90
-16
lines changed

src/sagemaker/session.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,6 +1007,27 @@ def wait_for_transform_job(self, job, poll=5):
10071007
self._check_job_status(job, desc, "TransformJobStatus")
10081008
return desc
10091009

1010+
def stop_transform_job(self, name):
1011+
"""Stop the Amazon SageMaker hyperparameter tuning job with the specified name.
1012+
1013+
Args:
1014+
name (str): Name of the Amazon SageMaker batch transform job.
1015+
1016+
Raises:
1017+
ClientError: If an error occurs while trying to stop the batch transform job.
1018+
"""
1019+
try:
1020+
LOGGER.info('Stopping transform job: {}'.format(name))
1021+
self.sagemaker_client.stop_transform_job(TransformJobName=name)
1022+
except ClientError as e:
1023+
error_code = e.response['Error']['Code']
1024+
# allow to pass if the job already stopped
1025+
if error_code == 'ValidationException':
1026+
LOGGER.info('Transform job: {} is already stopped or not running.'.format(name))
1027+
else:
1028+
LOGGER.error('Error occurred while attempting to stop transform job: {}.'.format(name))
1029+
raise
1030+
10101031
def _check_job_status(self, job, desc, status_key_name):
10111032
"""Check to see if the job completed successfully and, if not, construct and
10121033
raise a ValueError.

src/sagemaker/transformer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,12 @@ def wait(self):
204204
self._ensure_last_transform_job()
205205
self.latest_transform_job.wait()
206206

207+
def stop_transform_job(self):
208+
"""Stop latest running batch transform job.
209+
"""
210+
self._ensure_last_transform_job()
211+
self.latest_transform_job.stop()
212+
207213
def _ensure_last_transform_job(self):
208214
if self.latest_transform_job is None:
209215
raise ValueError("No transform job available")
@@ -303,6 +309,9 @@ def start_new(
303309
def wait(self):
304310
self.sagemaker_session.wait_for_transform_job(self.job_name)
305311

312+
def stop(self):
313+
self.sagemaker_session.stop_transform_job(name=self.job_name)
314+
306315
@staticmethod
307316
def _load_config(data, data_type, content_type, compression_type, split_type, transformer):
308317
input_config = _TransformJob._format_inputs_to_input_config(

tests/integ/test_transformer.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
import pickle
1818
import sys
19+
import time
1920

2021
import pytest
2122

@@ -301,20 +302,48 @@ def test_transform_byo_estimator(sagemaker_session):
301302
assert tags == model_tags
302303

303304

304-
def _create_transformer_and_transform_job(
305-
estimator,
306-
transform_input,
307-
volume_kms_key=None,
308-
input_filter=None,
309-
output_filter=None,
310-
join_source=None,
311-
):
312-
transformer = estimator.transformer(1, "ml.m4.xlarge", volume_kms_key=volume_kms_key)
313-
transformer.transform(
314-
transform_input,
315-
content_type="text/csv",
316-
input_filter=input_filter,
317-
output_filter=output_filter,
318-
join_source=join_source,
319-
)
305+
def test_stop_transform_job(sagemaker_session, mxnet_full_version):
306+
data_path = os.path.join(DATA_DIR, 'mxnet_mnist')
307+
script_path = os.path.join(data_path, 'mnist.py')
308+
tags = [{'Key': 'some-tag', 'Value': 'value-for-tag'}]
309+
310+
mx = MXNet(entry_point=script_path, role='SageMakerRole', train_instance_count=1,
311+
train_instance_type='ml.c4.xlarge', sagemaker_session=sagemaker_session,
312+
framework_version=mxnet_full_version)
313+
314+
train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'),
315+
key_prefix='integ-test-data/mxnet_mnist/train')
316+
test_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'test'),
317+
key_prefix='integ-test-data/mxnet_mnist/test')
318+
job_name = unique_name_from_base('test-mxnet-transform')
319+
320+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
321+
mx.fit({'train': train_input, 'test': test_input}, job_name=job_name)
322+
323+
transform_input_path = os.path.join(data_path, 'transform', 'data.csv')
324+
transform_input_key_prefix = 'integ-test-data/mxnet_mnist/transform'
325+
transform_input = mx.sagemaker_session.upload_data(path=transform_input_path,
326+
key_prefix=transform_input_key_prefix)
327+
328+
transformer = mx.transformer(1, 'ml.m4.xlarge', tags=tags)
329+
transformer.transform(transform_input, content_type='text/csv')
330+
331+
time.sleep(15)
332+
333+
latest_transform_job_name = transformer.latest_transform_job.name
334+
335+
print('Attempting to stop {}'.format(latest_transform_job_name))
336+
337+
transformer.stop_transform_job()
338+
339+
desc = transformer.latest_transform_job.sagemaker_session.sagemaker_client \
340+
.describe_transform_job(TransformJobName=latest_transform_job_name)
341+
assert desc['TransformJobStatus'] == 'Stopping'
342+
343+
344+
def _create_transformer_and_transform_job(estimator, transform_input, volume_kms_key=None,
345+
input_filter=None, output_filter=None, join_source=None):
346+
transformer = estimator.transformer(1, 'ml.m4.xlarge', volume_kms_key=volume_kms_key)
347+
transformer.transform(transform_input, content_type='text/csv', input_filter=input_filter,
348+
output_filter=output_filter, join_source=join_source)
320349
return transformer

tests/unit/test_transformer.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,3 +437,18 @@ def test_transform_job_wait(sagemaker_session):
437437
job.wait()
438438

439439
assert sagemaker_session.wait_for_transform_job.called_once
440+
441+
442+
def test_stop_transform_job(sagemaker_session, transformer):
443+
sagemaker_session.stop_transform_job = Mock(name='stop_transform_job')
444+
transformer.latest_transform_job = _TransformJob(sagemaker_session, JOB_NAME)
445+
446+
transformer.stop_transform_job()
447+
448+
sagemaker_session.stop_transform_job.assert_called_once_with(name=JOB_NAME)
449+
450+
451+
def test_stop_transform_job_no_transform_job(transformer):
452+
with pytest.raises(ValueError) as e:
453+
transformer.stop_transform_job()
454+
assert 'No transform job available' in str(e)

0 commit comments

Comments
 (0)