Skip to content

Commit 272a9c7

Browse files
committed
refactor: Apply linting, fix warnings
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 609a697 commit 272a9c7

File tree

5 files changed

+26
-20
lines changed

5 files changed

+26
-20
lines changed

core/lowering/passes/linear_to_addmm.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11

22
#include <torch/csrc/jit/runtime/operator.h>
3+
#include "core/util/prelude.h"
4+
#include "torch/csrc/jit/api/function_impl.h"
35
#include "torch/csrc/jit/ir/alias_analysis.h"
46
#include "torch/csrc/jit/jit_log.h"
57
#include "torch/csrc/jit/passes/constant_propagation.h"
68
#include "torch/csrc/jit/passes/dead_code_elimination.h"
79
#include "torch/csrc/jit/passes/guard_elimination.h"
810
#include "torch/csrc/jit/passes/peephole.h"
9-
#include "torch/csrc/jit/runtime/graph_executor.h"
10-
#include "torch/csrc/jit/api/function_impl.h"
1111
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
12-
#include "core/util/prelude.h"
12+
#include "torch/csrc/jit/runtime/graph_executor.h"
1313

1414
namespace torch_tensorrt {
1515
namespace core {
@@ -34,7 +34,8 @@ void replaceLinearWithBiasNonePattern(std::shared_ptr<torch::jit::Graph> graph)
3434
continue;
3535
} else {
3636
torch::jit::WithInsertPoint guard(*it);
37-
std::shared_ptr<torch::jit::Graph> d_graph = toGraphFunction(decompose_funcs.get_function("linear")).graph();;
37+
std::shared_ptr<torch::jit::Graph> d_graph = toGraphFunction(decompose_funcs.get_function("linear")).graph();
38+
;
3839
torch::jit::Value* new_output = insertGraph(*it->owningGraph(), *d_graph, it->inputs()).at(0);
3940
new_output->setType(it->output()->type());
4041
it->output()->replaceAllUsesWith(new_output);

core/lowering/passes/reduce_gelu.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ void ReduceGelu(std::shared_ptr<torch::jit::Graph>& graph) {
1212
%out : Tensor = aten::gelu(%x)
1313
return (%out))IR";
1414

15-
// This gelu_approximate_pattern schema exists in 21.11, 21.12, 22.01 containers of pytorch. These container versions use
16-
// an unmerged PR in pytorch : https://github.com/pytorch/pytorch/pull/61439. We reduce this to regular Gelu.
15+
// This gelu_approximate_pattern schema exists in 21.11, 21.12, 22.01 containers of pytorch. These container versions
16+
// use an unmerged PR in pytorch : https://github.com/pytorch/pytorch/pull/61439. We reduce this to regular Gelu.
1717
std::string gelu_approximate_pattern = R"IR(
1818
graph(%x : Tensor, %approx):
1919
%out : Tensor = aten::gelu(%x, %approx)
@@ -64,7 +64,8 @@ void ReduceGelu(std::shared_ptr<torch::jit::Graph>& graph) {
6464
map_gelu_to_pointwise_ops.runOnGraph(graph);
6565

6666
torch::jit::SubgraphRewriter map_gelu_approximate_to_pointwise_ops;
67-
map_gelu_approximate_to_pointwise_ops.RegisterRewritePattern(gelu_approximate_pattern, gelu_reduce_multi_input_pattern);
67+
map_gelu_approximate_to_pointwise_ops.RegisterRewritePattern(
68+
gelu_approximate_pattern, gelu_reduce_multi_input_pattern);
6869
map_gelu_approximate_to_pointwise_ops.runOnGraph(graph);
6970

7071
LOG_GRAPH("Post lowering of [aten::gelu] -> " << *graph);

core/partitioning/partitioning.cpp

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -231,11 +231,11 @@ std::unordered_map<torch::jit::Value*, usage_info> getInputUsageCounts(
231231
return usage_counts;
232232
}
233233

234-
std::unordered_map<size_t, std::list<SegmentedBlock>::iterator>
235-
getIdxtoIterMap(std::list<SegmentedBlock> &segmented_blocks_list) {
234+
std::unordered_map<size_t, std::list<SegmentedBlock>::iterator> getIdxtoIterMap(
235+
std::list<SegmentedBlock>& segmented_blocks_list) {
236236
std::unordered_map<size_t, std::list<SegmentedBlock>::iterator> idx_to_iter;
237237
auto iter = segmented_blocks_list.begin();
238-
for (int i = 0; i < segmented_blocks_list.size(); ++i, ++iter) {
238+
for (uint64_t i = 0; i < segmented_blocks_list.size(); ++i, ++iter) {
239239
idx_to_iter[i] = iter;
240240
}
241241
return idx_to_iter;
@@ -283,22 +283,24 @@ void resolveNonTensorInputBlocks(PartitionedGraph& segmented_blocks) {
283283
}
284284

285285
void resolveTensorListInputBlocks(PartitionedGraph& segmented_blocks) {
286-
// usage_counts is a map with key as non-tensor/tensorlist inputs and value as the idx of segmented block which produces/contains it.
287-
auto usage_counts = getInputUsageCounts(
288-
segmented_blocks, [](torch::jit::Value* input) -> bool { return isTensorList(input); });
286+
// usage_counts is a map with key as non-tensor/tensorlist inputs and value as the idx of segmented block which
287+
// produces/contains it.
288+
auto usage_counts =
289+
getInputUsageCounts(segmented_blocks, [](torch::jit::Value* input) -> bool { return isTensorList(input); });
289290

290291
// Get idx of the segblock to its iterator mapping
291292
std::list<SegmentedBlock> segmented_blocks_list(segmented_blocks.cbegin(), segmented_blocks.cend());
292293
auto idx_to_iter = getIdxtoIterMap(segmented_blocks_list);
293294

294295
std::unordered_set<int> updated_segments;
295296
// we need to re-segment TensorRT segments whose inputs are TensorLists
296-
for (auto &use : usage_counts) {
297+
for (auto& use : usage_counts) {
297298
auto use_info = use.second;
298299
// For a particular tensorlist input, traverse through all ids of segmented blocks whose target is TensorRT
299300
for (auto i : use_info.tensorrt_use_id) {
300301
if (!updated_segments.count(i)) {
301-
// tensorlistinput_to_segblock is a mapping from {tensorlist input : segmented block which produced this tensorlist input}
302+
// tensorlistinput_to_segblock is a mapping from {tensorlist input : segmented block which produced this
303+
// tensorlist input}
302304
std::unordered_map<torch::jit::Value*, SegmentedBlock> tensorlistinput_to_segblock;
303305
for (auto input : segmented_blocks[i].raw_inputs()) {
304306
if (isTensorList(input)) {
@@ -308,18 +310,20 @@ void resolveTensorListInputBlocks(PartitionedGraph& segmented_blocks) {
308310

309311
// For each tensorlist input in tensorlistinput_to_segblock, get the node which actually uses this input.
310312
// Once we retrieve the node, we remove it from the current TensorRT segmented_blocks[i]. This node should be
311-
// added to block that generated/produced (can be obtained via produce_id) this tensorlist input in the first place.
313+
// added to block that generated/produced (can be obtained via produce_id) this tensorlist input in the first
314+
// place.
312315
auto seg_blocks = segmentBlocksWithTensorListInputs(segmented_blocks[i], tensorlistinput_to_segblock);
313316
auto append_blocks = seg_blocks.first;
314317
auto trt_block = seg_blocks.second;
315-
// Remove the current TensorRT seg_block and replace it with new TRT block (non empty) which has the node that uses tensorlist input removed.
318+
// Remove the current TensorRT seg_block and replace it with new TRT block (non empty) which has the node that
319+
// uses tensorlist input removed.
316320
auto next_iter = segmented_blocks_list.erase(idx_to_iter[i]);
317321
if (trt_block.raw_nodes().size() > 0) {
318322
segmented_blocks_list.insert(next_iter, trt_block);
319323
}
320324

321325
// append blocks' nodes to the producer seg_block
322-
for (auto append_block: append_blocks) {
326+
for (auto append_block : append_blocks) {
323327
auto input = append_block.first; // corresponds to the tensorlist input
324328
auto block = append_block.second;
325329
// append nodes to segmented_blocks_list

core/partitioning/shape_analysis.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
#include <ATen/ATen.h>
21
#include "core/partitioning/shape_analysis.h"
2+
#include <ATen/ATen.h>
33
#include "core/util/prelude.h"
44
#include "torch/csrc/jit/api/module.h"
55
#include "torch/csrc/jit/passes/constant_pooling.h"

tests/util/util.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#pragma once
22

3+
#include <ATen/ATen.h>
34
#include <string>
45
#include <vector>
5-
#include <ATen/ATen.h>
66
#include "ATen/Tensor.h"
77
#include "core/ir/ir.h"
88
#include "core/util/prelude.h"

0 commit comments

Comments
 (0)