Skip to content

Commit 167ae8e

Browse files
authored
Add PyTorch hyperparameter tuning integ test (#318)
1 parent 04ead1f commit 167ae8e

File tree

2 files changed

+54
-6
lines changed

2 files changed

+54
-6
lines changed

tests/data/pytorch_mnist/mnist.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@ def forward(self, x):
3939
return F.log_softmax(x, dim=1)
4040

4141

42-
def _get_train_data_loader(training_dir, is_distributed, **kwargs):
42+
def _get_train_data_loader(training_dir, is_distributed, batch_size, **kwargs):
4343
logger.info('Get train data loader')
4444
dataset = datasets.MNIST(training_dir, train=True, transform=transforms.Compose([
4545
transforms.ToTensor(),
4646
transforms.Normalize((0.1307,), (0.3081,))
4747
]))
4848
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if is_distributed else None
49-
train_loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=train_sampler is None,
49+
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=train_sampler is None,
5050
sampler=train_sampler, **kwargs)
5151
return train_sampler, train_loader
5252

@@ -94,7 +94,7 @@ def train(args):
9494
if use_cuda:
9595
torch.cuda.manual_seed(seed)
9696

97-
train_sampler, train_loader = _get_train_data_loader(args.data_dir, is_distributed, **kwargs)
97+
train_sampler, train_loader = _get_train_data_loader(args.data_dir, is_distributed, args.batch_size, **kwargs)
9898
test_loader = _get_test_data_loader(args.data_dir, **kwargs)
9999

100100
logger.debug('Processes {}/{} ({:.0f}%) of train data'.format(
@@ -142,9 +142,11 @@ def train(args):
142142
logger.debug('Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}'.format(
143143
epoch, batch_idx * len(data), len(train_loader.sampler),
144144
100. * batch_idx / len(train_loader), loss.item()))
145-
test(model, test_loader, device)
145+
accuracy = test(model, test_loader, device)
146146
save_model(model, args.model_dir)
147147

148+
logger.debug('Overall test accuracy: {}'.format(accuracy))
149+
148150

149151
def test(model, test_loader, device):
150152
model.eval()
@@ -159,9 +161,12 @@ def test(model, test_loader, device):
159161
correct += pred.eq(target.view_as(pred)).sum().item()
160162

161163
test_loss /= len(test_loader.dataset)
164+
accuracy = 100. * correct / len(test_loader.dataset)
165+
162166
logger.debug('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
163-
test_loss, correct, len(test_loader.dataset),
164-
100. * correct / len(test_loader.dataset)))
167+
test_loss, correct, len(test_loader.dataset), accuracy))
168+
169+
return accuracy
165170

166171

167172
def model_fn(model_dir):
@@ -181,6 +186,7 @@ def save_model(model, model_dir):
181186
if __name__ == '__main__':
182187
parser = argparse.ArgumentParser()
183188
parser.add_argument('--epochs', type=int, default=1, metavar='N')
189+
parser.add_argument('--batch-size', type=int, default=64, metavar='N')
184190

185191
# Container environment
186192
parser.add_argument('--hosts', type=list, default=json.loads(os.environ['SM_HOSTS']))

tests/integ/test_tuner.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from sagemaker.estimator import Estimator
3232
from sagemaker.mxnet.estimator import MXNet
3333
from sagemaker.predictor import json_deserializer
34+
from sagemaker.pytorch import PyTorch
3435
from sagemaker.tensorflow import TensorFlow
3536
from sagemaker.tuner import IntegerParameter, ContinuousParameter, CategoricalParameter, HyperparameterTuner
3637
from tests.integ import DATA_DIR
@@ -314,6 +315,47 @@ def test_tuning_chainer(sagemaker_session):
314315
assert len(output) == batch_size
315316

316317

318+
@pytest.mark.continuous_testing
319+
def test_attach_tuning_pytorch(sagemaker_session):
320+
mnist_dir = os.path.join(DATA_DIR, 'pytorch_mnist')
321+
mnist_script = os.path.join(mnist_dir, 'mnist.py')
322+
323+
estimator = PyTorch(entry_point=mnist_script, role='SageMakerRole', train_instance_count=1,
324+
train_instance_type='ml.c4.xlarge', sagemaker_session=sagemaker_session)
325+
326+
with timeout(minutes=15):
327+
objective_metric_name = 'evaluation-accuracy'
328+
metric_definitions = [{'Name': 'evaluation-accuracy', 'Regex': 'Overall test accuracy: (\d+)'}]
329+
hyperparameter_ranges = {'batch-size': IntegerParameter(50, 100)}
330+
331+
tuner = HyperparameterTuner(estimator, objective_metric_name, hyperparameter_ranges, metric_definitions,
332+
max_jobs=2, max_parallel_jobs=2)
333+
334+
training_data = estimator.sagemaker_session.upload_data(path=os.path.join(mnist_dir, 'training'),
335+
key_prefix='integ-test-data/pytorch_mnist/training')
336+
tuner.fit({'training': training_data})
337+
338+
tuning_job_name = tuner.latest_tuning_job.name
339+
340+
print('Started hyperparameter tuning job with name:' + tuning_job_name)
341+
342+
time.sleep(15)
343+
tuner.wait()
344+
345+
attached_tuner = HyperparameterTuner.attach(tuning_job_name, sagemaker_session=sagemaker_session)
346+
best_training_job = tuner.best_training_job()
347+
with timeout_and_delete_endpoint_by_name(best_training_job, sagemaker_session, minutes=20):
348+
predictor = attached_tuner.deploy(1, 'ml.c4.xlarge')
349+
data = np.zeros(shape=(1, 1, 28, 28), dtype=np.float32)
350+
predictor.predict(data)
351+
352+
batch_size = 100
353+
data = np.random.rand(batch_size, 1, 28, 28).astype(np.float32)
354+
output = predictor.predict(data)
355+
356+
assert output.shape == (batch_size, 10)
357+
358+
317359
@pytest.mark.continuous_testing
318360
def test_tuning_byo_estimator(sagemaker_session):
319361
"""Use Factorization Machines algorithm as an example here.

0 commit comments

Comments
 (0)