Skip to content

Commit 2d37d27

Browse files
committed
Address PR comments; try/expect in launch_dist_inference; added comments
1 parent b8f88fd commit 2d37d27

File tree

1 file changed

+26
-11
lines changed

1 file changed

+26
-11
lines changed

torchchat/distributed/generate.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _setup_env(world_size: int, rank: int, target: callable, *args, **kwargs):
3939
def _launch_distributed_inference(
4040
model_name: str, builder_args: BuilderArgs, tokenizer_args: TokenizerArgs
4141
) -> tuple[List]:
42-
# create programmatic elastic launch
42+
# launch distributed inference worker, each worker gets a pipe to communicate with the main process
4343
logger.info("Launching distributed inference ...")
4444

4545
num_processes_per_node = builder_args.pp * builder_args.tp
@@ -50,17 +50,25 @@ def _launch_distributed_inference(
5050

5151
pipes = []
5252
procs = []
53-
for rank in range(num_processes_per_node):
54-
server_pipe, client_pipe = mp.Pipe(duplex=True)
55-
pipes.append(server_pipe)
56-
proc = mp.Process(
57-
target=partial(_setup_env, num_processes_per_node, rank, main),
58-
args=(model_name, builder_args, tokenizer_args, client_pipe),
59-
)
60-
proc.start()
53+
try:
54+
for rank in range(num_processes_per_node):
55+
server_pipe, client_pipe = mp.Pipe(duplex=True)
56+
pipes.append(server_pipe)
57+
procs.append(
58+
mp.Process(
59+
target=partial(_setup_env, num_processes_per_node, rank, main),
60+
args=(model_name, builder_args, tokenizer_args, client_pipe),
61+
)
62+
)
63+
procs[-1].start()
6164

62-
for pipe in pipes:
63-
response = pipe.recv()
65+
for pipe in pipes:
66+
assert pipe.recv() == "ready", "Starting the worker failed"
67+
except Exception as e:
68+
logger.error(f"Error during distributed inference: {str(e)}")
69+
for p in procs:
70+
p.kill()
71+
raise e
6472

6573
logger.info(
6674
f"Done launching distributed inference on {num_processes_per_node} GPUs."
@@ -105,11 +113,13 @@ def __init__(
105113
self.loop = loop
106114

107115
def schedule_request(self, req: Request):
116+
# add request to queue and create deque and async event for response
108117
self.req_to_states[req.request_id] = asyncio.Event()
109118
self.req_to_results[req.request_id] = deque()
110119
self.request_queue.put(req)
111120

112121
def process_requests_loop(self):
122+
# Continuously process requests (one at a time for now), results are routed into the requests deque
113123
while True:
114124
req = self.request_queue.get()
115125
if req == "stop":
@@ -127,6 +137,7 @@ def process_requests_loop(self):
127137
running &= not outputs[0].is_finished
128138

129139
async def wait_for_request(self, req: Request) -> Output:
140+
# Wait for request to deliver result, uses event to trigger and reads from left side of deque
130141
is_finished = False
131142
while not is_finished:
132143
await self.req_to_states[req.request_id].wait()
@@ -138,6 +149,7 @@ async def wait_for_request(self, req: Request) -> Output:
138149
del self.req_to_results[req.request_id]
139150

140151
def step(self) -> List[Output]:
152+
# Make a prefill or decoding step and receive results
141153
responses = []
142154
# TODO: Implement a scheduler to handle the requests
143155
if len(self.in_flight_requests) > 0:
@@ -166,6 +178,7 @@ def step(self) -> List[Output]:
166178
text, token_ids = v
167179
outputs.append(
168180
Output(
181+
# TODO: Look for tokenizer.eos_id as well
169182
is_finished=self.current_step >= self.generator_args.max_new_tokens,
170183
text=text,
171184
token=token_ids,
@@ -218,6 +231,7 @@ def __init__(
218231
atexit.register(self.shutdown)
219232

220233
def shutdown(self):
234+
# Stop all processes and threads
221235
self.scheduler.request_queue.put("stop")
222236
self.scheduler_thread.join()
223237

@@ -227,6 +241,7 @@ def shutdown(self):
227241
p.kill()
228242

229243
def generate(self, text):
244+
# Function to generate text from prompt
230245
req = Request.new_request(text)
231246
self.scheduler.schedule_request(req)
232247

0 commit comments

Comments
 (0)