Skip to content

Commit 8d14a80

Browse files
authored
Disable cudnn option (#123)
* add disable cudnn option * correct comment * clang format * add to readme --------- Co-authored-by: jamied <[email protected]>
1 parent 2704676 commit 8d14a80

File tree

2 files changed

+48
-2
lines changed

2 files changed

+48
-2
lines changed

README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,26 @@ key: "INFERENCE_MODE"
144144
}
145145
```
146146

147+
* `DISABLE_CUDNN`: Boolean flag to disable the cuDNN library. By default, cuDNN is enabled.
148+
149+
[cuDNN](https://developer.nvidia.com/cudnn) is a GPU-accelerated library of primitives for
150+
deep neural networks. cuDNN provides highly tuned implementations for standard routines.
151+
152+
Typically, models run with cuDNN enabled are faster. However there are some exceptions
153+
where using cuDNN can be slower, cause higher memory usage or result in errors.
154+
155+
156+
The section of model config file specifying this parameter will look like:
157+
158+
```
159+
parameters: {
160+
key: "DISABLE_CUDNN"
161+
value: {
162+
string_value: "true"
163+
}
164+
}
165+
```
166+
147167
* `ENABLE_WEIGHT_SHARING`: Boolean flag to enable model instances on the same device to
148168
share weights. This optimization should not be used with stateful models. If not specified,
149169
weight sharing is disabled.

src/libtorch.cc

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ class ModelState : public BackendModel {
104104
return enable_jit_executor_pair_;
105105
}
106106
bool EnabledInferenceMode() { return enable_inference_mode_; }
107+
bool EnabledCudnn() { return enable_cudnn_; }
107108
bool EnabledCacheCleaning() { return enable_cache_cleaning_; }
108109

109110
bool EnabledWeightSharing() { return enable_weight_sharing_; }
@@ -125,6 +126,9 @@ class ModelState : public BackendModel {
125126
// Flag to indicate whether inference mode is enabled. Defaults to false.
126127
bool enable_inference_mode_;
127128

129+
// Flag to indicate whether cudnn is enabled. Defaults to true.
130+
bool enable_cudnn_;
131+
128132
// Flag to indicate whether cache cleaning after each run is enabled.
129133
// Defaults to false.
130134
bool enable_cache_cleaning_;
@@ -227,8 +231,9 @@ ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state)
227231

228232
ModelState::ModelState(TRITONBACKEND_Model* triton_model)
229233
: BackendModel(triton_model), enable_optimized_execution_(true),
230-
enable_inference_mode_(true), enable_cache_cleaning_(false),
231-
enable_weight_sharing_(false), enable_tensor_fuser_pair_({false, true}),
234+
enable_inference_mode_(true), enable_cudnn_(true),
235+
enable_cache_cleaning_(false), enable_weight_sharing_(false),
236+
enable_tensor_fuser_pair_({false, true}),
232237
enable_jit_profiling_pair_({false, true}),
233238
enable_jit_executor_pair_({false, true})
234239
{
@@ -393,6 +398,24 @@ ModelState::ParseParameters()
393398
" for model instance '" + Name() + "'")
394399
.c_str());
395400

401+
// If 'DISABLE_CUDNN' is not present in 'parameters' then no update is made
402+
// to 'enable_cudnn_'.
403+
bool disable_cudnn = false;
404+
err = ParseParameter(params, "DISABLE_CUDNN", &disable_cudnn);
405+
if (err != nullptr) {
406+
if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) {
407+
return err;
408+
} else {
409+
TRITONSERVER_ErrorDelete(err);
410+
}
411+
}
412+
enable_cudnn_ = !disable_cudnn;
413+
LOG_MESSAGE(
414+
TRITONSERVER_LOG_INFO,
415+
(std::string("cuDNN is ") + (enable_cudnn_ ? "enabled" : "disabled") +
416+
" for model instance '" + Name() + "'")
417+
.c_str());
418+
396419
// If 'ENABLE_TENSOR_FUSER' is not present in 'parameters' then no
397420
// update is made to 'enable_tensor_fuser'.
398421
bool enable_tensor_fuser = false;
@@ -1562,6 +1585,9 @@ ModelInstanceState::Execute(
15621585
// enable/disable inference mode - supersedes NoGradGuard
15631586
torch::InferenceMode infer_guard(model_state_->EnabledInferenceMode());
15641587

1588+
// enable/disable cudnn
1589+
at::globalContext().setUserEnabledCuDNN(model_state_->EnabledCudnn());
1590+
15651591
// JIT. No change is made unless parameter is explicitly set.
15661592
if (std::get<0>(model_state_->EnabledJitProfiling())) {
15671593
torch::jit::getProfilingMode() =

0 commit comments

Comments
 (0)