Skip to content

Commit 7dad56f

Browse files
mresoJack-Khuu
authored andcommitted
Integrate distributed inference with chat/server (#1381)
* Integrate distributed inference without introducing abstraction * Cleanup old distributed inference integration * Read distribution from model_config * Declare distribution_path if args.model is not given * Address some nits from PR review * Added comment on model size all reduce + type hint * Apply suggestions from code review Co-authored-by: Jack-Khuu <[email protected]> * Make sure speculative decoding is disable for pp >1 and remark this in the comments as well * Refactor conditions in pp * Rename and alter signature of setup_env to reflect that it also runs the target * Rename setup_env in server + fix condition * Update generate.py * Add default value to add_generation_prompt to preserve bc --------- Co-authored-by: Jack-Khuu <[email protected]>
1 parent 582e558 commit 7dad56f

File tree

8 files changed

+596
-956
lines changed

8 files changed

+596
-956
lines changed

torchchat/cli/builder.py

Lines changed: 109 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,17 @@
1414
import torch
1515
import torch._dynamo.config
1616
import torch._inductor.config
17-
import torch.nn as nn
17+
import torch.distributed as dist
1818

19-
from torchchat.model import Model, ModelArgs, ModelType
19+
from torchchat.distributed.utils import(
20+
Color as color,
21+
CUDATrackTime,
22+
init_distributed,
23+
GPUMemoryMonitor,
24+
)
25+
from torchchat.distributed.logging_utils import SingletonLogger
2026

27+
from torchchat.model import Model, ModelArgs, ModelType, Transformer, TransformerArgs
2128
from torchchat.model_config.model_config import resolve_model_config
2229
from torchchat.utils.build_utils import (
2330
device_sync,
@@ -28,6 +35,7 @@
2835
from torchchat.utils.measure_time import measure_time
2936
from torchchat.utils.quantize import quantize_model
3037

38+
3139
from torchtune.models.convert_weights import meta_to_tune
3240

3341
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
@@ -56,6 +64,7 @@ class BuilderArgs:
5664
pp: int = 1
5765
tp: int = 1
5866
chpt_from: str = "hf"
67+
distribution_path: Optional[str] = None
5968
is_chat_model: bool = False
6069
prefill_possible: bool = False
6170
dynamic_shapes: bool = False
@@ -107,6 +116,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
107116

108117
checkpoint_path = args.checkpoint_path
109118
params_table = args.params_table
119+
distribution_path = None
110120
if args.model: # Using a named, well-known model
111121
model_config = resolve_model_config(args.model)
112122

@@ -121,6 +131,8 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
121131
model_config.transformer_params_key or model_config.name.split("/")[-1]
122132
)
123133

134+
distribution_path = model_config.distribution_path
135+
124136
dso_path = getattr(args, "dso_path", None)
125137
pte_path = getattr(args, "pte_path", None)
126138
aoti_package_path = getattr(args, "aoti_package_path", None)
@@ -186,6 +198,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
186198
pp=pp,
187199
tp=tp,
188200
chpt_from=chpt_from,
201+
distribution_path=distribution_path,
189202
is_chat_model=is_chat_model,
190203
dynamic_shapes=getattr(args, "dynamic_shapes", False),
191204
max_seq_length=getattr(args, "max_seq_length", None),
@@ -601,6 +614,100 @@ def do_nothing(max_batch_size, max_seq_length):
601614
model = PTEModel(config, builder_args.pte_path)
602615
except Exception:
603616
raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}")
617+
elif builder_args.distributed:
618+
pp_degree = builder_args.pp
619+
tp_degree = builder_args.tp
620+
621+
init_distributed()
622+
rank = dist.get_rank()
623+
torch.cuda.set_device(rank % torch.cuda.device_count())
624+
625+
logger = SingletonLogger.get_logger()
626+
627+
gpu_memory_monitor = GPUMemoryMonitor("cuda")
628+
logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}")
629+
630+
# Model-level config
631+
if builder_args.params_table:
632+
model_config = ModelArgs.from_table(builder_args.params_table)
633+
else:
634+
raise NotImplementedError()
635+
# Transformer-level config
636+
config = TransformerArgs.from_params(model_config.transformer_args["text"])
637+
logger.info(f"Transformer Config: {config}")
638+
639+
#TODO: Move into head of file after solving circular import
640+
from torchchat.distributed.checkpoint_utils import (
641+
load_model_weights,
642+
)
643+
644+
# Validate pipeline degree
645+
assert config.n_layers % pp_degree == 0
646+
647+
# Create device mesh
648+
device_mesh = dist.init_device_mesh(
649+
"cuda",
650+
(pp_degree, tp_degree),
651+
mesh_dim_names=("pp", "tp")
652+
)
653+
tp_mesh = device_mesh["tp"]
654+
pp_mesh = device_mesh["pp"]
655+
logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}")
656+
657+
pp_rank = pp_mesh.get_local_rank()
658+
logger.info(f"{pp_degree=}, {tp_degree=}")
659+
660+
# Assuming same number of GPUs per node
661+
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
662+
663+
# Fill in PP configs
664+
config.stage_idx = pp_rank
665+
config.n_stages = pp_degree
666+
667+
with torch.device("meta"):
668+
# TODO: we should create model instead of Transformer
669+
model = Transformer(config)
670+
671+
# Distribute model on TP mesh
672+
# (Surprisingly, this works even though model is on meta device and mesh is of
673+
# cuda devices)
674+
model.distribute(tp_mesh)
675+
if rank == 0:
676+
logger.info(f"Model: {model}")
677+
678+
# Load weights
679+
logger.info(f"Loading weights for {pp_rank=} on {device=}")
680+
with CUDATrackTime() as timer:
681+
load_model_weights(model, builder_args.distribution_path, device, config, builder_args.chpt_from)
682+
683+
logger.info(
684+
f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}"
685+
)
686+
687+
# Setup KV caches (after model distribution)
688+
# The number of cache lanes is the same as the maximum number of
689+
# micro-batches that can be "in flight" in parallel -- imagine each
690+
# micro-batch takes 1 "pipeline lane," they need distinct KV cache spaces.
691+
# When decoding is done for certain micro-batches, we can reuse the KV cache
692+
# lanes.
693+
# TODO: bump up the lane count
694+
pipeline_lanes = 1
695+
seqlen_prefill=1024
696+
with device:
697+
model.setup_caches(1, seqlen_prefill, cache_lanes=pipeline_lanes)
698+
699+
# info on stage size and params
700+
# stage_size = get_module_size(model)
701+
# stage_size_formatted = bytes_to_readable(stage_size)
702+
# stage_num_params = get_num_params(model)
703+
# logger.info(
704+
# f"Stage {rank} has {color.blue}{stage_num_params} params{color.reset}, Size: {color.blue}{stage_size_formatted}{color.reset}"
705+
# )
706+
model.eval()
707+
708+
model.text_transformer_args = None
709+
model.config.model_type = model_config.model_type
710+
model.device_mesh = device_mesh
604711
else:
605712
with measure_time("Time to load model: {time:.02f} seconds"):
606713
model = _load_model(builder_args)

torchchat/distributed/checkpoint_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from torch.distributed._tensor import DTensor
1818
from torchchat.distributed.dtensor_utils import convert_to_dtensor
1919
from torchchat.cli.builder import BuilderArgs, _load_checkpoint
20+
from torchchat.model import ModelArgs
2021

2122

2223
_DEFAULT_SAFETENSOR_FILE_NAME = "model.safetensors.index.json"
@@ -450,3 +451,34 @@ def load_weights_from_torchchat_format(stage_module, distribution, device, model
450451
# Fill state dict into stage module
451452
stage_module.load_state_dict(stage_state_dict, strict=False, assign=True)
452453
logger.info(f"Successfully loaded {len(updated_states)} weights into stage module")
454+
455+
456+
def load_model_weights(
457+
stage_module: torch.nn.Module,
458+
distribution: str,
459+
device: torch.device,
460+
model_config: ModelArgs,
461+
chpt_from: str,
462+
):
463+
"""Load the weights from the safetensor file(s) into the model stage.
464+
Model config is needed b/c we permute wq and wk weights based on attn heads.
465+
466+
Args:
467+
stage_module (torch.nn.Module): The model stage to load the weights into.
468+
distribution (str): The distribution name, e.g. "meta-llama/Meta-Llama-3-8B-Instruct".
469+
device (torch.device): The device to load the weights onto.
470+
model_config (ModelArgs): The model config.
471+
chpt_from (str): The checkpoint format to load the weights from, e.g. "torchchat" or "hf".
472+
"""
473+
if chpt_from == "hf":
474+
# This format stands for: index file + multiple binary files
475+
load_weights_from_hf_format(stage_module, distribution, device, model_config)
476+
elif chpt_from == "torchchat":
477+
# This format stands for:
478+
# single binary file, OR
479+
# multiple binary files without index files.
480+
load_weights_from_torchchat_format(
481+
stage_module, distribution, device, model_config
482+
)
483+
else:
484+
raise ValueError(f"Unknown checkpoint format: {chpt_from}")

0 commit comments

Comments
 (0)