|
| 1 | +import argparse |
| 2 | +import math |
| 3 | +from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed, get_scheduler, SchedulerType |
| 4 | +from datasets import load_from_disk |
| 5 | +import torch |
| 6 | +import torch.distributed as dist |
| 7 | + |
| 8 | +from utils import create_dataloaders, StubDataset |
| 9 | +import functools |
| 10 | +import deepspeed |
| 11 | +try: |
| 12 | + backend = "smddp" |
| 13 | + import smdistributed.dataparallel.torch.torch_smddp |
| 14 | +except ModuleNotFoundError: |
| 15 | + backend = "nccl" |
| 16 | + print("Warning: SMDDP not found on this image, falling back to NCCL!") |
| 17 | + |
| 18 | +def parse_args(): |
| 19 | + parser = argparse.ArgumentParser() |
| 20 | + parser.add_argument( |
| 21 | + "--model_id", |
| 22 | + type=str, |
| 23 | + default="meta-llama/Llama-2-7b-chat-hf", |
| 24 | + help="Model id to use for training.", |
| 25 | + ) |
| 26 | + parser.add_argument("--epochs", type=int, default=2, help="Number of epochs to train for.") |
| 27 | + parser.add_argument("--max_steps", type=int, default=None, help="Number of epochs to train for.") |
| 28 | + parser.add_argument( |
| 29 | + "--batch_size", |
| 30 | + type=int, |
| 31 | + default=1, |
| 32 | + help="Batch size to use for training.", |
| 33 | + ) |
| 34 | + parser.add_argument("--lr", type=float, default=3e-5, help="Learning rate to use for training.") |
| 35 | + parser.add_argument("--optimizer", type=str, default="adamw_hf", help="Learning rate to use for training.") |
| 36 | + parser.add_argument("--seed", type=int, default=42, help="Seed to use for training.") |
| 37 | + parser.add_argument("--num_train_epochs", type=int, default=1, help="Total number of training epochs to perform.") |
| 38 | + |
| 39 | + parser.add_argument( |
| 40 | + "--gradient_checkpointing", |
| 41 | + type=bool, |
| 42 | + default=True, |
| 43 | + help="Whether to use gradient checkpointing to save memory.", |
| 44 | + ) |
| 45 | + parser.add_argument( |
| 46 | + "--bf16", |
| 47 | + type=bool, |
| 48 | + default=True if torch.cuda.get_device_capability()[0] == 8 else False, |
| 49 | + help="Whether to use bf16.", |
| 50 | + ) |
| 51 | + parser.add_argument( |
| 52 | + "--max_train_steps", |
| 53 | + type=int, |
| 54 | + default=None, |
| 55 | + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| 56 | + ) |
| 57 | + parser.add_argument( |
| 58 | + "--learning_rate", |
| 59 | + type=float, |
| 60 | + default=5e-5, |
| 61 | + help="Initial learning rate (after the potential warmup period) to use.", |
| 62 | + ) |
| 63 | + parser.add_argument( |
| 64 | + "--gradient_accumulation_steps", |
| 65 | + type=int, |
| 66 | + default=1, |
| 67 | + help="Number of updates steps to accumulate before performing a backward/update pass.", |
| 68 | + ) |
| 69 | + parser.add_argument( |
| 70 | + "--lr_scheduler_type", |
| 71 | + type=SchedulerType, |
| 72 | + default="linear", |
| 73 | + help="The scheduler type to use.", |
| 74 | + choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"], |
| 75 | + ) |
| 76 | + parser.add_argument( |
| 77 | + "--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." |
| 78 | + ) |
| 79 | + parser.add_argument( |
| 80 | + "--deepspeed_config", type=str, help="Path to deepspeed config json" |
| 81 | + ) |
| 82 | + |
| 83 | + parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay to use.") |
| 84 | + parser.add_argument("--model_dir",type=str,default="/opt/ml/model") |
| 85 | + parser.add_argument("--cache_dir",type=str,default=None) |
| 86 | + args = parser.parse_known_args() |
| 87 | + return args |
| 88 | + |
| 89 | +def training_function(args): |
| 90 | + # smddp example specifically tailored for p4d(e) instance types. |
| 91 | + local_rank = dist.get_rank() % 8 |
| 92 | + seed = args.seed |
| 93 | + set_seed(seed) |
| 94 | + torch.cuda.set_device(local_rank) |
| 95 | + |
| 96 | + dataset = { |
| 97 | + 'train': StubDataset(), |
| 98 | + 'validation': StubDataset() |
| 99 | + } |
| 100 | + |
| 101 | + dtype = torch.bfloat16 |
| 102 | + |
| 103 | + from transformers import LlamaConfig |
| 104 | + configuration = LlamaConfig(use_cache=False) |
| 105 | + from transformers.models.llama import LlamaForCausalLM |
| 106 | + with deepspeed.zero.Init(dtype=dtype, enabled=True): |
| 107 | + model = AutoModelForCausalLM.from_config(configuration) |
| 108 | + model.gradient_checkpointing_enable() |
| 109 | + |
| 110 | + train_dataset = dataset["train"] |
| 111 | + eval_dataset = dataset["validation"] |
| 112 | + train_dataloader, eval_dataloader = create_dataloaders( |
| 113 | + train_dataset, eval_dataset, dist.get_rank(), dist.get_world_size(), |
| 114 | + seed, args.batch_size, args.batch_size) |
| 115 | + |
| 116 | + no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight"] |
| 117 | + optimizer_grouped_parameters = [{ |
| 118 | + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], |
| 119 | + "weight_decay": args.weight_decay, |
| 120 | + },{ |
| 121 | + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], |
| 122 | + "weight_decay": 0.0, |
| 123 | + }] |
| 124 | + |
| 125 | + optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) |
| 126 | + |
| 127 | + # Scheduler and math around the number of training steps. |
| 128 | + overrode_max_train_steps = False |
| 129 | + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
| 130 | + if dist.get_rank()==0: |
| 131 | + print(f"Number of update steps per epoch {num_update_steps_per_epoch}") |
| 132 | + if args.max_train_steps is None: |
| 133 | + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
| 134 | + overrode_max_train_steps = True |
| 135 | + |
| 136 | + lr_scheduler = get_scheduler( |
| 137 | + name=args.lr_scheduler_type, |
| 138 | + optimizer=optimizer, |
| 139 | + num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps, |
| 140 | + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
| 141 | + ) |
| 142 | + |
| 143 | + model, optimizer, _, _ = deepspeed.initialize( |
| 144 | + model=model, |
| 145 | + optimizer=optimizer, |
| 146 | + model_parameters=model.parameters(), |
| 147 | + config=args.deepspeed_config |
| 148 | + ) |
| 149 | + device = torch.device(f"cuda:{local_rank}") |
| 150 | + for epoch in range(args.num_train_epochs): |
| 151 | + model.train() |
| 152 | + total_steps=0 |
| 153 | + ds_loss = torch.zeros(2).to(local_rank) |
| 154 | + |
| 155 | + for batch_idx, batch in enumerate(train_dataloader): |
| 156 | + batch = {k: v.to(device) for k, v in batch.items()} |
| 157 | + output = model(**batch) |
| 158 | + if dist.get_rank() == 0: print(f"Processing training batch {batch_idx}") |
| 159 | + loss = output["loss"] |
| 160 | + loss.backward() |
| 161 | + ds_loss[0] += loss.item() |
| 162 | + ds_loss[1] += len(batch["input_ids"]) |
| 163 | + optimizer.zero_grad() |
| 164 | + lr_scheduler.step() |
| 165 | + total_steps += 1 |
| 166 | + if args.max_steps is not None and total_steps > args.max_steps: |
| 167 | + break |
| 168 | + |
| 169 | + torch.distributed.all_reduce(ds_loss, op=torch.distributed.ReduceOp.SUM) |
| 170 | + train_loss = ds_loss[0] / ds_loss[1] |
| 171 | + train_ppl = torch.exp(train_loss) |
| 172 | + |
| 173 | + if dist.get_rank()==0: |
| 174 | + print(f"******{epoch=}: {train_ppl=} {train_loss=}******") |
| 175 | + |
| 176 | + model.eval() |
| 177 | + eval_loss = 0 |
| 178 | + ds_eval_loss = torch.zeros(2).to(local_rank) |
| 179 | + for steps, batch in enumerate(eval_dataloader): |
| 180 | + batch = {k: v.to(device) for k, v in batch.items()} |
| 181 | + |
| 182 | + if dist.get_rank() == 0: print(f"Performing validation on training batch {batch_idx}") |
| 183 | + with torch.no_grad(): |
| 184 | + outputs = model(**batch) |
| 185 | + loss = outputs["loss"] |
| 186 | + ds_eval_loss[0] += loss.item() |
| 187 | + ds_eval_loss[1] += len(batch["input_ids"]) |
| 188 | + if args.max_steps is not None and steps > args.max_steps: |
| 189 | + break |
| 190 | + |
| 191 | + torch.distributed.all_reduce(ds_eval_loss, op=torch.distributed.ReduceOp.SUM) |
| 192 | + eval_loss = ds_eval_loss[0] / ds_eval_loss[1] |
| 193 | + eval_ppl = torch.exp(eval_loss) |
| 194 | + |
| 195 | + if dist.get_rank()==0: |
| 196 | + print(f"*******{epoch=}: {eval_ppl=} {eval_loss=}*******") |
| 197 | + |
| 198 | + if args.max_steps is not None and total_steps > args.max_steps: |
| 199 | + break |
| 200 | + |
| 201 | + if dist.get_rank() == 0: |
| 202 | + print("Training done!") |
| 203 | + dist.barrier() |
| 204 | + |
| 205 | +def main(): |
| 206 | + deepspeed.init_distributed(dist_backend=backend) |
| 207 | + |
| 208 | + args, _ = parse_args() |
| 209 | + training_function(args) |
| 210 | + |
| 211 | +if __name__ == "__main__": |
| 212 | + main() |
0 commit comments