Skip to content

Commit 096a5d4

Browse files
committed
refactor(//core/partitioning): Refactor partitioning
to use LF line endings and fix block ids Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 20e5d41 commit 096a5d4

File tree

2 files changed

+543
-537
lines changed

2 files changed

+543
-537
lines changed

core/partitioning/SegmentedBlock.h

Lines changed: 115 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -1,113 +1,116 @@
1-
#pragma once
2-
3-
#include <ostream>
4-
#include <vector>
5-
6-
#include "NvInfer.h"
7-
#include "core/ir/ir.h"
8-
#include "core/partitioning/PartitionInfo.h"
9-
#include "torch/csrc/jit/ir/ir.h"
10-
11-
namespace torch_tensorrt {
12-
namespace core {
13-
namespace partitioning {
14-
15-
struct SegmentedBlock {
16-
public:
17-
enum SegmentedBlockTarget {
18-
kTorch,
19-
kTensorRT,
20-
};
21-
22-
static std::string target_to_str(SegmentedBlockTarget t) {
23-
if (t == SegmentedBlockTarget::kTorch) {
24-
return "Torch";
25-
} else {
26-
return "TensorRT";
27-
}
28-
}
29-
30-
using BlockID = uint64_t;
31-
32-
SegmentedBlock() = default;
33-
SegmentedBlock(SegmentedBlockTarget blk_target) : target_(blk_target), g_(std::make_shared<torch::jit::Graph>()) {}
34-
SegmentedBlock(SegmentedBlockTarget blk_target, const std::vector<torch::jit::Node*>& nodes);
35-
SegmentedBlock(SegmentedBlockTarget blk_target, std::shared_ptr<torch::jit::Graph> g) : target_(blk_target), g_(g) {}
36-
SegmentedBlock(BlockID id, SegmentedBlockTarget blk_target, const std::vector<torch::jit::Node*>& nodes);
37-
38-
torch::jit::Value* getOrAddInputForValue(torch::jit::Value* v);
39-
torch::jit::Node* cloneNode(torch::jit::Node* node);
40-
void appendNode(torch::jit::Node* n) {
41-
cloneNode(n);
42-
}
43-
void registerOutput(torch::jit::Value* raw_output);
44-
torch::jit::graph_node_list nodes() {
45-
return g_->nodes();
46-
}
47-
const std::vector<torch::jit::Node*>& raw_nodes() const {
48-
return nodes_;
49-
}
50-
torch::jit::Block* block() {
51-
return g_->block();
52-
}
53-
std::shared_ptr<torch::jit::Graph>& g() {
54-
return g_;
55-
}
56-
void update_graph(std::shared_ptr<torch::jit::Graph> new_g) {
57-
g_ = new_g;
58-
}
59-
c10::ArrayRef<torch::jit::Value*> inputs() {
60-
return g_->inputs();
61-
}
62-
c10::ArrayRef<torch::jit::Value*> outputs() {
63-
return g_->outputs();
64-
}
65-
const std::vector<torch::jit::Value*>& raw_inputs() const {
66-
return inputs_;
67-
}
68-
const std::vector<torch::jit::Value*>& raw_outputs() const {
69-
return outputs_;
70-
}
71-
void eraseInput(size_t i);
72-
void eraseOutput(size_t i);
73-
bool contain_raw_value(torch::jit::Value* input) {
74-
return old_to_new_.count(input);
75-
}
76-
void register_inshapes(std::vector<ir::Input>& in_shapes) {
77-
in_shapes_ = in_shapes;
78-
}
79-
const std::vector<ir::Input>& in_shapes() const {
80-
return in_shapes_;
81-
}
82-
void register_intypes(std::vector<at::ScalarType>& in_types) {
83-
in_types_ = in_types;
84-
}
85-
const std::vector<at::ScalarType>& in_types() const {
86-
return in_types_;
87-
}
88-
void update_target(SegmentedBlockTarget new_target) {
89-
target_ = new_target;
90-
}
91-
enum SegmentedBlockTarget target() {
92-
return target_;
93-
}
94-
95-
friend std::ostream& operator<<(std::ostream& os, const SegmentedBlock& b);
96-
97-
private:
98-
BlockID id_;
99-
SegmentedBlockTarget target_;
100-
std::vector<ir::Input> in_shapes_;
101-
std::vector<at::ScalarType> in_types_;
102-
std::vector<torch::jit::Value*> inputs_;
103-
std::vector<torch::jit::Value*> outputs_;
104-
std::vector<torch::jit::Node*> nodes_;
105-
std::shared_ptr<torch::jit::Graph> g_;
106-
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_;
107-
};
108-
109-
std::ostream& operator<<(std::ostream& os, const SegmentedBlock::SegmentedBlockTarget& t);
110-
111-
} // namespace partitioning
112-
} // namespace core
1+
#pragma once
2+
3+
#include <ostream>
4+
#include <vector>
5+
6+
#include "NvInfer.h"
7+
#include "core/ir/ir.h"
8+
#include "core/partitioning/PartitionInfo.h"
9+
#include "torch/csrc/jit/ir/ir.h"
10+
11+
namespace torch_tensorrt {
12+
namespace core {
13+
namespace partitioning {
14+
15+
struct SegmentedBlock {
16+
public:
17+
enum SegmentedBlockTarget {
18+
kTorch,
19+
kTensorRT,
20+
};
21+
22+
static std::string target_to_str(SegmentedBlockTarget t) {
23+
if (t == SegmentedBlockTarget::kTorch) {
24+
return "Torch";
25+
} else {
26+
return "TensorRT";
27+
}
28+
}
29+
30+
using BlockID = uint64_t;
31+
32+
SegmentedBlock() = default;
33+
SegmentedBlock(SegmentedBlockTarget blk_target) : target_(blk_target), g_(std::make_shared<torch::jit::Graph>()) {}
34+
SegmentedBlock(SegmentedBlockTarget blk_target, const std::vector<torch::jit::Node*>& nodes);
35+
SegmentedBlock(SegmentedBlockTarget blk_target, std::shared_ptr<torch::jit::Graph> g) : target_(blk_target), g_(g) {}
36+
SegmentedBlock(BlockID id, SegmentedBlockTarget blk_target, const std::vector<torch::jit::Node*>& nodes);
37+
38+
torch::jit::Value* getOrAddInputForValue(torch::jit::Value* v);
39+
torch::jit::Node* cloneNode(torch::jit::Node* node);
40+
void appendNode(torch::jit::Node* n) {
41+
cloneNode(n);
42+
}
43+
void registerOutput(torch::jit::Value* raw_output);
44+
torch::jit::graph_node_list nodes() {
45+
return g_->nodes();
46+
}
47+
const std::vector<torch::jit::Node*>& raw_nodes() const {
48+
return nodes_;
49+
}
50+
torch::jit::Block* block() {
51+
return g_->block();
52+
}
53+
std::shared_ptr<torch::jit::Graph>& g() {
54+
return g_;
55+
}
56+
void update_graph(std::shared_ptr<torch::jit::Graph> new_g) {
57+
g_ = new_g;
58+
}
59+
c10::ArrayRef<torch::jit::Value*> inputs() {
60+
return g_->inputs();
61+
}
62+
c10::ArrayRef<torch::jit::Value*> outputs() {
63+
return g_->outputs();
64+
}
65+
const std::vector<torch::jit::Value*>& raw_inputs() const {
66+
return inputs_;
67+
}
68+
const std::vector<torch::jit::Value*>& raw_outputs() const {
69+
return outputs_;
70+
}
71+
void eraseInput(size_t i);
72+
void eraseOutput(size_t i);
73+
bool contain_raw_value(torch::jit::Value* input) {
74+
return old_to_new_.count(input);
75+
}
76+
void register_inshapes(std::vector<ir::Input>& in_shapes) {
77+
in_shapes_ = in_shapes;
78+
}
79+
const std::vector<ir::Input>& in_shapes() const {
80+
return in_shapes_;
81+
}
82+
void register_intypes(std::vector<at::ScalarType>& in_types) {
83+
in_types_ = in_types;
84+
}
85+
const std::vector<at::ScalarType>& in_types() const {
86+
return in_types_;
87+
}
88+
void update_id(BlockID new_id) {
89+
id_ = new_id;
90+
}
91+
void update_target(SegmentedBlockTarget new_target) {
92+
target_ = new_target;
93+
}
94+
enum SegmentedBlockTarget target() {
95+
return target_;
96+
}
97+
98+
friend std::ostream& operator<<(std::ostream& os, const SegmentedBlock& b);
99+
100+
private:
101+
BlockID id_;
102+
SegmentedBlockTarget target_;
103+
std::vector<ir::Input> in_shapes_;
104+
std::vector<at::ScalarType> in_types_;
105+
std::vector<torch::jit::Value*> inputs_;
106+
std::vector<torch::jit::Value*> outputs_;
107+
std::vector<torch::jit::Node*> nodes_;
108+
std::shared_ptr<torch::jit::Graph> g_;
109+
std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_;
110+
};
111+
112+
std::ostream& operator<<(std::ostream& os, const SegmentedBlock::SegmentedBlockTarget& t);
113+
114+
} // namespace partitioning
115+
} // namespace core
113116
} // namespace torch_tensorrt

0 commit comments

Comments
 (0)