Skip to content

Commit fa942e7

Browse files
author
Anurag Dixit
committed
Redesgined version
Signed-off-by: Anurag Dixit <[email protected]>
1 parent 01a8920 commit fa942e7

File tree

3 files changed

+148
-57
lines changed

3 files changed

+148
-57
lines changed

core/compiler.cpp

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,9 @@ c10::FunctionSchema GenerateGraphSchema(
4747
void AddEngineToGraph(
4848
torch::jit::script::Module mod,
4949
std::shared_ptr<torch::jit::Graph>& g,
50-
std::string& serialized_engine) {
51-
runtime::CudaDevice device;
52-
53-
// Read current CUDA device properties
54-
runtime::get_cuda_device(device);
55-
56-
// Serialize current device information
57-
auto device_info = runtime::serialize_device(device);
58-
59-
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name(), serialized_engine, device_info);
50+
std::string& engine,
51+
CudaDevice& device_info) {
52+
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name(), engine, device_info);
6053
// Get required metadata about the engine out
6154
auto num_io = engine_ptr->num_io;
6255
auto name = engine_ptr->name;
@@ -166,12 +159,15 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
166159
// torch::jit::script::Module new_mod = mod.clone();
167160
torch::jit::script::Module new_mod(mod._ivalue()->name() + "_trt");
168161
std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
162+
169163
for (const torch::jit::script::Method& method : mod.get_methods()) {
170164
// Don't convert hidden methods
171165
if (method.name().rfind("_", 0)) {
172166
auto engine = ConvertGraphToTRTEngine(mod, method.name(), cfg);
173167
auto new_g = std::make_shared<torch::jit::Graph>();
174-
AddEngineToGraph(new_mod, new_g, engine);
168+
169+
auto cuda_device = runtime::spec_to_device(cfg->convert_info.engine_settings.device);
170+
AddEngineToGraph(new_mod, new_g, engine, cuda_device);
175171
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
176172
auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g);
177173
new_mod.type()->addMethod(new_method);

core/runtime/TRTEngine.cpp

Lines changed: 113 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@ std::string slugify(std::string s) {
1616
return s;
1717
}
1818

19-
TRTEngine::TRTEngine(std::string serialized_engine)
19+
TRTEngine::TRTEngine(std::string serialized_engine, CudaDevice device)
2020
: logger(
2121
std::string("[] - "),
2222
util::logging::get_logger().get_reportable_severity(),
2323
util::logging::get_logger().get_is_colored_output_on()) {
2424
std::string _name = "deserialized_trt";
25-
new (this) TRTEngine(_name, serialized_engine, std::string());
25+
new (this) TRTEngine(_name, serialized_engine, device);
2626
}
2727

2828
TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
@@ -31,27 +31,23 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
3131
util::logging::get_logger().get_reportable_severity(),
3232
util::logging::get_logger().get_is_colored_output_on()) {
3333
std::string _name = "deserialized_trt";
34-
std::string device_info = serialized_info[0];
35-
std::string engine_info = serialized_info[1];
34+
std::string engine_info = serialized_info[EngineIdx];
3635

37-
new (this) TRTEngine(_name, engine_info, device_info);
36+
CudaDevice cuda_device = deserialize_device(serialized_info[DeviceIdx]);
37+
38+
new (this) TRTEngine(_name, engine_info, cuda_device);
3839
}
3940

4041
TRTEngine::TRTEngine(
4142
std::string mod_name,
4243
std::string serialized_engine,
43-
std::string serialized_device_info = std::string())
44+
CudaDevice cuda_device)
4445
: logger(
4546
std::string("[") + mod_name + std::string("_engine] - "),
4647
util::logging::get_logger().get_reportable_severity(),
4748
util::logging::get_logger().get_is_colored_output_on()) {
48-
CudaDevice cuda_device;
49-
// Deserialize device meta data if device_info is non-empty
50-
if (!serialized_device_info.empty()) {
51-
cuda_device = deserialize_device(serialized_device_info);
52-
// Set CUDA device as configured in serialized meta data
53-
set_cuda_device(cuda_device);
54-
}
49+
50+
set_cuda_device(cuda_device);
5551

5652
rt = nvinfer1::createInferRuntime(logger);
5753

@@ -120,41 +116,63 @@ static auto TRTORCH_UNUSED TRTEngineTSRegistrtion =
120116
// Adding device info related meta data to the serialized file
121117
auto trt_engine = std::string((const char*)serialized_trt_engine->data(), serialized_trt_engine->size());
122118

123-
CudaDevice cuda_device;
124-
get_cuda_device(cuda_device);
125119
std::vector<std::string> serialize_info;
126-
serialize_info.push_back(serialize_device(cuda_device));
120+
serialize_info.push_back(serialize_device(self.cuda_device));
127121
serialize_info.push_back(trt_engine);
128122
return serialize_info;
129123
},
130124
[](std::vector<std::string> seralized_info) -> c10::intrusive_ptr<TRTEngine> {
131125
return c10::make_intrusive<TRTEngine>(std::move(seralized_info));
132126
});
133127

134-
int CudaDevice::get_id(void) {
128+
int64_t CudaDevice::get_id(void) {
135129
return this->id;
136130
}
137131

138-
void CudaDevice::set_id(int id) {
132+
void CudaDevice::set_id(int64_t id) {
139133
this->id = id;
140134
}
141135

142-
int CudaDevice::get_major(void) {
136+
int64_t CudaDevice::get_major(void) {
143137
return this->major;
144138
}
145139

146-
void CudaDevice::set_major(int major) {
140+
void CudaDevice::set_major(int64_t major) {
147141
this->major = major;
148142
}
149143

150-
int CudaDevice::get_minor(void) {
144+
int64_t CudaDevice::get_minor(void) {
151145
return this->minor;
152146
}
153147

154-
void CudaDevice::set_minor(int minor) {
148+
void CudaDevice::set_minor(int64_t minor) {
155149
this->minor = minor;
156150
}
157151

152+
nvinfer1::DeviceType get_device_type(void) {
153+
return this->device_type;
154+
}
155+
156+
void set_device_type(nvinfer1::DeviceType device_type) {
157+
this->device_type = device_type;
158+
}
159+
160+
std::string get_device_name(void) {
161+
return this->device_name;
162+
}
163+
164+
void set_device_name(std::string& name) {
165+
this->device_name = name;
166+
}
167+
168+
size_t get_device_name_len(void) {
169+
return this->device_name_len;
170+
}
171+
172+
void set_device_name_len(size_t size) {
173+
this->device_name_len = size;
174+
}
175+
158176
void set_cuda_device(CudaDevice& cuda_device) {
159177
TRTORCH_CHECK((cudaSetDevice(cuda_device.id) == cudaSuccess), "Unable to set device: " << cuda_device.id);
160178
}
@@ -167,48 +185,106 @@ void get_cuda_device(CudaDevice& cuda_device) {
167185
"Unable to get CUDA properties from device:" << cuda_device.id);
168186
cuda_device.set_major(device_prop.major);
169187
cuda_device.set_minor(device_prop.minor);
188+
cuda_device.set_device_name(std::string(device_prop.name));
170189
}
171190

172191
std::string serialize_device(CudaDevice& cuda_device) {
173192
void* buffer = new char[sizeof(cuda_device)];
174193
void* ref_buf = buffer;
175194

176-
int temp = cuda_device.get_id();
177-
memcpy(buffer, reinterpret_cast<int*>(&temp), sizeof(int));
178-
buffer = static_cast<char*>(buffer) + sizeof(int);
195+
int64_t temp = cuda_device.get_id();
196+
memcpy(buffer, reinterpret_cast<int64_t*>(&temp), sizeof(int64_t));
197+
buffer = static_cast<char*>(buffer) + sizeof(int64_t);
179198

180199
temp = cuda_device.get_major();
181-
memcpy(buffer, reinterpret_cast<int*>(&temp), sizeof(int));
182-
buffer = static_cast<char*>(buffer) + sizeof(int);
200+
memcpy(buffer, reinterpret_cast<int64_t*>(&temp), sizeof(int64_t));
201+
buffer = static_cast<char*>(buffer) + sizeof(int64_t);
183202

184203
temp = cuda_device.get_minor();
185-
memcpy(buffer, reinterpret_cast<int*>(&temp), sizeof(int));
186-
buffer = static_cast<char*>(buffer) + sizeof(int);
204+
memcpy(buffer, reinterpret_cast<int64_t*>(&temp), sizeof(int64_t));
205+
buffer = static_cast<char*>(buffer) + sizeof(int64_t);
206+
207+
auto device_type = cuda_device.get_device_type();
208+
memcpy(buffer, reinterpret_cast<char*>(&device_type), sizeof(nvinfer1::DeviceType));
209+
buffer = static_cast<char*>(buffer) + sizeof(nvinfer1::DeviceType);
210+
211+
size_t device_name_len = cuda_device.get_device_name_len();
212+
memcpy(buffer, reinterpret_cast<char*>(&device_name_len), sizeof(size_t));
213+
buffer = static_cast<char*>(buffer) + sizeof(size_t);
187214

188-
return std::string((const char*)ref_buf, sizeof(int) * 3);
215+
auto device_name = cuda_device.get_device_name();
216+
memcpy(buffer, reinterpret_cast<char*>(&device_name), device_name.size());
217+
buffer = static_cast<char*>(buffer) + device_name.size();
218+
219+
return std::string((const char*)ref_buf, sizeof(int64_t) * 3 + sizeof(nvinfer1::DeviceType) + device_name.size();
189220
}
190221

191222
CudaDevice deserialize_device(std::string device_info) {
192223
CudaDevice ret;
193224
char* buffer = new char[device_info.size() + 1];
194225
std::copy(device_info.begin(), device_info.end(), buffer);
195-
int temp = 0;
226+
int64_t temp = 0;
196227

197-
memcpy(&temp, reinterpret_cast<char*>(buffer), sizeof(int));
198-
buffer += sizeof(int);
228+
memcpy(&temp, reinterpret_cast<char*>(buffer), sizeof(int64_t));
229+
buffer += sizeof(int64_t);
199230
ret.set_id(temp);
200231

201-
memcpy(&temp, reinterpret_cast<char*>(buffer), sizeof(int));
202-
buffer += sizeof(int);
232+
memcpy(&temp, reinterpret_cast<char*>(buffer), sizeof(int64_t));
233+
buffer += sizeof(int64_t);
203234
ret.set_major(temp);
204235

205-
memcpy(&temp, reinterpret_cast<char*>(buffer), sizeof(int));
206-
buffer += sizeof(int);
236+
memcpy(&temp, reinterpret_cast<char*>(buffer), sizeof(int64_t));
237+
buffer += sizeof(int64_t);
207238
ret.set_minor(temp);
208239

240+
nvinfer1::DeviceType device_type;
241+
memcpy(&device_type, reinterpret_cast<char*>(buffer), sizeof(nvinfer1::DeviceType));
242+
buffer += sizeof(nvinfer1::DeviceType);
243+
244+
size_t size;
245+
memcpy(&size, reinterpret_cast<size_t*>(&buffer), sizeof(size_t));
246+
buffer += sizeof(size_t);
247+
248+
ret.set_device_name_len(size);
249+
250+
std::string device_name;
251+
memcpy(&device_name, reinterpret_cast<char*>(buffer), size * sizeof(char));
252+
buffer += size * sizeof(char);
253+
254+
ret.set_device_name(device_name);
255+
209256
return ret;
210257
}
211258

259+
CudaDevice spec_to_device(conversion::Device& spec) {
260+
CudaDevice device;
261+
cudaDeviceProp device_prop;
262+
263+
// Device ID
264+
device.set_id(spec.gpu_id);
265+
266+
// Get Device Properties
267+
cudaGetDeviceProperties(&device_prop, spec.gpu_id);
268+
269+
// Compute capability major version
270+
device.set_major(device_prop.major);
271+
272+
// Compute capability minor version
273+
device.set_minor(device_prop.minor);
274+
275+
std::string device_name = std::string(device_prop.name);
276+
277+
// Set Device name
278+
device.set_device_name(device_name);
279+
280+
// Set Device name len for serialization/deserialization
281+
device.set_device_name_len(device_nmae.size());
282+
283+
// Set Device Type
284+
device.set_device_type(spec.device_type);
285+
return device;
286+
}
287+
212288
} // namespace runtime
213289
} // namespace core
214290
} // namespace trtorch

core/runtime/runtime.h

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,36 @@ namespace runtime {
1111

1212
using EngineID = int64_t;
1313

14+
typedef enum {
15+
DeviceIdx = 0,
16+
EngineIdx
17+
}SerializedInfoIndex;
18+
1419
struct CudaDevice {
15-
int id; // CUDA device id
16-
int major; // CUDA compute major version
17-
int minor; // CUDA compute minor version
20+
int64_t id; // CUDA device id
21+
int64_t major; // CUDA compute major version
22+
int64_t minor; // CUDA compute minor version
23+
nvinfer1::DeviceType device_type;
24+
size_t device_name_len;
25+
std::string device_name;
26+
27+
nvinfer1::DeviceType get_device_type(void);
28+
void set_device_type(nvinfer1::DeviceType dev_type);
29+
30+
size_t get_device_name_len(void);
31+
void set_device_name_len(size_t size);
1832

19-
int get_id(void);
20-
void set_id(int id);
33+
std::string get_device_name(void);
34+
void set_device_name(std::string& name);
35+
36+
int64_t get_id(void);
37+
void set_id(int64_t id);
2138

22-
int get_major(void);
23-
void set_major(int major);
39+
int64_t get_major(void);
40+
void set_major(int64_t major);
2441

25-
int get_minor(void);
26-
void set_minor(int minor);
42+
int64_t get_minor(void);
43+
void set_minor(int64_t minor);
2744
};
2845

2946
void set_cuda_device(CudaDevice& cuda_device);
@@ -32,6 +49,8 @@ void get_cuda_device(CudaDevice& cuda_device);
3249
std::string serialize_device(CudaDevice& cuda_device);
3350
CudaDevice deserialize_device(std::string device_info);
3451

52+
CudaDevice spec_to_device(conversion::Device& spec);
53+
3554
struct TRTEngine : torch::CustomClassHolder {
3655
// Each engine needs it's own runtime object
3756
nvinfer1::IRuntime* rt;

0 commit comments

Comments
 (0)