@@ -94,6 +94,7 @@ bool _validate_shapes(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngi
94
94
void setup_input_tensors (
95
95
std::vector<at::Tensor> inputs,
96
96
c10::intrusive_ptr<TRTEngine> compiled_engine,
97
+ bool cudagraphs_enabled,
97
98
bool need_cudagraphs_record) {
98
99
// this is a buffer to store shape tensor input addresses throughout the runtime scope
99
100
std::list<std::vector<int64_t >> inputShapeTensorValues;
@@ -127,7 +128,7 @@ void setup_input_tensors(
127
128
compiled_engine->exec_ctx ->setTensorAddress (name.c_str (), inputShapeTensorValues.back ().data ()),
128
129
" Error while setting the tensor address for shape inputs" );
129
130
130
- if (CUDAGRAPHS_MODE ) {
131
+ if (cudagraphs_enabled ) {
131
132
// @peri044 I dont know if this makes sense since they are supposed to be GPU buffers
132
133
compiled_engine->input_buffers [i] = input_cpu;
133
134
}
@@ -147,7 +148,7 @@ void setup_input_tensors(
147
148
TORCHTRT_CHECK (
148
149
compiled_engine->exec_ctx ->setInputShape (name.c_str (), dims), " Error while setting the input shape" );
149
150
150
- if (CUDAGRAPHS_MODE ) {
151
+ if (cudagraphs_enabled ) {
151
152
// If using CUDAGraphs copy formatted input to the corresponding persistent input buffer
152
153
compiled_engine->input_buffers [i].copy_ (formatted_inputs.back (), true );
153
154
TORCHTRT_CHECK (
@@ -202,15 +203,16 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
202
203
compiled_engine->cudagraph .enable_debug_mode ();
203
204
}
204
205
bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS);
206
+ bool shape_changed = _validate_shapes (inputs, compiled_engine);
205
207
206
208
// Whether cudagraphs needs to record the graph on this pass
207
- // Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
208
- bool need_cudagraphs_record = cudagraphs_enabled &&
209
- ((!compiled_engine->prev_cudagraphs_enabled ) || (!_cudagraphs_validate_shapes (inputs, compiled_engine)));
209
+ auto result = compiled_engine->runtime_states .set_runtime_states (
210
+ cudagraphs_enabled, compiled_engine->use_pre_allocated_outputs , shape_changed);
210
211
211
- compiled_engine->prev_cudagraphs_enabled = cudagraphs_enabled;
212
+ bool need_cudagraphs_record = std::get<0 >(result);
213
+ bool can_use_pre_allocated_outputs = std::get<1 >(result);
212
214
213
- if (!cudagraphs_enabled) {
215
+ if (!cudagraphs_enabled || shape_changed ) {
214
216
compiled_engine->cudagraph .reset ();
215
217
}
216
218
@@ -272,69 +274,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
272
274
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path );
273
275
}
274
276
275
- for (size_t i = 0 ; i < inputs.size (); i++) {
276
- std::string name = compiled_engine->in_binding_names [i];
277
-
278
- TORCHTRT_CHECK (
279
- inputs[i].is_cuda (), " Expected input tensors to have device cuda, found device " << inputs[i].device ());
280
-
281
- auto expected_type =
282
- util::TRTDataTypeToScalarType (compiled_engine->exec_ctx ->getEngine ().getTensorDataType (name.c_str ()));
283
- TORCHTRT_CHECK (
284
- inputs[i].dtype () == expected_type,
285
- " Expected input tensors to have type " << expected_type << " , found type " << inputs[i].dtype ());
286
-
287
- auto dims = core::util::toDims (inputs[i].sizes ());
288
- auto shape = core::util::toVec (dims);
289
- LOG_DEBUG (" Input Name: " << name << " Shape: " << dims);
290
-
291
- if (compiled_engine->cuda_engine ->isShapeInferenceIO (name.c_str ())) {
292
- // Shape tensor inputs are casted to int64 explicitly.
293
- // Refer to
294
- // https://github.com/NVIDIA/TensorRT/blob/d2f4ef789a9a6ffdf37b55c3f81b486225f6b380/samples/common/sampleInference.cpp#L435
295
- auto input_cpu = inputs[i].clone ().contiguous ().cpu ().to (torch::kInt64 );
296
- std::vector<int64_t > inputs_cpu_vec (
297
- input_cpu.data_ptr <int64_t >(), input_cpu.data_ptr <int64_t >() + input_cpu.numel ());
298
- inputShapeTensorValues.emplace_back (inputs_cpu_vec);
299
- TORCHTRT_CHECK (
300
- compiled_engine->exec_ctx ->setTensorAddress (name.c_str (), inputShapeTensorValues.back ().data ()),
301
- " Error while setting the tensor address for shape inputs" );
302
-
303
- if (cudagraphs_enabled) {
304
- // @peri044 I dont know if this makes sense since they are supposed to be GPU buffers
305
- compiled_engine->input_buffers [i] = input_cpu;
306
- }
307
- TORCHTRT_CHECK (
308
- compiled_engine->exec_ctx ->setTensorAddress (name.c_str (), inputShapeTensorValues.back ().data ()),
309
- " Error while setting the tensor address for shape inputs" );
310
-
311
- } else {
312
- at::Tensor contig_input = inputs[i].view (shape).contiguous ();
313
- formatted_inputs.emplace_back (std::move (contig_input));
314
-
315
- if (need_cudagraphs_record) {
316
- // Create a new persistent input buffer
317
- compiled_engine->input_buffers [i] = std::move (formatted_inputs.back ().clone ());
318
- }
319
-
320
- TORCHTRT_CHECK (
321
- compiled_engine->exec_ctx ->setInputShape (name.c_str (), dims), " Error while setting the input shape" );
322
-
323
- if (cudagraphs_enabled) {
324
- // If using CUDAGraphs copy formatted input to the corresponding persistent input buffer
325
- compiled_engine->input_buffers [i].copy_ (formatted_inputs.back (), true );
326
- TORCHTRT_CHECK (
327
- compiled_engine->exec_ctx ->setTensorAddress (name.c_str (), compiled_engine->input_buffers [i].data_ptr ()),
328
- " Error while setting the input tensor address for inputs" );
329
- } else {
330
- // Otherwise use the formatted buffer directly
331
- TORCHTRT_CHECK (
332
- compiled_engine->exec_ctx ->setTensorAddress (name.c_str (), formatted_inputs.back ().data_ptr ()),
333
- " Error while setting the input tensor address for inputs" );
334
- }
335
- }
336
- }
337
-
277
+ setup_input_tensors (inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record);
338
278
// Check if input shapes can be inferred.
339
279
int32_t const io_size{compiled_engine->cuda_engine ->getNbIOTensors ()};
340
280
std::vector<char const *> names (io_size);
0 commit comments