Skip to content

Commit 13a7027

Browse files
author
Ignacio Quintero
committed
Add docker-compose and unit tests.
docker-compose was not a dependency of the SDk but it is required for local mode.
1 parent a2eec9e commit 13a7027

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ def read(fname):
4444
],
4545

4646
# Declare minimal set for installation
47-
install_requires=['boto3>=1.4.8', 'numpy>=1.9.0', 'protobuf>=3.1', 'scipy>=0.19.0', 'urllib3>=1.2',
48-
'PyYAML>=3.2', 'protobuf3-to-dict>=0.1.5'],
47+
install_requires=['boto3>=1.4.8', 'numpy>=1.9.0', 'protobuf>=3.1', 'scipy>=0.19.0', 'urllib3 >=1.21, <1.23',
48+
'PyYAML>=3.2', 'protobuf3-to-dict>=0.1.5', 'docker-compose>=1.21.0'],
4949

5050
extras_require={
5151
'test': ['tox', 'flake8', 'pytest', 'pytest-cov', 'pytest-xdist',

tests/unit/test_image.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,28 @@ def test_train_local_code(download_folder, _cleanup, popen, _stream_output,
334334
assert '%s:/opt/ml/shared' % shared_folder_path in volumes
335335

336336

337+
def test_container_has_gpu_support(tmpdir, sagemaker_session):
338+
instance_count = 1
339+
image = 'my-image'
340+
directories = [str(tmpdir.mkdir('container-root')), str(tmpdir.mkdir('data'))]
341+
sagemaker_container = _SageMakerContainer('local_gpu', instance_count, image,
342+
sagemaker_session=sagemaker_session)
343+
344+
docker_host = sagemaker_container._create_docker_host('host-1', {}, set(), 'train', [])
345+
assert 'runtime' in docker_host
346+
assert docker_host['runtime'] == 'nvidia'
347+
348+
349+
def test_container_does_not_enable_nvidia_docker_for_cpu_containers(tmpdir, sagemaker_session):
350+
instance_count = 1
351+
image = 'my-image'
352+
sagemaker_container = _SageMakerContainer('local', instance_count, image,
353+
sagemaker_session=sagemaker_session)
354+
355+
docker_host = sagemaker_container._create_docker_host('host-1', {}, set(), 'train', [])
356+
assert 'runtime' not in docker_host
357+
358+
337359
@patch('sagemaker.local.image._HostingContainer.run')
338360
@patch('shutil.copy')
339361
@patch('shutil.copytree')

0 commit comments

Comments
 (0)