@@ -80,7 +80,7 @@ class ModelState : public BackendModel {
80
80
// representing the model.
81
81
TRITONSERVER_Error* LoadModel (
82
82
const std::string& artifact_name, const torch::Device device,
83
- std::string* model_path,
83
+ std::string* model_path, const TRITONSERVER_InstanceGroupKind& kind,
84
84
std::shared_ptr<torch::jit::script::Module>* torch_model);
85
85
86
86
bool EnabledOptimizedExecution () { return enable_optimized_execution_; }
@@ -205,7 +205,7 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
205
205
TRITONSERVER_Error*
206
206
ModelState::LoadModel (
207
207
const std::string& artifact_name, const torch::Device device,
208
- std::string* model_path,
208
+ std::string* model_path, const TRITONSERVER_InstanceGroupKind& kind,
209
209
std::shared_ptr<torch::jit::script::Module>* torch_model)
210
210
{
211
211
// Find the TorchScript file that describes the model. If the model
@@ -255,8 +255,14 @@ ModelState::LoadModel(
255
255
256
256
try {
257
257
std::istringstream model_stream (model_data_str);
258
- torch_model->reset (
259
- new torch::jit::Module (torch::jit::load (model_stream, device)));
258
+ if (kind == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
259
+ // Don't select the device when loading the model.
260
+ torch_model->reset (
261
+ new torch::jit::Module (torch::jit::load (model_stream)));
262
+ } else {
263
+ torch_model->reset (
264
+ new torch::jit::Module (torch::jit::load (model_stream, device)));
265
+ }
260
266
}
261
267
catch (const std::exception& ex) {
262
268
return TRITONSERVER_ErrorNew (
@@ -606,7 +612,7 @@ ModelInstanceState::ModelInstanceState(
606
612
}
607
613
608
614
THROW_IF_BACKEND_INSTANCE_ERROR (model_state->LoadModel (
609
- ArtifactFilename (), device_, &model_path_, &torch_model_));
615
+ ArtifactFilename (), device_, &model_path_, Kind (), &torch_model_));
610
616
611
617
size_t expected_input_cnt = 0 ;
612
618
{
0 commit comments