Skip to content

Commit 8e6c4f2

Browse files
authored
Fix Keras test (#132)
This test is only configured to run with 'local'. Change it to use the correct instance type accordingly.
1 parent 962f15b commit 8e6c4f2

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

test/integration/local/test_keras.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717

1818
import numpy as np
19+
import pytest
1920
from sagemaker.tensorflow import serving, TensorFlow
2021

2122
from test.integration import RESOURCE_PATH
@@ -24,15 +25,21 @@
2425
logging.basicConfig(level=logging.DEBUG)
2526

2627

27-
def test_keras_training(sagemaker_local_session, docker_image, tmpdir):
28+
@pytest.fixture
29+
def local_mode_instance_type(processor):
30+
instance_type = 'local' if processor == 'cpu' else 'local_gpu'
31+
return instance_type
32+
33+
34+
def test_keras_training(sagemaker_local_session, docker_image, tmpdir, local_mode_instance_type):
2835
entry_point = os.path.join(RESOURCE_PATH, 'keras_inception.py')
2936
output_path = 'file://{}'.format(tmpdir)
3037

3138
estimator = TensorFlow(
3239
entry_point=entry_point,
3340
role='SageMakerRole',
3441
train_instance_count=1,
35-
train_instance_type='local',
42+
train_instance_type=local_mode_instance_type,
3643
image_name=docker_image,
3744
sagemaker_session=sagemaker_local_session,
3845
model_dir='/opt/ml/model',

0 commit comments

Comments
 (0)