Skip to content

Commit e6c9ffc

Browse files
ruhanprasadprasadruatqy
authored
Add examples for using SMDDP in sharded data parallel training with FSDP and Deepspeed (#4466)
* Deepspeed and FSDP example notebooks * Make subnets and ecr image optional * fall back to nccl if smddp is not present * bug fixes * linting * grammar fixes * grammar fix * Fixing badging * Info about enabling SMDDP * permission changes * Header change --------- Co-authored-by: prasadru <[email protected]> Co-authored-by: atqy <[email protected]>
1 parent 62aaed1 commit e6c9ffc

File tree

10 files changed

+1443
-0
lines changed

10 files changed

+1443
-0
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
{
2+
"bf16": {
3+
"enabled": "auto"
4+
},
5+
"zero_optimization": {
6+
"stage": 3,
7+
"overlap_comm": true,
8+
"contiguous_gradients": true,
9+
"sub_group_size": 1e9,
10+
"reduce_bucket_size": 5e8,
11+
"stage3_max_live_parameters": 1e9,
12+
"stage3_max_reuse_distance": 1e9,
13+
"stage3_gather_16bit_weights_on_model_save": false
14+
},
15+
"gradient_accumulation_steps": 1,
16+
"steps_per_print": 2000,
17+
"train_micro_batch_size_per_gpu": 4,
18+
"wall_clock_breakdown": false
19+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
transformers==4.31
2+
datasets
3+
accelerate>=0.21
4+
bitsandbytes
5+
peft
6+
deepspeed==0.9.2
7+
Pydantic==1.10.13
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Pydantic==1.10.13
2+
deepspeed==0.9.2
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
import argparse
2+
import math
3+
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed, get_scheduler, SchedulerType
4+
from datasets import load_from_disk
5+
import torch
6+
import torch.distributed as dist
7+
8+
from utils import create_dataloaders, StubDataset
9+
import functools
10+
import deepspeed
11+
try:
12+
backend = "smddp"
13+
import smdistributed.dataparallel.torch.torch_smddp
14+
except ModuleNotFoundError:
15+
backend = "nccl"
16+
print("Warning: SMDDP not found on this image, falling back to NCCL!")
17+
18+
def parse_args():
19+
parser = argparse.ArgumentParser()
20+
parser.add_argument(
21+
"--model_id",
22+
type=str,
23+
default="meta-llama/Llama-2-7b-chat-hf",
24+
help="Model id to use for training.",
25+
)
26+
parser.add_argument("--epochs", type=int, default=2, help="Number of epochs to train for.")
27+
parser.add_argument("--max_steps", type=int, default=None, help="Number of epochs to train for.")
28+
parser.add_argument(
29+
"--batch_size",
30+
type=int,
31+
default=1,
32+
help="Batch size to use for training.",
33+
)
34+
parser.add_argument("--lr", type=float, default=3e-5, help="Learning rate to use for training.")
35+
parser.add_argument("--optimizer", type=str, default="adamw_hf", help="Learning rate to use for training.")
36+
parser.add_argument("--seed", type=int, default=42, help="Seed to use for training.")
37+
parser.add_argument("--num_train_epochs", type=int, default=1, help="Total number of training epochs to perform.")
38+
39+
parser.add_argument(
40+
"--gradient_checkpointing",
41+
type=bool,
42+
default=True,
43+
help="Whether to use gradient checkpointing to save memory.",
44+
)
45+
parser.add_argument(
46+
"--bf16",
47+
type=bool,
48+
default=True if torch.cuda.get_device_capability()[0] == 8 else False,
49+
help="Whether to use bf16.",
50+
)
51+
parser.add_argument(
52+
"--max_train_steps",
53+
type=int,
54+
default=None,
55+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
56+
)
57+
parser.add_argument(
58+
"--learning_rate",
59+
type=float,
60+
default=5e-5,
61+
help="Initial learning rate (after the potential warmup period) to use.",
62+
)
63+
parser.add_argument(
64+
"--gradient_accumulation_steps",
65+
type=int,
66+
default=1,
67+
help="Number of updates steps to accumulate before performing a backward/update pass.",
68+
)
69+
parser.add_argument(
70+
"--lr_scheduler_type",
71+
type=SchedulerType,
72+
default="linear",
73+
help="The scheduler type to use.",
74+
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
75+
)
76+
parser.add_argument(
77+
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
78+
)
79+
parser.add_argument(
80+
"--deepspeed_config", type=str, help="Path to deepspeed config json"
81+
)
82+
83+
parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.")
84+
parser.add_argument("--model_dir",type=str,default="/opt/ml/model")
85+
parser.add_argument("--cache_dir",type=str,default=None)
86+
args = parser.parse_known_args()
87+
return args
88+
89+
def training_function(args):
90+
# smddp example specifically tailored for p4d(e) instance types.
91+
local_rank = dist.get_rank() % 8
92+
seed = args.seed
93+
set_seed(seed)
94+
torch.cuda.set_device(local_rank)
95+
96+
dataset = {
97+
'train': StubDataset(),
98+
'validation': StubDataset()
99+
}
100+
101+
dtype = torch.bfloat16
102+
103+
from transformers import LlamaConfig
104+
configuration = LlamaConfig(use_cache=False)
105+
from transformers.models.llama import LlamaForCausalLM
106+
with deepspeed.zero.Init(dtype=dtype, enabled=True):
107+
model = AutoModelForCausalLM.from_config(configuration)
108+
model.gradient_checkpointing_enable()
109+
110+
train_dataset = dataset["train"]
111+
eval_dataset = dataset["validation"]
112+
train_dataloader, eval_dataloader = create_dataloaders(
113+
train_dataset, eval_dataset, dist.get_rank(), dist.get_world_size(),
114+
seed, args.batch_size, args.batch_size)
115+
116+
no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight"]
117+
optimizer_grouped_parameters = [{
118+
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
119+
"weight_decay": args.weight_decay,
120+
},{
121+
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
122+
"weight_decay": 0.0,
123+
}]
124+
125+
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
126+
127+
# Scheduler and math around the number of training steps.
128+
overrode_max_train_steps = False
129+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
130+
if dist.get_rank()==0:
131+
print(f"Number of update steps per epoch {num_update_steps_per_epoch}")
132+
if args.max_train_steps is None:
133+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
134+
overrode_max_train_steps = True
135+
136+
lr_scheduler = get_scheduler(
137+
name=args.lr_scheduler_type,
138+
optimizer=optimizer,
139+
num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps,
140+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
141+
)
142+
143+
model, optimizer, _, _ = deepspeed.initialize(
144+
model=model,
145+
optimizer=optimizer,
146+
model_parameters=model.parameters(),
147+
config=args.deepspeed_config
148+
)
149+
device = torch.device(f"cuda:{local_rank}")
150+
for epoch in range(args.num_train_epochs):
151+
model.train()
152+
total_steps=0
153+
ds_loss = torch.zeros(2).to(local_rank)
154+
155+
for batch_idx, batch in enumerate(train_dataloader):
156+
batch = {k: v.to(device) for k, v in batch.items()}
157+
output = model(**batch)
158+
if dist.get_rank() == 0: print(f"Processing training batch {batch_idx}")
159+
loss = output["loss"]
160+
loss.backward()
161+
ds_loss[0] += loss.item()
162+
ds_loss[1] += len(batch["input_ids"])
163+
optimizer.zero_grad()
164+
lr_scheduler.step()
165+
total_steps += 1
166+
if args.max_steps is not None and total_steps > args.max_steps:
167+
break
168+
169+
torch.distributed.all_reduce(ds_loss, op=torch.distributed.ReduceOp.SUM)
170+
train_loss = ds_loss[0] / ds_loss[1]
171+
train_ppl = torch.exp(train_loss)
172+
173+
if dist.get_rank()==0:
174+
print(f"******{epoch=}: {train_ppl=} {train_loss=}******")
175+
176+
model.eval()
177+
eval_loss = 0
178+
ds_eval_loss = torch.zeros(2).to(local_rank)
179+
for steps, batch in enumerate(eval_dataloader):
180+
batch = {k: v.to(device) for k, v in batch.items()}
181+
182+
if dist.get_rank() == 0: print(f"Performing validation on training batch {batch_idx}")
183+
with torch.no_grad():
184+
outputs = model(**batch)
185+
loss = outputs["loss"]
186+
ds_eval_loss[0] += loss.item()
187+
ds_eval_loss[1] += len(batch["input_ids"])
188+
if args.max_steps is not None and steps > args.max_steps:
189+
break
190+
191+
torch.distributed.all_reduce(ds_eval_loss, op=torch.distributed.ReduceOp.SUM)
192+
eval_loss = ds_eval_loss[0] / ds_eval_loss[1]
193+
eval_ppl = torch.exp(eval_loss)
194+
195+
if dist.get_rank()==0:
196+
print(f"*******{epoch=}: {eval_ppl=} {eval_loss=}*******")
197+
198+
if args.max_steps is not None and total_steps > args.max_steps:
199+
break
200+
201+
if dist.get_rank() == 0:
202+
print("Training done!")
203+
dist.barrier()
204+
205+
def main():
206+
deepspeed.init_distributed(dist_backend=backend)
207+
208+
args, _ = parse_args()
209+
training_function(args)
210+
211+
if __name__ == "__main__":
212+
main()
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import torch
2+
import torch.distributed as dist
3+
from torch.utils.data import DataLoader
4+
from transformers import default_data_collator
5+
6+
from torch.utils.data import Dataset
7+
from torch.utils.data.distributed import DistributedSampler
8+
9+
# dummy dataset for this example
10+
class StubDataset(Dataset):
11+
def __len__(self): return dist.get_world_size()*2
12+
def __getitem__(self, index):
13+
block_size = 4096
14+
return {
15+
'input_ids': torch.randint(1, 31580, (block_size,)),
16+
'attention_mask': torch.randint(0, 2, (block_size,)),
17+
'labels': torch.randint(1, 31579, (block_size,))
18+
}
19+
20+
def create_dataloaders(train_dataset, eval_dataset, rank, world_size, seed,
21+
train_batch_size, eval_batch_size):
22+
train_sampler = torch.utils.data.DistributedSampler(
23+
train_dataset, shuffle=True, seed=seed, rank=rank, num_replicas=world_size,
24+
drop_last=True,)
25+
eval_sampler = torch.utils.data.DistributedSampler(
26+
eval_dataset, shuffle=True, seed=seed, rank=rank, num_replicas=world_size,
27+
drop_last=True,)
28+
29+
train_dataloader = DataLoader(
30+
train_dataset, sampler=train_sampler, collate_fn=default_data_collator,
31+
batch_size=train_batch_size, pin_memory=True,drop_last=True)
32+
eval_dataloader = DataLoader(
33+
eval_dataset,sampler=eval_sampler, collate_fn=default_data_collator,
34+
batch_size=eval_batch_size, pin_memory=True,drop_last=True)
35+
return train_dataloader,eval_dataloader

0 commit comments

Comments
 (0)