@@ -39,7 +39,7 @@ def _setup_env(world_size: int, rank: int, target: callable, *args, **kwargs):
39
39
def _launch_distributed_inference (
40
40
model_name : str , builder_args : BuilderArgs , tokenizer_args : TokenizerArgs
41
41
) -> tuple [List ]:
42
- # create programmatic elastic launch
42
+ # launch distributed inference worker, each worker gets a pipe to communicate with the main process
43
43
logger .info ("Launching distributed inference ..." )
44
44
45
45
num_processes_per_node = builder_args .pp * builder_args .tp
@@ -50,17 +50,25 @@ def _launch_distributed_inference(
50
50
51
51
pipes = []
52
52
procs = []
53
- for rank in range (num_processes_per_node ):
54
- server_pipe , client_pipe = mp .Pipe (duplex = True )
55
- pipes .append (server_pipe )
56
- proc = mp .Process (
57
- target = partial (_setup_env , num_processes_per_node , rank , main ),
58
- args = (model_name , builder_args , tokenizer_args , client_pipe ),
59
- )
60
- proc .start ()
53
+ try :
54
+ for rank in range (num_processes_per_node ):
55
+ server_pipe , client_pipe = mp .Pipe (duplex = True )
56
+ pipes .append (server_pipe )
57
+ procs .append (
58
+ mp .Process (
59
+ target = partial (_setup_env , num_processes_per_node , rank , main ),
60
+ args = (model_name , builder_args , tokenizer_args , client_pipe ),
61
+ )
62
+ )
63
+ procs [- 1 ].start ()
61
64
62
- for pipe in pipes :
63
- response = pipe .recv ()
65
+ for pipe in pipes :
66
+ assert pipe .recv () == "ready" , "Starting the worker failed"
67
+ except Exception as e :
68
+ logger .error (f"Error during distributed inference: { str (e )} " )
69
+ for p in procs :
70
+ p .kill ()
71
+ raise e
64
72
65
73
logger .info (
66
74
f"Done launching distributed inference on { num_processes_per_node } GPUs."
@@ -105,11 +113,13 @@ def __init__(
105
113
self .loop = loop
106
114
107
115
def schedule_request (self , req : Request ):
116
+ # add request to queue and create deque and async event for response
108
117
self .req_to_states [req .request_id ] = asyncio .Event ()
109
118
self .req_to_results [req .request_id ] = deque ()
110
119
self .request_queue .put (req )
111
120
112
121
def process_requests_loop (self ):
122
+ # Continuously process requests (one at a time for now), results are routed into the requests deque
113
123
while True :
114
124
req = self .request_queue .get ()
115
125
if req == "stop" :
@@ -127,6 +137,7 @@ def process_requests_loop(self):
127
137
running &= not outputs [0 ].is_finished
128
138
129
139
async def wait_for_request (self , req : Request ) -> Output :
140
+ # Wait for request to deliver result, uses event to trigger and reads from left side of deque
130
141
is_finished = False
131
142
while not is_finished :
132
143
await self .req_to_states [req .request_id ].wait ()
@@ -138,6 +149,7 @@ async def wait_for_request(self, req: Request) -> Output:
138
149
del self .req_to_results [req .request_id ]
139
150
140
151
def step (self ) -> List [Output ]:
152
+ # Make a prefill or decoding step and receive results
141
153
responses = []
142
154
# TODO: Implement a scheduler to handle the requests
143
155
if len (self .in_flight_requests ) > 0 :
@@ -166,6 +178,7 @@ def step(self) -> List[Output]:
166
178
text , token_ids = v
167
179
outputs .append (
168
180
Output (
181
+ # TODO: Look for tokenizer.eos_id as well
169
182
is_finished = self .current_step >= self .generator_args .max_new_tokens ,
170
183
text = text ,
171
184
token = token_ids ,
@@ -218,6 +231,7 @@ def __init__(
218
231
atexit .register (self .shutdown )
219
232
220
233
def shutdown (self ):
234
+ # Stop all processes and threads
221
235
self .scheduler .request_queue .put ("stop" )
222
236
self .scheduler_thread .join ()
223
237
@@ -227,6 +241,7 @@ def shutdown(self):
227
241
p .kill ()
228
242
229
243
def generate (self , text ):
244
+ # Function to generate text from prompt
230
245
req = Request .new_request (text )
231
246
self .scheduler .schedule_request (req )
232
247
0 commit comments