1
- #pragma once
2
-
3
- #include < ostream>
4
- #include < vector>
5
-
6
- #include " NvInfer.h"
7
- #include " core/ir/ir.h"
8
- #include " core/partitioning/PartitionInfo.h"
9
- #include " torch/csrc/jit/ir/ir.h"
10
-
11
- namespace torch_tensorrt {
12
- namespace core {
13
- namespace partitioning {
14
-
15
- struct SegmentedBlock {
16
- public:
17
- enum SegmentedBlockTarget {
18
- kTorch ,
19
- kTensorRT ,
20
- };
21
-
22
- static std::string target_to_str (SegmentedBlockTarget t) {
23
- if (t == SegmentedBlockTarget::kTorch ) {
24
- return " Torch" ;
25
- } else {
26
- return " TensorRT" ;
27
- }
28
- }
29
-
30
- using BlockID = uint64_t ;
31
-
32
- SegmentedBlock () = default ;
33
- SegmentedBlock (SegmentedBlockTarget blk_target) : target_(blk_target), g_(std::make_shared<torch::jit::Graph>()) {}
34
- SegmentedBlock (SegmentedBlockTarget blk_target, const std::vector<torch::jit::Node*>& nodes);
35
- SegmentedBlock (SegmentedBlockTarget blk_target, std::shared_ptr<torch::jit::Graph> g) : target_(blk_target), g_(g) {}
36
- SegmentedBlock (BlockID id, SegmentedBlockTarget blk_target, const std::vector<torch::jit::Node*>& nodes);
37
-
38
- torch::jit::Value* getOrAddInputForValue (torch::jit::Value* v);
39
- torch::jit::Node* cloneNode (torch::jit::Node* node);
40
- void appendNode (torch::jit::Node* n) {
41
- cloneNode (n);
42
- }
43
- void registerOutput (torch::jit::Value* raw_output);
44
- torch::jit::graph_node_list nodes () {
45
- return g_->nodes ();
46
- }
47
- const std::vector<torch::jit::Node*>& raw_nodes () const {
48
- return nodes_;
49
- }
50
- torch::jit::Block* block () {
51
- return g_->block ();
52
- }
53
- std::shared_ptr<torch::jit::Graph>& g () {
54
- return g_;
55
- }
56
- void update_graph (std::shared_ptr<torch::jit::Graph> new_g) {
57
- g_ = new_g;
58
- }
59
- c10::ArrayRef<torch::jit::Value*> inputs () {
60
- return g_->inputs ();
61
- }
62
- c10::ArrayRef<torch::jit::Value*> outputs () {
63
- return g_->outputs ();
64
- }
65
- const std::vector<torch::jit::Value*>& raw_inputs () const {
66
- return inputs_;
67
- }
68
- const std::vector<torch::jit::Value*>& raw_outputs () const {
69
- return outputs_;
70
- }
71
- void eraseInput (size_t i);
72
- void eraseOutput (size_t i);
73
- bool contain_raw_value (torch::jit::Value* input) {
74
- return old_to_new_.count (input);
75
- }
76
- void register_inshapes (std::vector<ir::Input>& in_shapes) {
77
- in_shapes_ = in_shapes;
78
- }
79
- const std::vector<ir::Input>& in_shapes () const {
80
- return in_shapes_;
81
- }
82
- void register_intypes (std::vector<at::ScalarType>& in_types) {
83
- in_types_ = in_types;
84
- }
85
- const std::vector<at::ScalarType>& in_types () const {
86
- return in_types_;
87
- }
88
- void update_target (SegmentedBlockTarget new_target) {
89
- target_ = new_target;
90
- }
91
- enum SegmentedBlockTarget target () {
92
- return target_;
93
- }
94
-
95
- friend std::ostream& operator <<(std::ostream& os, const SegmentedBlock& b);
96
-
97
- private:
98
- BlockID id_;
99
- SegmentedBlockTarget target_;
100
- std::vector<ir::Input> in_shapes_;
101
- std::vector<at::ScalarType> in_types_;
102
- std::vector<torch::jit::Value*> inputs_;
103
- std::vector<torch::jit::Value*> outputs_;
104
- std::vector<torch::jit::Node*> nodes_;
105
- std::shared_ptr<torch::jit::Graph> g_;
106
- std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_;
107
- };
108
-
109
- std::ostream& operator <<(std::ostream& os, const SegmentedBlock::SegmentedBlockTarget& t);
110
-
111
- } // namespace partitioning
112
- } // namespace core
1
+ #pragma once
2
+
3
+ #include < ostream>
4
+ #include < vector>
5
+
6
+ #include " NvInfer.h"
7
+ #include " core/ir/ir.h"
8
+ #include " core/partitioning/PartitionInfo.h"
9
+ #include " torch/csrc/jit/ir/ir.h"
10
+
11
+ namespace torch_tensorrt {
12
+ namespace core {
13
+ namespace partitioning {
14
+
15
+ struct SegmentedBlock {
16
+ public:
17
+ enum SegmentedBlockTarget {
18
+ kTorch ,
19
+ kTensorRT ,
20
+ };
21
+
22
+ static std::string target_to_str (SegmentedBlockTarget t) {
23
+ if (t == SegmentedBlockTarget::kTorch ) {
24
+ return " Torch" ;
25
+ } else {
26
+ return " TensorRT" ;
27
+ }
28
+ }
29
+
30
+ using BlockID = uint64_t ;
31
+
32
+ SegmentedBlock () = default ;
33
+ SegmentedBlock (SegmentedBlockTarget blk_target) : target_(blk_target), g_(std::make_shared<torch::jit::Graph>()) {}
34
+ SegmentedBlock (SegmentedBlockTarget blk_target, const std::vector<torch::jit::Node*>& nodes);
35
+ SegmentedBlock (SegmentedBlockTarget blk_target, std::shared_ptr<torch::jit::Graph> g) : target_(blk_target), g_(g) {}
36
+ SegmentedBlock (BlockID id, SegmentedBlockTarget blk_target, const std::vector<torch::jit::Node*>& nodes);
37
+
38
+ torch::jit::Value* getOrAddInputForValue (torch::jit::Value* v);
39
+ torch::jit::Node* cloneNode (torch::jit::Node* node);
40
+ void appendNode (torch::jit::Node* n) {
41
+ cloneNode (n);
42
+ }
43
+ void registerOutput (torch::jit::Value* raw_output);
44
+ torch::jit::graph_node_list nodes () {
45
+ return g_->nodes ();
46
+ }
47
+ const std::vector<torch::jit::Node*>& raw_nodes () const {
48
+ return nodes_;
49
+ }
50
+ torch::jit::Block* block () {
51
+ return g_->block ();
52
+ }
53
+ std::shared_ptr<torch::jit::Graph>& g () {
54
+ return g_;
55
+ }
56
+ void update_graph (std::shared_ptr<torch::jit::Graph> new_g) {
57
+ g_ = new_g;
58
+ }
59
+ c10::ArrayRef<torch::jit::Value*> inputs () {
60
+ return g_->inputs ();
61
+ }
62
+ c10::ArrayRef<torch::jit::Value*> outputs () {
63
+ return g_->outputs ();
64
+ }
65
+ const std::vector<torch::jit::Value*>& raw_inputs () const {
66
+ return inputs_;
67
+ }
68
+ const std::vector<torch::jit::Value*>& raw_outputs () const {
69
+ return outputs_;
70
+ }
71
+ void eraseInput (size_t i);
72
+ void eraseOutput (size_t i);
73
+ bool contain_raw_value (torch::jit::Value* input) {
74
+ return old_to_new_.count (input);
75
+ }
76
+ void register_inshapes (std::vector<ir::Input>& in_shapes) {
77
+ in_shapes_ = in_shapes;
78
+ }
79
+ const std::vector<ir::Input>& in_shapes () const {
80
+ return in_shapes_;
81
+ }
82
+ void register_intypes (std::vector<at::ScalarType>& in_types) {
83
+ in_types_ = in_types;
84
+ }
85
+ const std::vector<at::ScalarType>& in_types () const {
86
+ return in_types_;
87
+ }
88
+ void update_id (BlockID new_id) {
89
+ id_ = new_id;
90
+ }
91
+ void update_target (SegmentedBlockTarget new_target) {
92
+ target_ = new_target;
93
+ }
94
+ enum SegmentedBlockTarget target () {
95
+ return target_;
96
+ }
97
+
98
+ friend std::ostream& operator <<(std::ostream& os, const SegmentedBlock& b);
99
+
100
+ private:
101
+ BlockID id_;
102
+ SegmentedBlockTarget target_;
103
+ std::vector<ir::Input> in_shapes_;
104
+ std::vector<at::ScalarType> in_types_;
105
+ std::vector<torch::jit::Value*> inputs_;
106
+ std::vector<torch::jit::Value*> outputs_;
107
+ std::vector<torch::jit::Node*> nodes_;
108
+ std::shared_ptr<torch::jit::Graph> g_;
109
+ std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_;
110
+ };
111
+
112
+ std::ostream& operator <<(std::ostream& os, const SegmentedBlock::SegmentedBlockTarget& t);
113
+
114
+ } // namespace partitioning
115
+ } // namespace core
113
116
} // namespace torch_tensorrt
0 commit comments