Skip to content

Commit e7e9b46

Browse files
authored
PyTorch Distributed Training Support & Test (aws#272)
* Initial commit * Add test for PT-DT * Rename file * Remove extraneous file * Add function back * Add check for horovod.torch * Fix f-string * catch ImportError * Remove try-catch * Address Rahul's comment * Check that torch.distributed is available * Add trial.workers() check * Wrap race condition * Add reset_collections() to test cases * peace offering to CI
1 parent 4f8dd64 commit e7e9b46

File tree

8 files changed

+226
-23
lines changed

8 files changed

+226
-23
lines changed
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
"""
2+
Tests core functionality of naming workers when there are multiple processes.
3+
See https://pytorch.org/tutorials/intermediate/ddp_tutorial.html to decide
4+
how we want to support DistributedDataParallel with limited user configuration.
5+
6+
The key methods are
7+
torch.distributed.get_rank() - when manually spawning processes
8+
"""
9+
import numpy as nn
10+
import os
11+
import torch
12+
import torch.distributed as dist
13+
from torch.multiprocessing import Process
14+
import torch.nn as nn
15+
from torch.nn.parallel import DistributedDataParallel as DDP
16+
import torch.nn.functional as F
17+
import torch.optim as optim
18+
import shutil
19+
20+
21+
import tornasole.pytorch as ts
22+
from tornasole.trials import Trial, create_trial
23+
24+
out_dir = '/tmp/run'
25+
26+
class Net(nn.Module):
27+
"""Returns f(x) = sigmoid(w*x + b)"""
28+
def __init__(self):
29+
super().__init__()
30+
self.add_module('fc', nn.Linear(1, 1))
31+
32+
def forward(self, x):
33+
x = self.fc(x)
34+
x = F.sigmoid(x)
35+
return x
36+
37+
def dataset(batch_size=4):
38+
"""Return a dataset of (data, target)."""
39+
data = torch.rand(batch_size, 1)
40+
target = F.sigmoid(2 * data + 1)
41+
return data, target
42+
43+
def train(model, device, optimizer, num_steps=10):
44+
"""Runs the training loop, no explicit Tornasole here."""
45+
model.train()
46+
for i in range(num_steps):
47+
batch_size = 4
48+
data = torch.rand(batch_size, 1)
49+
target = F.sigmoid(2 * data + 1)
50+
data, target = data.to(device), target.to(device)
51+
optimizer.zero_grad()
52+
output = model(data)
53+
loss = F.mse_loss(output, target)
54+
loss.backward()
55+
optimizer.step()
56+
57+
58+
59+
def run(rank, size, num_epochs=10, batch_size=128, num_batches=10):
60+
"""Distributed function to be implemented later."""
61+
torch.manual_seed(1234)
62+
device = torch.device('cpu')
63+
model = Net().to(device)
64+
optimizer = optim.SGD(model.parameters(), lr=1)
65+
66+
shutil.rmtree(out_dir, ignore_errors=True)
67+
hook = ts.TornasoleHook(
68+
out_dir=out_dir,
69+
save_config=ts.SaveConfig(save_steps=[0, 1, 5]),
70+
save_all=True,
71+
)
72+
hook.register_hook(model)
73+
74+
for epoch in range(num_epochs):
75+
epoch_loss = 0.0
76+
for _ in range(num_batches):
77+
optimizer.zero_grad()
78+
data, target = dataset(batch_size)
79+
output = model(data)
80+
loss = F.mse_loss(output, target)
81+
epoch_loss += loss.item()
82+
loss.backward()
83+
average_gradients(model)
84+
optimizer.step()
85+
# print(f"Rank {dist.get_rank()}, epoch {epoch}: {epoch_loss / num_batches}")
86+
87+
assert hook.get_worker_name() == f"worker_{dist.get_rank()}"
88+
# Race condition here where both workers attempt to move
89+
# /tmp/{out_dir}/END_OF_JOB.ts to {out_dir}/END_OF_JOB.ts
90+
try:
91+
hook._cleanup()
92+
except FileNotFoundError:
93+
pass
94+
95+
def average_gradients(model):
96+
"""Gradient averaging."""
97+
size = float(dist.get_world_size())
98+
for param in model.parameters():
99+
dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM)
100+
param.grad.data /= size
101+
102+
def init_processes(rank, size, fn, backend='gloo'):
103+
"""Initialize the distributed environment."""
104+
os.environ['MASTER_ADDR'] = '127.0.0.1'
105+
os.environ['MASTER_PORT'] = '29500'
106+
dist.init_process_group(backend, rank=rank, world_size=size)
107+
fn(rank, size)
108+
109+
def test_run_net_single_process():
110+
"""Runs a single linear layer."""
111+
ts.reset_collections()
112+
device = torch.device('cpu')
113+
model = Net().to(device)
114+
optimizer = optim.SGD(model.parameters(), lr=0.01)
115+
116+
shutil.rmtree(out_dir, ignore_errors=True)
117+
hook = ts.TornasoleHook(
118+
out_dir=out_dir,
119+
save_config=ts.SaveConfig(save_steps=[0, 1, 5]),
120+
save_all=True,
121+
)
122+
hook.register_hook(model)
123+
train(model=model, device=device, optimizer=optimizer)
124+
hook._cleanup()
125+
126+
assert hook.get_worker_name() == "worker_0"
127+
128+
trial = create_trial(path=out_dir)
129+
assert len(trial.workers()) == 1, f"trial.workers() = {trial.workers()}"
130+
assert len(trial.steps()) == 3, f"trial.steps() = {trial.steps()}"
131+
shutil.rmtree(out_dir, ignore_errors=True)
132+
133+
def test_run_net_distributed():
134+
"""Runs a single linear layer on 2 processes."""
135+
# torch.distributed is empty on Mac on Torch <= 1.2
136+
if not hasattr(dist, 'is_initialized'):
137+
return
138+
139+
ts.reset_collections()
140+
size = 2
141+
processes = []
142+
for rank in range(size):
143+
p = Process(target=init_processes, args=(rank, size, run))
144+
p.start()
145+
processes.append(p)
146+
147+
for p in processes:
148+
p.join()
149+
150+
# WARNING: assert statements do not cause test failure inside subprocesses
151+
# https://stackoverflow.com/questions/13400546/py-test-how-to-automatically-detect-an-exception-in-a-child-process
152+
assert all([not p.exitcode for p in processes]), f"Some processes failed. processes={processes}"
153+
154+
out_dir = '/tmp/run'
155+
trial = create_trial(path=out_dir)
156+
assert len(trial.workers()) == 2, f"trial.workers() = {trial.workers()}"
157+
assert len(trial.steps()) == 3, f"trial.steps() = {trial.steps()}"

tests/pytorch/test_loss.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def forward(self, x):
3030

3131
def test_register_loss():
3232
"""Test that the loss is saved as a tensor."""
33+
ts.reset_collections()
3334
out_dir = '/tmp/pytorch_test_loss'
3435
shutil.rmtree(out_dir, ignore_errors=True)
3536

@@ -63,7 +64,7 @@ def test_register_loss():
6364
# (like we do here). Then it'll crash, likewise in a Jupyter notebook.
6465
hook._cleanup()
6566

66-
trial = create_trial(path=out_dir, name='run')
67+
trial = create_trial(path=out_dir)
6768
loss_coll = hook.collection_manager.get('losses')
6869
assert len(loss_coll.get_tensor_names()) == 3
6970

tests/pytorch/test_modes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torch.autograd import Variable
88
from tornasole import modes, SaveConfig, SaveConfigMode
99
from tornasole.pytorch.hook import *
10-
from tornasole.pytorch.torch_collection import *
10+
from tornasole.pytorch.collection import *
1111
from tornasole.pytorch import reset_collections
1212
from tornasole.core.json_config import TORNASOLE_CONFIG_FILE_PATH_ENV_STR
1313
import uuid

tests/pytorch/test_simple_write.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from tornasole import SaveConfig
1010
from tornasole.pytorch.hook import *
11-
from tornasole.pytorch.torch_collection import *
11+
from tornasole.pytorch.collection import *
1212
from tornasole.pytorch import reset_collections
1313
from tornasole.core.json_config import TORNASOLE_CONFIG_FILE_PATH_ENV_STR
1414
import uuid

tornasole/core/index_reader.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import numpy as np
12
import os
23
import json
4+
from typing import Any, Dict, List, Tuple
35
from tornasole.core.locations import TensorLocation, IndexFileLocationUtils
46
from tornasole.core.s3_utils import list_s3_objects
57
from tornasole.core.access_layer.s3handler import ReadObjectRequest, S3Handler
@@ -65,7 +67,14 @@ def list_index_files_in_dir(dirname):
6567
return sorted(index_files)
6668

6769
@staticmethod
68-
def get_disk_responses(path, start_after_key=0, range_steps=None):
70+
def get_disk_responses(path, start_after_key=0, range_steps=None) -> Tuple[List[bytes], List[int], int]:
71+
"""Read files like `trial_{datetime}/index/000/{step}_{worker}.json.
72+
73+
Returns:
74+
responses: List of the contents of each file, encoded as bytes.
75+
steps: List of steps read.
76+
start_after_key: An int referring where to start reading next time.
77+
"""
6978
index_files = LocalIndexReader.list_index_files_in_dir(path)
7079
steps = []
7180
workers = []
@@ -86,7 +95,7 @@ def get_disk_responses(path, start_after_key=0, range_steps=None):
8695
class IndexReader:
8796

8897
@staticmethod
89-
def fetch_tensor_value(tensor_location):
98+
def fetch_tensor_value(tensor_location: TensorLocation) -> np.ndarray:
9099
event_file_name = tensor_location.event_file_name
91100
start = tensor_location.start_idx
92101
length = tensor_location.length
@@ -107,7 +116,8 @@ def fetch_tensor_value(tensor_location):
107116
return tensor_data
108117

109118
@staticmethod
110-
def load_tensor_data_from_index_files(path, start_after_key=None, range_steps=None):
119+
def load_tensor_data_from_index_files(path, start_after_key=None, range_steps=None) -> Tuple[Dict[str, Dict[int, Dict[str, TensorLocation]]], int]:
120+
"""Return a triply nested dict referring to tensor data."""
111121
s3, bucket_name, prefix_name = is_s3(path)
112122
if s3:
113123
if start_after_key == 0:
@@ -130,7 +140,22 @@ def _validate(index_dict):
130140
raise IndexReaderException('tensor_payload section is not present')
131141

132142
@staticmethod
133-
def _update_tensors_from_json(index_tensors_dict, step, response, path, worker):
143+
def _update_tensors_from_json(index_tensors_dict, step, response: bytes, path, worker) -> Dict[str, Dict[int, Dict[str, TensorLocation]]]:
144+
"""Return a triply nested dict referring to tensor data.
145+
146+
Example:
147+
{
148+
'dense/bias:0': {
149+
0: {
150+
'tensor_location': <TensorLocation object>
151+
},
152+
2: { ... },
153+
...
154+
},
155+
'conv2d/kernel:0': { ... },
156+
...
157+
}
158+
"""
134159
index_dict = json.loads(response)
135160
IndexReader._validate(index_dict)
136161
index_meta = index_dict['meta']

tornasole/pytorch/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .hook import TornasoleHook
2-
from .torch_collection import Collection, CollectionManager
2+
from .collection import Collection, CollectionManager
33

4-
from .torch_collection import get_collections, get_collection, \
4+
from .collection import get_collections, get_collection, \
55
load_collections, \
66
add_to_collection, add_to_default_collection, reset_collections
77
from tornasole import SaveConfig, SaveConfigMode, ReductionConfig

tornasole/pytorch/torch_collection.py renamed to tornasole/pytorch/collection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,4 @@ def get_collection(collection_name):
6262
return _collection_manager.get(collection_name, create=True)
6363

6464
def get_collections():
65-
return _collection_manager.collections
65+
return _collection_manager.collections

tornasole/pytorch/hook.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import importlib
12
import torch
3+
import torch.distributed as dist
24
import logging
35
from tornasole.core.hook import CallbackHook
46
from tornasole.core.collection import CollectionKeys
57
from tornasole.core.logger import get_logger
68
from tornasole.core.json_config import create_hook_from_json_config
7-
from tornasole.pytorch.torch_collection import get_collection_manager
9+
from tornasole.pytorch.collection import get_collection_manager
810
from tornasole.pytorch.utils import get_reduction_of_data, make_numpy_array
911
from tornasole.core.json_config import TORNASOLE_CONFIG_DEFAULT_WORKER_NAME
1012

@@ -45,20 +47,38 @@ def __init__(self,
4547
self.module_maps = dict()
4648

4749
def get_num_workers(self):
48-
try:
49-
import horovod.torch as hvd
50-
if hvd.size():
51-
return hvd.size()
52-
except (ModuleNotFoundError, ValueError, ImportError):
53-
return 1
50+
"""Check horovod and torch.distributed."""
51+
# Try torch.distributed
52+
# torch.distributed is empty on Mac on Torch <= 1.2
53+
if hasattr(dist, 'is_initialized') and dist.is_initialized():
54+
return torch.distributed.get_world_size()
55+
# Try horovod
56+
else:
57+
try:
58+
import horovod.torch as hvd
59+
if hvd.size():
60+
return hvd.size()
61+
except (ModuleNotFoundError, ValueError, ImportError):
62+
pass
63+
# Return default
64+
return 1
5465

5566
def get_worker_name(self):
56-
try:
57-
import horovod.torch as hvd
58-
if hvd.size():
59-
return f'worker_{hvd.rank()}'
60-
except (ModuleNotFoundError, ValueError, ImportError):
61-
return TORNASOLE_CONFIG_DEFAULT_WORKER_NAME
67+
"""Check horovod and torch.distributed."""
68+
# Try torch.distributed
69+
# torch.distributed is empty on Mac on Torch <= 1.2
70+
if hasattr(dist, 'is_initialized') and dist.is_initialized():
71+
return f"worker_{dist.get_rank()}"
72+
# Try horovod
73+
else:
74+
try:
75+
import horovod.torch as hvd
76+
if hvd.size():
77+
return f"worker_{hvd.rank()}"
78+
except (ModuleNotFoundError, ValueError, ImportError):
79+
pass
80+
# Return default
81+
return TORNASOLE_CONFIG_DEFAULT_WORKER_NAME
6282

6383
@classmethod
6484
def hook_from_config(cls):

0 commit comments

Comments
 (0)