Skip to content

add chainer notebooks #265

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 31, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sagemaker-python-sdk/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ These examples focus on the Amazon SageMaker Python SDK which allows you to writ
- [cifar 10 with MXNet Gluon](mxnet_gluon_cifar10)
- [MNIST with MXNet Gluon](mxnet_gluon_mnist)
- [MNIST with MXNet](mxnet_mnist)
- [CIFAR-10 with Chainer and ChainerMN](chainer_cifar10)
- [Sentiment Analysis with Chainer](chainer_sentiment_analysis)
- [MNIST with Chainer](chainer_mnist)
- [Sentiment Analysis with MXNet Gluon](mxnet_gluon_sentiment)
- [TensorFlow Neural Networks with Layers](tensorflow_abalone_age_predictor_using_layers)
- [TensorFlow Networks with Keras](tensorflow_abalone_age_predictor_using_keras)
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
38 changes: 38 additions & 0 deletions sagemaker-python-sdk/chainer_cifar10/s3_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

import boto3
import tarfile
from urllib.parse import urlparse
import os

def retrieve_output_from_s3(s3_url, output_dir):
"""
Downloads output artifacts from s3 and extracts them into the given directory.

Args:
s3_url: S3 URL to the output artifacts
output_dir: directory to write artifacts to
"""
o = urlparse(s3_url)
s3 = boto3.resource('s3')
output_data_path = os.path.join(output_dir)
output_file_name = os.path.join(output_data_path, 'output.tar.gz')
try:
os.makedirs(output_data_path)
except FileExistsError:
pass
s3.Bucket(o.netloc).download_file(o.path.lstrip('/'), output_file_name)
tar = tarfile.open(output_file_name)
tar.extractall(output_data_path)
tar.close()
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from __future__ import print_function, absolute_import

import argparse
import os

import numpy as np

import chainer
import chainer.functions as F
import chainer.links as L
import chainermn
from chainer import initializers
from chainer import serializers
from chainer import training
from chainer.training import extensions

import net


if __name__=='__main__':

num_gpus = int(os.environ['SM_NUM_GPUS'])

parser = argparse.ArgumentParser()

# retrieve the hyperparameters we set from the client (with some defaults)
parser.add_argument('--epochs', type=int, default=30)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--learning-rate', type=float, default=0.05)
parser.add_argument('--communicator', type=str, default='pure_nccl' if num_gpus > 0 else 'naive')

# Data, model, and output directories. These are required.
parser.add_argument('--output-data-dir', type=str, default=os.environ['SM_OUTPUT_DATA_DIR'])
parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN'])
parser.add_argument('--test', type=str, default=os.environ['SM_CHANNEL_TEST'])

args, _ = parser.parse_known_args()

train_data = np.load(os.path.join(args.train, 'train.npz'))['data']
train_labels = np.load(os.path.join(args.train, 'train.npz'))['labels']

test_data = np.load(os.path.join(args.test, 'test.npz'))['data']
test_labels = np.load(os.path.join(args.test, 'test.npz'))['labels']

train = chainer.datasets.TupleDataset(train_data, train_labels)
test = chainer.datasets.TupleDataset(test_data, test_labels)

# Set up a neural network to train.
# Classifier reports softmax cross entropy loss and accuracy at every
# iteration, which will be used by the PrintReport extension below.
model = L.Classifier(net.VGG(10))

comm = chainermn.create_communicator(args.communicator)

# comm.inter_rank gives the rank of the node. This should only print on one node.
if comm.inter_rank == 0:
print('# Minibatch-size: {}'.format(args.batch_size))
print('# epoch: {}'.format(args.epochs))
print('# communicator: {}'.format(args.communicator))

# Set up a neural network to train.
# Classifier reports softmax cross entropy loss and accuracy at every
# iteration, which will be used by the PrintReport extension below.

# comm.intra_rank gives the rank of the process on a given node.
device = comm.intra_rank if num_gpus > 0 else -1
if device >= 0:
chainer.cuda.get_device_from_id(device).use()

optimizer = chainermn.create_multi_node_optimizer(chainer.optimizers.MomentumSGD(args.learning_rate), comm)
optimizer.setup(model)
optimizer.add_hook(chainer.optimizer.WeightDecay(5e-4))

num_loaders = 2
train_iter = chainer.iterators.MultiprocessIterator(train, args.batch_size, n_processes=num_loaders)
test_iter = chainer.iterators.MultiprocessIterator(test, args.batch_size, repeat=False, n_processes=num_loaders)

# Set up a trainer
updater = training.StandardUpdater(train_iter, optimizer, device=device)
trainer = training.Trainer(updater, (args.epochs, 'epoch'), out=args.output_data_dir)

# Evaluate the model with the test dataset for each epoch

evaluator = extensions.Evaluator(test_iter, model, device=device)
evaluator = chainermn.create_multi_node_evaluator(evaluator, comm)
trainer.extend(evaluator)

# Reduce the learning rate by half every 25 epochs.
trainer.extend(extensions.ExponentialShift('lr', 0.5), trigger=(25, 'epoch'))

# Dump a computational graph from 'loss' variable at the first iteration
# The "main" refers to the target link of the "main" optimizer.
trainer.extend(extensions.dump_graph('main/loss'))

# Write a log of evaluation statistics for each epoch
trainer.extend(extensions.LogReport())
if comm.rank == 0:
if extensions.PlotReport.available():
trainer.extend(
extensions.PlotReport(['main/loss', 'validation/main/loss'],
'epoch', file_name='loss.png'))
trainer.extend(
extensions.PlotReport(
['main/accuracy', 'validation/main/accuracy'],
'epoch', file_name='accuracy.png'))

trainer.extend(extensions.dump_graph('main/loss'))

trainer.extend(extensions.PrintReport(
['epoch', 'main/loss', 'validation/main/loss',
'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))

# Run the training
trainer.run()

# Save the model (only on one host).
if comm.rank == 0:
serializers.save_npz(os.path.join(args.model_dir, 'model.npz'), model)



def model_fn(model_dir):
"""
This function is called by the Chainer container during hosting when running on SageMaker with
values populated by the hosting environment.

This function loads models written during training into `model_dir`.


Args:
model_dir (str): path to the directory containing the saved model artifacts

Returns:
a loaded Chainer model

For more on `model_fn`, please visit the sagemaker-python-sdk repository:
https://github.com/aws/sagemaker-python-sdk

For more on the Chainer container, please visit the sagemaker-chainer-containers repository:
https://github.com/aws/sagemaker-chainer-containers
"""
chainer.config.train = False
model = L.Classifier(net.VGG(10))
serializers.load_npz(os.path.join(model_dir, 'model.npz'), model)
return model.predictor
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

from __future__ import print_function, absolute_import

import argparse
import os

import numpy as np

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from chainer import serializers
from chainer.training import extensions

import net

if __name__ =='__main__':

parser = argparse.ArgumentParser()

# retrieve the hyperparameters we set from the client (with some defaults)
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--learning-rate', type=float, default=0.05)

# Data, model, and output directories These are required.
parser.add_argument('--output-data-dir', type=str, default=os.environ['SM_OUTPUT_DATA_DIR'])
parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN'])
parser.add_argument('--test', type=str, default=os.environ['SM_CHANNEL_TEST'])

args, _ = parser.parse_known_args()

num_gpus = int(os.environ['SM_NUM_GPUS'])

train_data = np.load(os.path.join(args.train, 'train.npz'))['data']
train_labels = np.load(os.path.join(args.train, 'train.npz'))['labels']

test_data = np.load(os.path.join(args.test, 'test.npz'))['data']
test_labels = np.load(os.path.join(args.test, 'test.npz'))['labels']

train = chainer.datasets.TupleDataset(train_data, train_labels)
test = chainer.datasets.TupleDataset(test_data, test_labels)

print('# Minibatch-size: {}'.format(args.batch_size))
print('# epoch: {}'.format(args.epochs))
print('# learning rate: {}'.format(args.learning_rate))

# Set up a neural network to train.
# Classifier reports softmax cross entropy loss and accuracy at every
# iteration, which will be used by the PrintReport extension below.
model = L.Classifier(net.VGG(10))

optimizer = chainer.optimizers.MomentumSGD(args.learning_rate)
optimizer.setup(model)
optimizer.add_hook(chainer.optimizer.WeightDecay(5e-4))

# Set up a trainer
device = 0 if num_gpus > 0 else -1 # -1 indicates CPU, 0 indicates first GPU device.
if num_gpus > 1:
devices = range(num_gpus)
train_iters = [chainer.iterators.MultiprocessIterator(i, args.batch_size, n_processes=4) \
for i in chainer.datasets.split_dataset_n_random(train, len(devices))]
test_iter = chainer.iterators.MultiprocessIterator(test, args.batch_size, repeat=False, n_processes=num_gpus)
updater = training.updaters.MultiprocessParallelUpdater(train_iters, optimizer, devices=range(num_gpus))
else:
train_iter = chainer.iterators.MultiprocessIterator(train, args.batch_size)
test_iter = chainer.iterators.MultiprocessIterator(test, args.batch_size, repeat=False)
updater = training.updater.StandardUpdater(train_iter, optimizer, device=device)

stop_trigger = (args.epochs, 'epoch')
trainer = training.Trainer(updater, stop_trigger, out=args.output_data_dir)
# Evaluate the model with the test dataset for each epoch
trainer.extend(extensions.Evaluator(test_iter, model, device=device))

# Reduce the learning rate by half every 25 epochs.
trainer.extend(extensions.ExponentialShift('lr', 0.5), trigger=(25, 'epoch'))

# Dump a computational graph from 'loss' variable at the first iteration
# The "main" refers to the target link of the "main" optimizer.
trainer.extend(extensions.dump_graph('main/loss'))

# Write a log of evaluation statistics for each epoch
trainer.extend(extensions.LogReport())

if extensions.PlotReport.available():
trainer.extend(
extensions.PlotReport(['main/loss', 'validation/main/loss'],
'epoch', file_name='loss.png'))
trainer.extend(
extensions.PlotReport(
['main/accuracy', 'validation/main/accuracy'],
'epoch', file_name='accuracy.png'))

# Print selected entries of the log to stdout
# Here "main" refers to the target link of the "main" optimizer again, and
# "validation" refers to the default name of the Evaluator extension.
# Entries other than 'epoch' are reported by the Classifier link, called by
# either the updater or the evaluator.
trainer.extend(extensions.PrintReport(
['epoch', 'main/loss', 'validation/main/loss',
'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))

# Run the training
trainer.run()

# Save the model to model_dir. It's loaded below in `model_fn`.
serializers.save_npz(os.path.join(args.model_dir, 'model.npz'), model)


def model_fn(model_dir):
"""
This function is called by the Chainer container during hosting when running on SageMaker with
values populated by the hosting environment.

This function loads models written during training into `model_dir`.

Args:
model_dir (str): path to the directory containing the saved model artifacts

Returns:
a loaded Chainer model

For more on `model_fn`, please visit the sagemaker-python-sdk repository:
https://github.com/aws/sagemaker-python-sdk

For more on the Chainer container, please visit the sagemaker-chainer-containers repository:
https://github.com/aws/sagemaker-chainer-containers
"""
chainer.config.train = False
model = L.Classifier(net.VGG(10))
serializers.load_npz(os.path.join(model_dir, 'model.npz'), model)
return model.predictor
Loading