Skip to content

Commit 4196324

Browse files
author
Ignacio Quintero
committed
Fix division for python3.
Also added a comment about ShardedByS3Key.
1 parent 99bbaf6 commit 4196324

File tree

3 files changed

+9
-5
lines changed

3 files changed

+9
-5
lines changed

sagemaker-python-sdk/mxnet_gluon_mnist/mnist.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,11 @@ def train(current_host, channel_input_dirs, hyperparameters, hosts, num_gpus):
5353
metric = mx.metric.Accuracy()
5454
loss = gluon.loss.SoftmaxCrossEntropyLoss()
5555

56-
# shard the training data in case we are doing
57-
# distributed training.
56+
# shard the training data in case we are doing distributed training. Alternatively to splitting in memory,
57+
# the data could be pre-split in S3 and use ShardedByS3Key to do distributed training.
5858
if len(hosts) > 1:
5959
train_data = [x for x in train_data]
60-
shard_size = len(train_data) / len(hosts)
60+
shard_size = len(train_data) // len(hosts)
6161
for i, host in enumerate(hosts):
6262
if host == current_host:
6363
start = shard_size * i

sagemaker-python-sdk/mxnet_gluon_sentiment/sentiment.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def train(current_host, hosts, num_cpus, num_gpus, channel_input_dirs, model_dir
4646
train_sentences = [[vocab.get(token, 1) for token in line if len(line)>0] for line in train_sentences]
4747
val_sentences = [[vocab.get(token, 1) for token in line if len(line)>0] for line in val_sentences]
4848

49-
shard_size = len(train_sentences) / len(hosts)
49+
# Alternatively to splitting in memory, the data could be pre-split in S3 and use ShardedByS3Key
50+
# to do parallel training.
51+
shard_size = len(train_sentences) // len(hosts)
5052
for i, host in enumerate(hosts):
5153
if host == current_host:
5254
start = shard_size * i

sagemaker-python-sdk/mxnet_mnist/mnist.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ def train(current_host, channel_input_dirs, hyperparameters, hosts, num_cpus, nu
3939
(train_labels, train_images) = load_data(os.path.join(channel_input_dirs['train']))
4040
(test_labels, test_images) = load_data(os.path.join(channel_input_dirs['test']))
4141

42-
shard_size = len(train_images) / len(hosts)
42+
# Alternatively to splitting in memory, the data could be pre-split in S3 and use ShardedByS3Key
43+
# to do parallel training.
44+
shard_size = len(train_images) // len(hosts)
4345
for i, host in enumerate(hosts):
4446
if host == current_host:
4547
start = shard_size * i

0 commit comments

Comments
 (0)