@@ -104,6 +104,7 @@ class ModelState : public BackendModel {
104
104
return enable_jit_executor_pair_;
105
105
}
106
106
bool EnabledInferenceMode () { return enable_inference_mode_; }
107
+ bool EnabledCudnn () { return enable_cudnn_; }
107
108
bool EnabledCacheCleaning () { return enable_cache_cleaning_; }
108
109
109
110
bool EnabledWeightSharing () { return enable_weight_sharing_; }
@@ -125,6 +126,9 @@ class ModelState : public BackendModel {
125
126
// Flag to indicate whether inference mode is enabled. Defaults to false.
126
127
bool enable_inference_mode_;
127
128
129
+ // Flag to indicate whether cudnn is enabled. Defaults to true.
130
+ bool enable_cudnn_;
131
+
128
132
// Flag to indicate whether cache cleaning after each run is enabled.
129
133
// Defaults to false.
130
134
bool enable_cache_cleaning_;
@@ -227,8 +231,9 @@ ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state)
227
231
228
232
ModelState::ModelState (TRITONBACKEND_Model* triton_model)
229
233
: BackendModel(triton_model), enable_optimized_execution_(true ),
230
- enable_inference_mode_ (true ), enable_cache_cleaning_(false ),
231
- enable_weight_sharing_(false ), enable_tensor_fuser_pair_({false , true }),
234
+ enable_inference_mode_ (true ), enable_cudnn_(true ),
235
+ enable_cache_cleaning_(false ), enable_weight_sharing_(false ),
236
+ enable_tensor_fuser_pair_({false , true }),
232
237
enable_jit_profiling_pair_({false , true }),
233
238
enable_jit_executor_pair_({false , true })
234
239
{
@@ -393,6 +398,24 @@ ModelState::ParseParameters()
393
398
" for model instance '" + Name () + " '" )
394
399
.c_str ());
395
400
401
+ // If 'DISABLE_CUDNN' is not present in 'parameters' then no update is made
402
+ // to 'enable_cudnn_'.
403
+ bool disable_cudnn = false ;
404
+ err = ParseParameter (params, " DISABLE_CUDNN" , &disable_cudnn);
405
+ if (err != nullptr ) {
406
+ if (TRITONSERVER_ErrorCode (err) != TRITONSERVER_ERROR_NOT_FOUND) {
407
+ return err;
408
+ } else {
409
+ TRITONSERVER_ErrorDelete (err);
410
+ }
411
+ }
412
+ enable_cudnn_ = !disable_cudnn;
413
+ LOG_MESSAGE (
414
+ TRITONSERVER_LOG_INFO,
415
+ (std::string (" cuDNN is " ) + (enable_cudnn_ ? " enabled" : " disabled" ) +
416
+ " for model instance '" + Name () + " '" )
417
+ .c_str ());
418
+
396
419
// If 'ENABLE_TENSOR_FUSER' is not present in 'parameters' then no
397
420
// update is made to 'enable_tensor_fuser'.
398
421
bool enable_tensor_fuser = false ;
@@ -1562,6 +1585,9 @@ ModelInstanceState::Execute(
1562
1585
// enable/disable inference mode - supersedes NoGradGuard
1563
1586
torch::InferenceMode infer_guard (model_state_->EnabledInferenceMode ());
1564
1587
1588
+ // enable/disable cudnn
1589
+ at::globalContext ().setUserEnabledCuDNN (model_state_->EnabledCudnn ());
1590
+
1565
1591
// JIT. No change is made unless parameter is explicitly set.
1566
1592
if (std::get<0 >(model_state_->EnabledJitProfiling ())) {
1567
1593
torch::jit::getProfilingMode () =
0 commit comments