@@ -231,11 +231,11 @@ std::unordered_map<torch::jit::Value*, usage_info> getInputUsageCounts(
231
231
return usage_counts;
232
232
}
233
233
234
- std::unordered_map<size_t , std::list<SegmentedBlock>::iterator>
235
- getIdxtoIterMap ( std::list<SegmentedBlock> & segmented_blocks_list) {
234
+ std::unordered_map<size_t , std::list<SegmentedBlock>::iterator> getIdxtoIterMap (
235
+ std::list<SegmentedBlock>& segmented_blocks_list) {
236
236
std::unordered_map<size_t , std::list<SegmentedBlock>::iterator> idx_to_iter;
237
237
auto iter = segmented_blocks_list.begin ();
238
- for (int i = 0 ; i < segmented_blocks_list.size (); ++i, ++iter) {
238
+ for (uint64_t i = 0 ; i < segmented_blocks_list.size (); ++i, ++iter) {
239
239
idx_to_iter[i] = iter;
240
240
}
241
241
return idx_to_iter;
@@ -283,22 +283,24 @@ void resolveNonTensorInputBlocks(PartitionedGraph& segmented_blocks) {
283
283
}
284
284
285
285
void resolveTensorListInputBlocks (PartitionedGraph& segmented_blocks) {
286
- // usage_counts is a map with key as non-tensor/tensorlist inputs and value as the idx of segmented block which produces/contains it.
287
- auto usage_counts = getInputUsageCounts (
288
- segmented_blocks, [](torch::jit::Value* input) -> bool { return isTensorList (input); });
286
+ // usage_counts is a map with key as non-tensor/tensorlist inputs and value as the idx of segmented block which
287
+ // produces/contains it.
288
+ auto usage_counts =
289
+ getInputUsageCounts (segmented_blocks, [](torch::jit::Value* input) -> bool { return isTensorList (input); });
289
290
290
291
// Get idx of the segblock to its iterator mapping
291
292
std::list<SegmentedBlock> segmented_blocks_list (segmented_blocks.cbegin (), segmented_blocks.cend ());
292
293
auto idx_to_iter = getIdxtoIterMap (segmented_blocks_list);
293
294
294
295
std::unordered_set<int > updated_segments;
295
296
// we need to re-segment TensorRT segments whose inputs are TensorLists
296
- for (auto & use : usage_counts) {
297
+ for (auto & use : usage_counts) {
297
298
auto use_info = use.second ;
298
299
// For a particular tensorlist input, traverse through all ids of segmented blocks whose target is TensorRT
299
300
for (auto i : use_info.tensorrt_use_id ) {
300
301
if (!updated_segments.count (i)) {
301
- // tensorlistinput_to_segblock is a mapping from {tensorlist input : segmented block which produced this tensorlist input}
302
+ // tensorlistinput_to_segblock is a mapping from {tensorlist input : segmented block which produced this
303
+ // tensorlist input}
302
304
std::unordered_map<torch::jit::Value*, SegmentedBlock> tensorlistinput_to_segblock;
303
305
for (auto input : segmented_blocks[i].raw_inputs ()) {
304
306
if (isTensorList (input)) {
@@ -308,18 +310,20 @@ void resolveTensorListInputBlocks(PartitionedGraph& segmented_blocks) {
308
310
309
311
// For each tensorlist input in tensorlistinput_to_segblock, get the node which actually uses this input.
310
312
// Once we retrieve the node, we remove it from the current TensorRT segmented_blocks[i]. This node should be
311
- // added to block that generated/produced (can be obtained via produce_id) this tensorlist input in the first place.
313
+ // added to block that generated/produced (can be obtained via produce_id) this tensorlist input in the first
314
+ // place.
312
315
auto seg_blocks = segmentBlocksWithTensorListInputs (segmented_blocks[i], tensorlistinput_to_segblock);
313
316
auto append_blocks = seg_blocks.first ;
314
317
auto trt_block = seg_blocks.second ;
315
- // Remove the current TensorRT seg_block and replace it with new TRT block (non empty) which has the node that uses tensorlist input removed.
318
+ // Remove the current TensorRT seg_block and replace it with new TRT block (non empty) which has the node that
319
+ // uses tensorlist input removed.
316
320
auto next_iter = segmented_blocks_list.erase (idx_to_iter[i]);
317
321
if (trt_block.raw_nodes ().size () > 0 ) {
318
322
segmented_blocks_list.insert (next_iter, trt_block);
319
323
}
320
324
321
325
// append blocks' nodes to the producer seg_block
322
- for (auto append_block: append_blocks) {
326
+ for (auto append_block : append_blocks) {
323
327
auto input = append_block.first ; // corresponds to the tensorlist input
324
328
auto block = append_block.second ;
325
329
// append nodes to segmented_blocks_list
0 commit comments