-
Notifications
You must be signed in to change notification settings - Fork 115
Summary on Supporting PyTorch
Kelang edited this page Jul 23, 2020
·
9 revisions
elasticdl train --image_name=elasticdl:mnist: tutorials/elasticdl_local.md
setup entry: elasticdl=elasticdl_client.main:main
run task (training/evaluation/prediction).Only calculate the gradient and report gradient to ps.
elastic/python/worker/worker.py
Push parameters to PS:elastic/python/worker/ps_client.py
Usually, we train in PyTorch with an optimizer
.
# training and testing
for epoch in range(EPOCH):
for step, (b_x, b_y) in enumerate(train_loader): # gives batch data, normalize x when iterate train_loader
output = cnn(b_x)[0] # cnn output
loss = loss_func(output, b_y) # cross entropy loss
optimizer.zero_grad() # clear gradients for this training step
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients