Skip to content

Commit 0b86d4f

Browse files
committed
Add support for instance group of type 'MODEL'
1 parent 550cf62 commit 0b86d4f

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

src/libtorch.cc

100644100755
Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class ModelState : public BackendModel {
8080
// representing the model.
8181
TRITONSERVER_Error* LoadModel(
8282
const std::string& artifact_name, const torch::Device device,
83-
std::string* model_path,
83+
std::string* model_path, const TRITONSERVER_InstanceGroupKind& kind,
8484
std::shared_ptr<torch::jit::script::Module>* torch_model);
8585

8686
bool EnabledOptimizedExecution() { return enable_optimized_execution_; }
@@ -205,7 +205,7 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
205205
TRITONSERVER_Error*
206206
ModelState::LoadModel(
207207
const std::string& artifact_name, const torch::Device device,
208-
std::string* model_path,
208+
std::string* model_path, const TRITONSERVER_InstanceGroupKind& kind,
209209
std::shared_ptr<torch::jit::script::Module>* torch_model)
210210
{
211211
// Find the TorchScript file that describes the model. If the model
@@ -255,8 +255,14 @@ ModelState::LoadModel(
255255

256256
try {
257257
std::istringstream model_stream(model_data_str);
258-
torch_model->reset(
259-
new torch::jit::Module(torch::jit::load(model_stream, device)));
258+
if (kind == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
259+
// Don't select the device when loading the model.
260+
torch_model->reset(
261+
new torch::jit::Module(torch::jit::load(model_stream)));
262+
} else {
263+
torch_model->reset(
264+
new torch::jit::Module(torch::jit::load(model_stream, device)));
265+
}
260266
}
261267
catch (const std::exception& ex) {
262268
return TRITONSERVER_ErrorNew(
@@ -606,7 +612,7 @@ ModelInstanceState::ModelInstanceState(
606612
}
607613

608614
THROW_IF_BACKEND_INSTANCE_ERROR(model_state->LoadModel(
609-
ArtifactFilename(), device_, &model_path_, &torch_model_));
615+
ArtifactFilename(), device_, &model_path_, Kind(), &torch_model_));
610616

611617
size_t expected_input_cnt = 0;
612618
{

0 commit comments

Comments
 (0)