|
| 1 | +import argparse |
| 2 | +import logging |
| 3 | +import os |
| 4 | +import zipfile |
| 5 | +import time |
| 6 | + |
| 7 | +import mxnet as mx |
| 8 | +import horovod.mxnet as hvd |
| 9 | +from mxnet import autograd, gluon, nd |
| 10 | +from mxnet.test_utils import download |
| 11 | + |
| 12 | +from tornasole import SaveConfig, modes |
| 13 | +from tornasole.mxnet import TornasoleHook |
| 14 | + |
| 15 | +# Training settings |
| 16 | +parser = argparse.ArgumentParser(description="MXNet MNIST Example") |
| 17 | + |
| 18 | +parser.add_argument("--batch-size", type=int, default=64, help="training batch size (default: 64)") |
| 19 | +parser.add_argument( |
| 20 | + "--dtype", type=str, default="float32", help="training data type (default: float32)" |
| 21 | +) |
| 22 | +parser.add_argument("--epochs", type=int, default=5, help="number of training epochs (default: 5)") |
| 23 | +parser.add_argument("--lr", type=float, default=0.01, help="learning rate (default: 0.01)") |
| 24 | +parser.add_argument("--momentum", type=float, default=0.9, help="SGD momentum (default: 0.9)") |
| 25 | +parser.add_argument( |
| 26 | + "--no-cuda", action="store_true", default=False, help="disable training on GPU (default: False)" |
| 27 | +) |
| 28 | +parser.add_argument( |
| 29 | + "--output-uri", |
| 30 | + type=str, |
| 31 | + default="/opt/ml/output/tensors/tornasole", |
| 32 | + help="S3 URI of the bucket where tensor data will be stored.", |
| 33 | +) |
| 34 | +args = parser.parse_args() |
| 35 | + |
| 36 | +if not args.no_cuda: |
| 37 | + # Disable CUDA if there are no GPUs. |
| 38 | + if not mx.test_utils.list_gpus(): |
| 39 | + args.no_cuda = True |
| 40 | + |
| 41 | +logging.basicConfig(level=logging.INFO) |
| 42 | +logging.info(args) |
| 43 | + |
| 44 | + |
| 45 | +# Function to get mnist iterator given a rank |
| 46 | +def get_mnist_iterator(rank): |
| 47 | + data_dir = "data-%d" % rank |
| 48 | + if not os.path.isdir(data_dir): |
| 49 | + os.makedirs(data_dir) |
| 50 | + zip_file_path = download("http://data.mxnet.io/mxnet/data/mnist.zip", dirname=data_dir) |
| 51 | + with zipfile.ZipFile(zip_file_path) as zf: |
| 52 | + zf.extractall(data_dir) |
| 53 | + |
| 54 | + input_shape = (1, 28, 28) |
| 55 | + batch_size = args.batch_size |
| 56 | + |
| 57 | + train_iter = mx.io.MNISTIter( |
| 58 | + image="%s/train-images-idx3-ubyte" % data_dir, |
| 59 | + label="%s/train-labels-idx1-ubyte" % data_dir, |
| 60 | + input_shape=input_shape, |
| 61 | + batch_size=batch_size, |
| 62 | + shuffle=True, |
| 63 | + flat=False, |
| 64 | + num_parts=hvd.size(), |
| 65 | + part_index=hvd.rank(), |
| 66 | + ) |
| 67 | + |
| 68 | + val_iter = mx.io.MNISTIter( |
| 69 | + image="%s/t10k-images-idx3-ubyte" % data_dir, |
| 70 | + label="%s/t10k-labels-idx1-ubyte" % data_dir, |
| 71 | + input_shape=input_shape, |
| 72 | + batch_size=batch_size, |
| 73 | + flat=False, |
| 74 | + ) |
| 75 | + |
| 76 | + return train_iter, val_iter |
| 77 | + |
| 78 | + |
| 79 | +# Function to define neural network |
| 80 | +def conv_nets(): |
| 81 | + net = gluon.nn.HybridSequential() |
| 82 | + with net.name_scope(): |
| 83 | + net.add(gluon.nn.Conv2D(channels=20, kernel_size=5, activation="relu")) |
| 84 | + net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2)) |
| 85 | + net.add(gluon.nn.Conv2D(channels=50, kernel_size=5, activation="relu")) |
| 86 | + net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2)) |
| 87 | + net.add(gluon.nn.Flatten()) |
| 88 | + net.add(gluon.nn.Dense(512, activation="relu")) |
| 89 | + net.add(gluon.nn.Dense(10)) |
| 90 | + return net |
| 91 | + |
| 92 | + |
| 93 | +# Function to evaluate accuracy for a model |
| 94 | +def evaluate(model, data_iter, context): |
| 95 | + data_iter.reset() |
| 96 | + metric = mx.metric.Accuracy() |
| 97 | + for _, batch in enumerate(data_iter): |
| 98 | + data = batch.data[0].as_in_context(context) |
| 99 | + label = batch.label[0].as_in_context(context) |
| 100 | + output = model(data.astype(args.dtype, copy=False)) |
| 101 | + metric.update([label], [output]) |
| 102 | + |
| 103 | + return metric.get() |
| 104 | + |
| 105 | + |
| 106 | +# Initialize Horovod |
| 107 | +hvd.init() |
| 108 | + |
| 109 | +# Horovod: pin context to local rank |
| 110 | +context = mx.cpu(hvd.local_rank()) if args.no_cuda else mx.gpu(hvd.local_rank()) |
| 111 | +num_workers = hvd.size() |
| 112 | + |
| 113 | +# Load training and validation data |
| 114 | +train_data, val_data = get_mnist_iterator(hvd.rank()) |
| 115 | + |
| 116 | +# Build model |
| 117 | +model = conv_nets() |
| 118 | +model.cast(args.dtype) |
| 119 | +model.hybridize() |
| 120 | + |
| 121 | +# Create optimizer |
| 122 | +optimizer_params = {"momentum": args.momentum, "learning_rate": args.lr * hvd.size()} |
| 123 | +opt = mx.optimizer.create("sgd", **optimizer_params) |
| 124 | +# opt = ts.TornasoleOptimizer(opt) |
| 125 | + |
| 126 | +# Initialize parameters |
| 127 | +initializer = mx.init.Xavier(rnd_type="gaussian", factor_type="in", magnitude=2) |
| 128 | +model.initialize(initializer, ctx=context) |
| 129 | + |
| 130 | +# Horovod: fetch and broadcast parameters |
| 131 | +params = model.collect_params() |
| 132 | +if params is not None: |
| 133 | + hvd.broadcast_parameters(params, root_rank=0) |
| 134 | + |
| 135 | +# Horovod: create DistributedTrainer, a subclass of gluon.Trainer |
| 136 | +trainer = hvd.DistributedTrainer(params, opt) |
| 137 | + |
| 138 | +# Create loss function and train metric |
| 139 | +loss_fn = gluon.loss.SoftmaxCrossEntropyLoss() |
| 140 | +metric = mx.metric.Accuracy() |
| 141 | + |
| 142 | + |
| 143 | +def create_tornasole_hook(): |
| 144 | + # With the following SaveConfig, we will save tensors for steps 1, 2 and 3 |
| 145 | + # (indexing starts with 0). |
| 146 | + save_config = SaveConfig(save_interval=1) |
| 147 | + |
| 148 | + # Create a hook that logs weights, biases and gradients while training the model. |
| 149 | + ts_hook = TornasoleHook( |
| 150 | + out_dir=args.output_uri, |
| 151 | + save_config=save_config, |
| 152 | + include_collections=["weights", "gradients", "biases"], |
| 153 | + ) |
| 154 | + return ts_hook |
| 155 | + |
| 156 | + |
| 157 | +# Train model |
| 158 | +for epoch in range(args.epochs): |
| 159 | + tic = time.time() |
| 160 | + train_data.reset() |
| 161 | + metric.reset() |
| 162 | + |
| 163 | + # Create Tornasole Hook |
| 164 | + hook = create_tornasole_hook() |
| 165 | + hook.register_hook(model) |
| 166 | + |
| 167 | + for nbatch, batch in enumerate(train_data, start=1): |
| 168 | + hook.set_mode(modes.TRAIN) |
| 169 | + data = batch.data[0].as_in_context(context) |
| 170 | + label = batch.label[0].as_in_context(context) |
| 171 | + with autograd.record(): |
| 172 | + output = model(data.astype(args.dtype, copy=False)) |
| 173 | + loss = loss_fn(output, label) |
| 174 | + loss.backward() |
| 175 | + trainer.step(args.batch_size) |
| 176 | + metric.update([label], [output]) |
| 177 | + |
| 178 | + if nbatch % 100 == 0: |
| 179 | + name, acc = metric.get() |
| 180 | + logging.info("[Epoch %d Batch %d] Training: %s=%f" % (epoch, nbatch, name, acc)) |
| 181 | + |
| 182 | + if hvd.rank() == 0: |
| 183 | + elapsed = time.time() - tic |
| 184 | + speed = nbatch * args.batch_size * hvd.size() / elapsed |
| 185 | + logging.info("Epoch[%d]\tSpeed=%.2f samples/s\tTime cost=%f", epoch, speed, elapsed) |
| 186 | + |
| 187 | + # Evaluate model accuracy |
| 188 | + hook.set_mode(modes.EVAL) |
| 189 | + _, train_acc = metric.get() |
| 190 | + name, val_acc = evaluate(model, val_data, context) |
| 191 | + if hvd.rank() == 0: |
| 192 | + logging.info( |
| 193 | + "Epoch[%d]\tTrain: %s=%f\tValidation: %s=%f", epoch, name, train_acc, name, val_acc |
| 194 | + ) |
| 195 | + |
| 196 | + if hvd.rank() == 0 and epoch == args.epochs - 1: |
| 197 | + assert val_acc > 0.96, ( |
| 198 | + "Achieved accuracy (%f) is lower than expected\ |
| 199 | + (0.96)" |
| 200 | + % val_acc |
| 201 | + ) |
0 commit comments