Skip to content

Commit 89f631b

Browse files
authored
Merge pull request #1090 from pytorch/partitioning_documentation
doc: add the explanation for partition phases on docs
2 parents c12fc1e + 41dd042 commit 89f631b

File tree

1 file changed

+85
-0
lines changed

1 file changed

+85
-0
lines changed

docsrc/contributors/partitioning.rst

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,88 @@ Partitioning Phase
66
The phase is optional and enabled by the user. It instructs the compiler to separate nodes into ones that should run in PyTorch and ones that should run in TensorRT.
77
Criteria for separation include: Lack of a converter, operator is explicitly set to run in PyTorch by the user or the node has a flag which tells partitioning to
88
run in PyTorch by the module fallback passes.
9+
10+
On a high level, Torch-TensorRT partitioning phase does the following:
11+
12+
* Segmentation. Go through the set of operators in order and verify if there is converter for each operator. Then, roughly separate the graph into parts that Torch-TensorRT can support and parts Torch-TensorRT cannot.
13+
14+
* Dependency Analysis. For every to be compiled operator there is a "complete dependency graph", which means that every input can to traced back to an input as Tensor or TensorList. Go through all segments after segmentation then do dependency analysis to ensure that there are only Tensor/TensorList inputs and outputs for TensorRT segments.
15+
16+
* Shape Analysis. For each segments, figure out the input and outputs shapes starting from the provided input shape from the user. Shapes can be calculated by running the graphs with JIT.
17+
18+
* Conversion. Every TensorRT segments will be converted to TensorRT engine. This part is done in compiler.cpp, but it's still a phase in our partitioning process.
19+
20+
* Stitching. Stitch all TensorRT engines with PyTorch nodes altogether.
21+
22+
Here are the brief description of these functions of each file:
23+
24+
PartitonInfo.h/.cpp
25+
***********************************
26+
27+
`core/partitioning/PartitionInfo.h <https://github.com/pytorch/TensorRT/blob/master/core/partitioning/PartitionInfo.h>`_
28+
29+
The automatic fallback APIs that is used for partitioning.
30+
31+
32+
SegmentedBlock.h/.cpp
33+
***********************************
34+
35+
`core/partitioning/SegmentedBlock.h <https://github.com/pytorch/TensorRT/blob/master/core/partitioning/SegmentedBlock.h>`_
36+
37+
The main data structures that is used to maintain information for each segments after segmentation.
38+
39+
40+
shape_analysis.h/.cpp
41+
***********************************
42+
43+
`core/partitioning/shape_analysis.h <https://github.com/pytorch/TensorRT/blob/master/core/partitioning/shape_analysis.h>`_
44+
45+
Code implementation to get the shapes for each segments by running them in JIT.
46+
47+
48+
partitioning.h/.cpp
49+
***********************************
50+
`core/partitioning/partitioning.h <https://github.com/pytorch/TensorRT/blob/master/core/partitioning/partitioning.h>`_
51+
52+
APIs and main code implementation for partitioning phase.
53+
54+
Automatic Fallback
55+
====================
56+
57+
To enable automatic fallback feature, you can set following attributes in Python:
58+
59+
.. code-block:: none
60+
61+
import torch
62+
import torch_tensorrt as torchtrt
63+
64+
...
65+
model = MyModel()
66+
ts_model = torch.jit.script(model)
67+
trt_model = torchtrt.ts.compile(model, **{
68+
...
69+
"min_block_size" : 3,
70+
"torch_executed_ops": ["aten::add"],
71+
"torch_executed_modules": [],
72+
})
73+
74+
* enabled: By default automatic fallback will be off. It is enabled by setting it to True.
75+
* min_block_size: The minimum number of consecutive operations that must satisfy to be converted to TensorRT. For example, if it's set to 3, then there must be 3 consecutive supported operators then this segments will be converted.
76+
* forced_fallback_ops: A list of strings that will be the names of operations that the user explicitly want to be in PyTorch nodes.
77+
78+
79+
.. code-block:: none
80+
81+
#include "torch/script.h"
82+
#include "torch_tensorrt/torch_tensorrt.h"
83+
84+
...
85+
auto in = torch::randn({1, 3, 224, 224}, {torch::kCUDA});
86+
87+
auto mod = torch::jit::load("trt_ts_module.ts");
88+
auto input_sizes = std::vector<torchtrt::InputRange>{{in.sizes()}};
89+
torchtrt::ts::CompileSpec cfg(input_sizes);
90+
cfg.min_block_size = 2;
91+
cfg.torch_executed_ops.push_back("aten::relu");
92+
auto trt_mod = torchtrt::ts::compile(mod, cfg);
93+
auto out = trt_mod.forward({in});

0 commit comments

Comments
 (0)