@@ -289,24 +289,32 @@ def _populate_mode_to_tensor_name_map(self, tensor: TensorLocation) -> None:
289
289
self .mode_to_tensors_map [tensor .mode ].add (tensor .tensorname )
290
290
291
291
def _add_tensor (self , step_num , worker , tensor_object : TensorLocation ):
292
- to = tensor_object
293
- # self.worker_set.add(worker)
294
- if REDUCTIONS_PREFIX in to .tensorname :
295
- tname , red_name , abs = reverse_reduction_tensor_name (to .tensorname )
292
+ is_reduction = False
293
+
294
+ if REDUCTIONS_PREFIX in tensor_object .tensorname :
295
+ tname , red_name , abs = reverse_reduction_tensor_name (tensor_object .tensorname )
296
+ tensor_object .tensorname = tname
297
+ is_reduction = True
296
298
else :
297
- tname = to .tensorname
299
+ tname = tensor_object .tensorname
300
+
298
301
if tname not in self ._tensors :
299
- t = Tensor (tname , trial = self , cache = self .cache )
300
- self ._tensors [tname ] = t
301
- t = self . _tensors [ tname ]
302
- self ._populate_step_dict ( to , step_num )
303
- self . _populate_global_step_to_tensor_name_map ( to , step_num )
304
- self . _populate_workers_for_global_step ( step_num , worker )
305
- self . _populate_mode_to_tensor_name_map ( to )
306
- if REDUCTIONS_PREFIX in to . tensorname :
307
- t . add_reduction_step ( to . mode , to . mode_step , worker , red_name , abs , to )
302
+ tensor = Tensor (tname , trial = self , cache = self .cache )
303
+ self ._tensors [tname ] = tensor
304
+
305
+ tensor = self ._tensors [ tname ]
306
+
307
+ if is_reduction :
308
+ tensor . add_reduction_step (
309
+ tensor_object . mode , tensor_object . mode_step , worker , red_name , abs , tensor_object
310
+ )
308
311
else :
309
- t .add_step (to .mode , to .mode_step , worker , to )
312
+ tensor .add_step (tensor_object .mode , tensor_object .mode_step , worker , tensor_object )
313
+
314
+ self ._populate_step_dict (tensor_object , step_num )
315
+ self ._populate_global_step_to_tensor_name_map (tensor_object , step_num )
316
+ self ._populate_workers_for_global_step (step_num , worker )
317
+ self ._populate_mode_to_tensor_name_map (tensor_object )
310
318
311
319
def _tensors_matching_regex (self , regex_list ) -> set :
312
320
matched_tensornames = set ()
0 commit comments