Skip to content

Commit dd9851f

Browse files
author
Ignacio Quintero
committed
Fix integ test + address some comments
1 parent 94c7ead commit dd9851f

File tree

4 files changed

+7
-8
lines changed

4 files changed

+7
-8
lines changed

src/sagemaker/local/entities.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,18 +285,18 @@ def _perform_batch_inference(self, input_data, output_data, **kwargs):
285285
copy_directory_structure(working_dir, relative_path)
286286
destination_path = os.path.join(working_dir, relative_path, filename + '.out')
287287

288-
with open(destination_path, 'w') as f:
288+
with open(destination_path, 'wb') as f:
289289
for item in batch_provider.pad(file, max_payload):
290290
# call the container and add the result to inference.
291291
response = self.local_session.sagemaker_runtime_client.invoke_endpoint(
292292
item, '', input_data['ContentType'], accept)
293293

294294
response_body = response['Body']
295-
data = response_body.read()
295+
data = response_body.read().strip()
296296
response_body.close()
297297
f.write(data)
298298
if 'AssembleWith' in output_data and output_data['AssembleWith'] == 'Line':
299-
f.write('\n')
299+
f.write(b'\n')
300300

301301
move_to_destination(working_dir, output_data['S3OutputPath'], self.local_session)
302302
self.container.stop_serving()

src/sagemaker/transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(self, model_name, instance_count, instance_type, strategy=None, ass
3131
instance_count (int): Number of EC2 instances to use.
3232
instance_type (str): Type of EC2 instance to use, for example, 'ml.c4.xlarge'.
3333
strategy (str): The strategy used to decide how to batch records in a single request (default: None).
34-
Valid values: 'MULTI_RECORD' and 'SINGLE_RECORD'.
34+
Valid values: 'MultiRecord' and 'SingleRecord'.
3535
assemble_with (str): How the output is assembled (default: None). Valid values: 'Line' or 'None'.
3636
output_path (str): S3 location for saving the transform result. If not specified, results are stored to
3737
a default bucket.

tests/integ/test_local_mode.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,8 @@ def test_local_transform_mxnet(sagemaker_local_session, tmpdir):
390390
key_prefix=transform_input_key_prefix)
391391

392392
output_path = 'file://%s' % (str(tmpdir))
393-
transformer = mx.transformer(1, 'local', assemble_with='Line', max_payload=1, output_path=output_path)
393+
transformer = mx.transformer(1, 'local', assemble_with='Line', max_payload=1,
394+
strategy='SingleRecord', output_path=output_path)
394395
transformer.transform(transform_input, content_type='text/csv', split_type='Line')
395396
transformer.wait()
396397

tests/unit/test_local_data.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,7 @@ def test_get_splitter_instance_with_valid_types():
103103

104104
def test_get_splitter_instance_with_invalid_types():
105105
with pytest.raises(ValueError):
106-
# something invalid
107-
sagemaker.local.data.get_splitter_instance('JSON')
106+
sagemaker.local.data.get_splitter_instance('SomethingInvalid')
108107

109108

110109
def test_none_splitter(tmpdir):
@@ -133,7 +132,6 @@ def test_line_splitter(tmpdir):
133132

134133

135134
def test_recordio_splitter(tmpdir):
136-
137135
test_file_path = tmpdir.join('recordio_test.txt')
138136
with test_file_path.open('wb') as f:
139137
for i in range(10):

0 commit comments

Comments
 (0)