Skip to content

Commit a8603f6

Browse files
author
Ignacio Quintero
committed
Add docker-compose and unit tests.
docker-compose was not a dependency of the SDk but it is required for local mode.
1 parent a2eec9e commit a8603f6

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ def read(fname):
4444
],
4545

4646
# Declare minimal set for installation
47-
install_requires=['boto3>=1.4.8', 'numpy>=1.9.0', 'protobuf>=3.1', 'scipy>=0.19.0', 'urllib3>=1.2',
48-
'PyYAML>=3.2', 'protobuf3-to-dict>=0.1.5'],
47+
install_requires=['boto3>=1.4.8', 'numpy>=1.9.0', 'protobuf>=3.1', 'scipy>=0.19.0', 'urllib3 >=1.21, <1.23',
48+
'PyYAML>=3.2', 'protobuf3-to-dict>=0.1.5', 'docker-compose>=1.21.0'],
4949

5050
extras_require={
5151
'test': ['tox', 'flake8', 'pytest', 'pytest-cov', 'pytest-xdist',

tests/unit/test_image.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,27 @@ def test_train_local_code(download_folder, _cleanup, popen, _stream_output,
334334
assert '%s:/opt/ml/shared' % shared_folder_path in volumes
335335

336336

337+
def test_container_has_gpu_support(tmpdir, sagemaker_session):
338+
instance_count = 1
339+
image = 'my-image'
340+
sagemaker_container = _SageMakerContainer('local_gpu', instance_count, image,
341+
sagemaker_session=sagemaker_session)
342+
343+
docker_host = sagemaker_container._create_docker_host('host-1', {}, set(), 'train', [])
344+
assert 'runtime' in docker_host
345+
assert docker_host['runtime'] == 'nvidia'
346+
347+
348+
def test_container_does_not_enable_nvidia_docker_for_cpu_containers(tmpdir, sagemaker_session):
349+
instance_count = 1
350+
image = 'my-image'
351+
sagemaker_container = _SageMakerContainer('local', instance_count, image,
352+
sagemaker_session=sagemaker_session)
353+
354+
docker_host = sagemaker_container._create_docker_host('host-1', {}, set(), 'train', [])
355+
assert 'runtime' not in docker_host
356+
357+
337358
@patch('sagemaker.local.image._HostingContainer.run')
338359
@patch('shutil.copy')
339360
@patch('shutil.copytree')

0 commit comments

Comments
 (0)