Skip to content

Commit 7330773

Browse files
author
Macdonald
committed
Adding more device option to kvstore. Fixing sentiment kvstore usage
1 parent 7ab4ad6 commit 7330773

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

sagemaker-python-sdk/mxnet_gluon_mnist/mnist.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import logging
44
import mxnet as mx
5-
from mxnet import gluon, autograd, kv
5+
from mxnet import gluon, autograd
66
from mxnet.gluon import nn
77
import numpy as np
88
import json
@@ -16,7 +16,7 @@
1616
# ------------------------------------------------------------ #
1717

1818

19-
def train(channel_input_dirs, hyperparameters, hosts, **kwargs):
19+
def train(channel_input_dirs, hyperparameters, hosts, num_gpus, **kwargs):
2020
# SageMaker passes num_cpus, num_gpus and other args we can use to tailor training to
2121
# the current container environment, but here we just use simple cpu context.
2222
ctx = mx.cpu()
@@ -41,10 +41,15 @@ def train(channel_input_dirs, hyperparameters, hosts, **kwargs):
4141
# Collect all parameters from net and its children, then initialize them.
4242
net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
4343
# Trainer is for updating parameters with gradient.
44-
store = kv.create('dist_sync' if len(hosts) > 1 else 'local')
44+
45+
if len(hosts) == 1:
46+
kvstore = 'device' if num_gpus > 0 else 'local'
47+
else:
48+
kvstore = 'dist_device_sync' if num_gpus > 0 else 'dist_sync'
49+
4550
trainer = gluon.Trainer(net.collect_params(), 'sgd',
4651
{'learning_rate': learning_rate, 'momentum': momentum},
47-
kvstore=store)
52+
kvstore=kvstore)
4853
metric = mx.metric.Accuracy()
4954
loss = gluon.loss.SoftmaxCrossEntropyLoss()
5055

sagemaker-python-sdk/mxnet_gluon_sentiment/sentiment.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def train(current_host, hosts, num_cpus, num_gpus, channel_input_dirs, model_dir
3131
if len(hosts) == 1:
3232
kvstore = 'device' if num_gpus > 0 else 'local'
3333
else:
34-
kvstore = 'dist_sync'
34+
kvstore = 'dist_device_sync' if num_gpus > 0 else 'dist_sync'
3535

3636
ctx = mx.gpu() if num_gpus > 0 else mx.cpu()
3737

@@ -56,7 +56,8 @@ def train(current_host, hosts, num_cpus, num_gpus, channel_input_dirs, model_dir
5656
net.initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)
5757
# Trainer is for updating parameters with gradient.
5858
trainer = gluon.Trainer(net.collect_params(), 'adam',
59-
{'learning_rate': learning_rate})
59+
{'learning_rate': learning_rate},
60+
kvstore=kvstore)
6061
metric = mx.metric.Accuracy()
6162
loss = gluon.loss.SoftmaxCrossEntropyLoss()
6263

0 commit comments

Comments
 (0)