Skip to content

Integrate distributed inference into torchchat cli #1327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
481e00b
add pp_dim, distributed, num_gpus, num_nodes as cmd line args
lessw2020 Oct 15, 2024
2f1787c
add tp_dim
lessw2020 Oct 15, 2024
fd3ddcd
add elastic_launch
lessw2020 Oct 17, 2024
bf79697
working, can now launch from cli
lessw2020 Oct 17, 2024
26a9455
Remove numpy < 2.0 pin to align with pytorch (#1301)
larryliu0820 Oct 16, 2024
5f0ca00
Update torchtune pin to 0.4.0-dev20241010 (#1300)
vmpuri Oct 16, 2024
598caf5
Unbreak gguf util CI job by fixing numpy version (#1307)
larryliu0820 Oct 16, 2024
6fe1646
Remove apparently-unused import torchvision in model.py (#1305)
swolchok Oct 17, 2024
78debce
remove global var for tokenizer type + patch tokenizer to allow list …
mreso Oct 17, 2024
2eefb13
make pp tp visible in interface
mreso Oct 17, 2024
e8bb076
Add llama 3.1 to dist_run.py
mreso Oct 17, 2024
1faa052
[WIP] Move dist inf into its own generator
mreso Oct 18, 2024
11f29fc
Add initial generator interface to dist inference
mreso Oct 21, 2024
adcf232
Added generate method and placeholder scheduler
mreso Oct 23, 2024
3836928
use prompt parameter for dist generation
mreso Oct 23, 2024
3f6fa2d
Enforce tp>=2
mreso Oct 24, 2024
fd9f704
Build tokenizer from TokenizerArgs
mreso Oct 24, 2024
e8f7c98
Disable torchchat format + constrain possible models for distributed
mreso Oct 24, 2024
9ec55fb
disable calling dist_run.py directly for now
mreso Oct 24, 2024
80f8138
Restore original dist_run.py for now
mreso Oct 24, 2024
abf0679
Merge branch 'main' into refactor/dist_run
mreso Oct 24, 2024
99606ab
disable _maybe_parallelize_model again
mreso Oct 24, 2024
4b8cdcb
Reenable arg.model_name in dist_run.py
mreso Oct 24, 2024
b8f88fd
Use singleton logger instead of print in generate
mreso Oct 24, 2024
2d37d27
Address PR comments; try/expect in launch_dist_inference; added comments
mreso Oct 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs

from torchchat.distributed.logging_utils import SingletonLogger

# TODO - these are not distributed specific, consider moving to new package
from torchchat.distributed.checkpoint_utils import (
get_hf_config_file,
load_weights_from_hf_format,
load_weights_from_torchchat_format,
)

from torchchat.distributed.logging_utils import SingletonLogger
from torchchat.distributed.utils import (
bytes_to_readable,
Color as color,
Expand Down Expand Up @@ -153,7 +153,9 @@ def _load_model_weights(
# This format stands for:
# single binary file, OR
# multiple binary files without index files.
load_weights_from_torchchat_format(stage_module, distribution, device, model_config)
load_weights_from_torchchat_format(
stage_module, distribution, device, model_config
)
else:
raise ValueError(f"Unknown checkpoint format: {chpt_from}")

Expand Down Expand Up @@ -593,9 +595,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
parser.add_argument(
"model_name",
type=str,
default="llama3",
help="Name of the model to load",
choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys(),
)

parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel degree")
parser.add_argument(
"--ntokens",
Expand Down
58 changes: 36 additions & 22 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,14 @@
import torch._inductor.config
import torch.nn as nn

from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune

from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama

from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.elastic.utils.distributed import get_free_port

from torchtune.models.convert_weights import meta_to_tune

from torchtune.training import set_default_dtype
from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama

from torchchat.model import Model, ModelArgs, ModelType

from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE

from torchchat.model_config.model_config import resolve_model_config
from torchchat.utils.build_utils import (
device_sync,
Expand All @@ -40,6 +34,14 @@
from torchchat.utils.measure_time import measure_time
from torchchat.utils.quantize import quantize_model

from torchtune.models.convert_weights import meta_to_tune

from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE

from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune

from torchtune.training import set_default_dtype


@dataclass
class BuilderArgs:
Expand All @@ -55,7 +57,10 @@ class BuilderArgs:
device: Optional[str] = None
precision: torch.dtype = torch.float32
setup_caches: bool = False
use_distributed: bool = False
distributed: bool = False
pp: int = 1
tp: int = 1
chpt_from: str = "hf"
is_chat_model: bool = False
prefill_possible: bool = False
dynamic_shapes: bool = False
Expand Down Expand Up @@ -87,7 +92,9 @@ def __post_init__(self):
]
for param, param_msg in ignored_params:
if param:
print(f"Warning: {param_msg} ignored because an exported DSO or PTE path was specified")
print(
f"Warning: {param_msg} ignored because an exported DSO or PTE path was specified"
)
else:
self.prefill_possible = True

Expand Down Expand Up @@ -153,7 +160,11 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
dtype = torch.float16
else:
dtype = name_to_dtype(args.dtype, args.device)

# distributed args
distributed = getattr(args, "distributed", False)
pp = getattr(args, "pp", 1)
tp = getattr(args, "tp", 1)
chpt_from = getattr(args, "chpt_from", "hf")
return cls(
checkpoint_dir=checkpoint_dir,
checkpoint_path=checkpoint_path,
Expand All @@ -167,7 +178,10 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
device=args.device,
precision=dtype,
setup_caches=(output_dso_path or output_pte_path),
use_distributed=args.distributed,
distributed=distributed,
pp=pp,
tp=tp,
chpt_from=chpt_from,
is_chat_model=is_chat_model,
dynamic_shapes=getattr(args, "dynamic_shapes", False),
max_seq_length=getattr(args, "max_seq_length", None),
Expand Down Expand Up @@ -397,10 +411,10 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
# does not host any actual values, need to reinitialize them in the actual
# device. Only do those buffer initialization, without initializing the entire
# model.
decoder_config = model.config.transformer_args['decoder']
head_dim = decoder_config['embed_dim'] // decoder_config['num_heads']
max_seq_len = decoder_config['max_seq_len']
rope_base = decoder_config['rope_base']
decoder_config = model.config.transformer_args["decoder"]
head_dim = decoder_config["embed_dim"] // decoder_config["num_heads"]
max_seq_len = decoder_config["max_seq_len"]
rope_base = decoder_config["rope_base"]
for submodule in model.modules():
if isinstance(submodule, Llama3ScaledRoPE):
submodule.__init__(head_dim, max_seq_len, rope_base)
Expand Down Expand Up @@ -476,18 +490,19 @@ def _maybe_parallelize_model(


def _load_model(builder_args: BuilderArgs) -> Model:
world_mesh, parallel_dims = _maybe_init_distributed(builder_args)
# world_mesh, parallel_dims = _maybe_init_distributed(builder_args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code is now effectively dead and we should just remove it but a later PR.

if builder_args.gguf_path:
model = _load_model_gguf(builder_args)
elif builder_args.use_distributed:
model = _init_model_on_meta_device(builder_args)
# elif builder_args.use_distributed:
# model = _init_model_on_meta_device(builder_args)
else:
model = _load_model_default(builder_args)
model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims)
# model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims)

model = model.to(device=builder_args.device, dtype=builder_args.precision)
return model.eval()


def _initialize_model(
builder_args: BuilderArgs,
quantize,
Expand All @@ -496,7 +511,6 @@ def _initialize_model(
support_tensor_subclass: bool = True,
) -> Model:
print("Loading model...")

if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path):
print("Setting gguf_kwargs for generate.")
is_dso = builder_args.dso_path is not None
Expand Down
28 changes: 24 additions & 4 deletions torchchat/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,7 @@ def _add_distributed_args(parser) -> None:
parser.add_argument(
"--distributed",
action="store_true",
help=argparse.SUPPRESS,
# "Whether to enable distributed inference",
help="Whether to enable distributed inference",
)
parser.add_argument(
"--dcp-dir",
Expand All @@ -409,6 +408,27 @@ def _add_distributed_args(parser) -> None:
help=argparse.SUPPRESS,
# "Use the specified model checkpoint directory",
)
parser.add_argument(
"--pp",
"--pipeline-parallel",
type=int,
default=1,
help="Pipeline parallel degree",
)
parser.add_argument(
"--tp",
"--tensor-parallel",
type=int,
default=2,
help="Tensor parallel degree",
)
parser.add_argument(
"--chpt-from",
type=str,
default="hf", # TODO: change to torchchat once we support it well
help="Checkpoint format to load from",
choices=["hf", "torchchat"],
)


# Add CLI Args related to custom model inputs
Expand All @@ -425,13 +445,13 @@ def _add_custom_model_args(parser) -> None:
"--params-path",
type=Path,
default=None,
help= "Use the specified parameter file, instead of one specified under torchchat.model_params",
help="Use the specified parameter file, instead of one specified under torchchat.model_params",
)
parser.add_argument(
"--tokenizer-path",
type=Path,
default=None,
help= "Use the specified model tokenizer file, instead of the one downloaded from HuggingFace",
help="Use the specified model tokenizer file, instead of the one downloaded from HuggingFace",
)


Expand Down
Loading
Loading