2
2
3
3
import logging
4
4
import mxnet as mx
5
- from mxnet import gluon , autograd , kv
5
+ from mxnet import gluon , autograd
6
6
from mxnet .gluon import nn
7
7
import numpy as np
8
8
import json
16
16
# ------------------------------------------------------------ #
17
17
18
18
19
- def train (channel_input_dirs , hyperparameters , hosts , ** kwargs ):
19
+ def train (channel_input_dirs , hyperparameters , hosts , num_gpus , ** kwargs ):
20
20
# SageMaker passes num_cpus, num_gpus and other args we can use to tailor training to
21
21
# the current container environment, but here we just use simple cpu context.
22
22
ctx = mx .cpu ()
@@ -41,10 +41,15 @@ def train(channel_input_dirs, hyperparameters, hosts, **kwargs):
41
41
# Collect all parameters from net and its children, then initialize them.
42
42
net .initialize (mx .init .Xavier (magnitude = 2.24 ), ctx = ctx )
43
43
# Trainer is for updating parameters with gradient.
44
- store = kv .create ('dist_sync' if len (hosts ) > 1 else 'local' )
44
+
45
+ if len (hosts ) == 1 :
46
+ kvstore = 'device' if num_gpus > 0 else 'local'
47
+ else :
48
+ kvstore = 'dist_device_sync' if num_gpus > 0 else 'dist_sync'
49
+
45
50
trainer = gluon .Trainer (net .collect_params (), 'sgd' ,
46
51
{'learning_rate' : learning_rate , 'momentum' : momentum },
47
- kvstore = store )
52
+ kvstore = kvstore )
48
53
metric = mx .metric .Accuracy ()
49
54
loss = gluon .loss .SoftmaxCrossEntropyLoss ()
50
55
0 commit comments