Skip to content

Commit 0b732df

Browse files
authored
Enable inference mode by default (#105)
1 parent a7a2413 commit 0b732df

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ key: "DISABLE_OPTIMIZED_EXECUTION"
123123
```
124124

125125
* `INFERENCE_MODE`: Boolean flag to enable the Inference Mode execution
126-
of TorchScript models. By default, the inference mode is disabled.
126+
of TorchScript models. By default, the inference mode is enabled.
127127

128128
[InferenceMode](https://pytorch.org/cppdocs/notes/inference_mode.html) is a new
129129
RAII guard analogous to NoGradMode to be used when you are certain your operations

src/libtorch.cc

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ ModelState::Create(TRITONBACKEND_Model* triton_model, ModelState** state)
176176

177177
ModelState::ModelState(TRITONBACKEND_Model* triton_model)
178178
: BackendModel(triton_model), enable_optimized_execution_(true),
179-
enable_inference_mode_(false), enable_cache_cleaning_(false),
179+
enable_inference_mode_(true), enable_cache_cleaning_(false),
180180
enable_weight_sharing_(false), enable_tensor_fuser_pair_({false, true}),
181181
enable_jit_profiling_pair_({false, true}),
182182
enable_jit_executor_pair_({false, true}),
@@ -1312,12 +1312,12 @@ ModelInstanceState::Execute(
13121312
torch::jit::overrideCanFuseOnCPU(false);
13131313
torch::jit::overrideCanFuseOnGPU(false);
13141314
torch::jit::setTensorExprFuserEnabled(false);
1315-
torch::jit::fuser::cuda::setEnabled(true);
1315+
torch::jit::fuser::cuda::setEnabled(true);
13161316
} else {
13171317
torch::jit::overrideCanFuseOnCPU(true);
13181318
torch::jit::overrideCanFuseOnGPU(true);
13191319
torch::jit::setTensorExprFuserEnabled(true);
1320-
torch::jit::fuser::cuda::setEnabled(false);
1320+
torch::jit::fuser::cuda::setEnabled(false);
13211321
}
13221322
}
13231323

@@ -1761,9 +1761,9 @@ ModelInstanceState::SetInputTensors(
17611761

17621762
batchn_shape[0] += GetElementCount(input_shape, input_dims_count);
17631763
}
1764-
}
1765-
else {
1766-
batchn_shape = std::vector<int64_t>(input_shape, input_shape + input_dims_count);
1764+
} else {
1765+
batchn_shape =
1766+
std::vector<int64_t>(input_shape, input_shape + input_dims_count);
17671767
if (supports_batching_) {
17681768
batchn_shape[0] = total_batch_size;
17691769
}
@@ -1887,9 +1887,11 @@ ModelInstanceState::ReadOutputTensors(
18871887

18881888
// Output tensors may not reside on the same device as model
18891889
torch::Device tensor_device = output_flat.device();
1890-
const auto memory_type = (tensor_device.type() == torch::kCPU) ? TRITONSERVER_MEMORY_CPU
1891-
: TRITONSERVER_MEMORY_GPU;
1892-
const auto memory_id = (tensor_device.type() == torch::kCPU) ? 0 : tensor_device.index();
1890+
const auto memory_type = (tensor_device.type() == torch::kCPU)
1891+
? TRITONSERVER_MEMORY_CPU
1892+
: TRITONSERVER_MEMORY_GPU;
1893+
const auto memory_id =
1894+
(tensor_device.type() == torch::kCPU) ? 0 : tensor_device.index();
18931895

18941896
// Batch output doesn't support string data type yet, as it is not trivial
18951897
// to parse string output
@@ -1906,16 +1908,16 @@ ModelInstanceState::ReadOutputTensors(
19061908
return TRITONSERVER_ErrorNew(
19071909
TRITONSERVER_ERROR_INVALID_ARG,
19081910
(std::string("output '") + name +
1909-
"' is a scalar which is not supported.")
1911+
"' is a scalar which is not supported.")
19101912
.c_str());
19111913
}
19121914

19131915
responder.ProcessTensor(
1914-
name, output_dtype, batchn_shape, output_buffer,
1915-
memory_type, memory_id);
1916+
name, output_dtype, batchn_shape, output_buffer, memory_type,
1917+
memory_id);
19161918
} else {
19171919
responder.ProcessBatchOutput(
1918-
name, *batch_output, output_buffer, memory_type, memory_id);
1920+
name, *batch_output, output_buffer, memory_type, memory_id);
19191921
}
19201922
} else if (output_tensors[op_index].isList()) {
19211923
// Custom handling for string/bytes tensor...

0 commit comments

Comments
 (0)