@@ -81,7 +81,7 @@ class ModelState : public BackendModel {
81
81
// representing the model.
82
82
TRITONSERVER_Error* LoadModel (
83
83
const std::string& artifact_name, const torch::Device device,
84
- std::string* model_path,
84
+ std::string* model_path, const TRITONSERVER_InstanceGroupKind& kind,
85
85
std::shared_ptr<torch::jit::script::Module>* torch_model);
86
86
87
87
bool EnabledOptimizedExecution () { return enable_optimized_execution_; }
@@ -206,7 +206,7 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
206
206
TRITONSERVER_Error*
207
207
ModelState::LoadModel (
208
208
const std::string& artifact_name, const torch::Device device,
209
- std::string* model_path,
209
+ std::string* model_path, const TRITONSERVER_InstanceGroupKind& kind,
210
210
std::shared_ptr<torch::jit::script::Module>* torch_model)
211
211
{
212
212
// Find the TorchScript file that describes the model. If the model
@@ -256,8 +256,14 @@ ModelState::LoadModel(
256
256
257
257
try {
258
258
std::istringstream model_stream (model_data_str);
259
- torch_model->reset (
260
- new torch::jit::Module (torch::jit::load (model_stream, device)));
259
+ if (kind == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
260
+ // Don't select the device when loading the model.
261
+ torch_model->reset (
262
+ new torch::jit::Module (torch::jit::load (model_stream)));
263
+ } else {
264
+ torch_model->reset (
265
+ new torch::jit::Module (torch::jit::load (model_stream, device)));
266
+ }
261
267
}
262
268
catch (const std::exception& ex) {
263
269
return TRITONSERVER_ErrorNew (
@@ -607,7 +613,7 @@ ModelInstanceState::ModelInstanceState(
607
613
}
608
614
609
615
THROW_IF_BACKEND_INSTANCE_ERROR (model_state->LoadModel (
610
- ArtifactFilename (), device_, &model_path_, &torch_model_));
616
+ ArtifactFilename (), device_, &model_path_, Kind (), &torch_model_));
611
617
612
618
size_t expected_input_cnt = 0 ;
613
619
{
0 commit comments