Skip to content

Commit 95dd7e4

Browse files
committed
Add Torchrun script and enable distributed for that script
1 parent 1e0e48e commit 95dd7e4

File tree

4 files changed

+39
-1
lines changed

4 files changed

+39
-1
lines changed

build/builder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def from_args(cls, args): # -> BuilderArgs:
142142
device=args.device,
143143
precision=dtype,
144144
setup_caches=(args.output_dso_path or args.output_pte_path),
145-
use_distributed=False,
145+
use_distributed=args.distributed,
146146
is_chat_model=is_chat_model,
147147
)
148148

@@ -347,6 +347,7 @@ def _load_model(builder_args, only_config=False):
347347
else:
348348
model = _load_model_default(builder_args)
349349

350+
# TODO: ongoing work to support loading model from checkpoint
350351
if builder_args.use_distributed:
351352
# init distributed
352353
world_size = int(os.environ["WORLD_SIZE"])

cli.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ def add_arguments_for_verb(parser, verb: str):
5656
action="store_true",
5757
help="Whether to start an interactive chat session",
5858
)
59+
parser.add_argument(
60+
"--distributed",
61+
action="store_true",
62+
help="Whether to enable distributed inference",
63+
)
5964
parser.add_argument(
6065
"--gui",
6166
action="store_true",

config/model_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
6+
67
import json
78
from dataclasses import dataclass, field
89
from enum import Enum

distributed/run_dist_inference.sh

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#!/usr/bin/bash
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
set -ex
9+
10+
# libUV is a scalable backend for TCPStore which is used in processGroup
11+
# rendezvous. This is the recommended backend for distributed training.
12+
export USE_LIBUV=1
13+
14+
# use envs as local overrides for convenience
15+
# e.g.
16+
# LOG_RANK=0,1 NGPU=4 ./run_dist_inference.sh
17+
18+
NGPU=${NGPU:-"8"}
19+
20+
# TODO: We need to decide how to log for inference.
21+
# by default log just rank 0 output,
22+
LOG_RANK=${LOG_RANK:-0}
23+
24+
overrides=""
25+
if [ $# -ne 0 ]; then
26+
overrides="$*"
27+
fi
28+
29+
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
30+
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
31+
torchchat.py chat llama3 --distributed $overrides

0 commit comments

Comments
 (0)