Skip to content

Commit c7d5007

Browse files
committed
Add support for instance group of type 'MODEL'
1 parent f405488 commit c7d5007

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

src/libtorch.cc

100644100755
Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class ModelState : public BackendModel {
8181
// representing the model.
8282
TRITONSERVER_Error* LoadModel(
8383
const std::string& artifact_name, const torch::Device device,
84-
std::string* model_path,
84+
std::string* model_path, const TRITONSERVER_InstanceGroupKind& kind,
8585
std::shared_ptr<torch::jit::script::Module>* torch_model);
8686

8787
bool EnabledOptimizedExecution() { return enable_optimized_execution_; }
@@ -206,7 +206,7 @@ ModelState::ModelState(TRITONBACKEND_Model* triton_model)
206206
TRITONSERVER_Error*
207207
ModelState::LoadModel(
208208
const std::string& artifact_name, const torch::Device device,
209-
std::string* model_path,
209+
std::string* model_path, const TRITONSERVER_InstanceGroupKind& kind,
210210
std::shared_ptr<torch::jit::script::Module>* torch_model)
211211
{
212212
// Find the TorchScript file that describes the model. If the model
@@ -256,8 +256,14 @@ ModelState::LoadModel(
256256

257257
try {
258258
std::istringstream model_stream(model_data_str);
259-
torch_model->reset(
260-
new torch::jit::Module(torch::jit::load(model_stream, device)));
259+
if (kind == TRITONSERVER_INSTANCEGROUPKIND_MODEL) {
260+
// Don't select the device when loading the model.
261+
torch_model->reset(
262+
new torch::jit::Module(torch::jit::load(model_stream)));
263+
} else {
264+
torch_model->reset(
265+
new torch::jit::Module(torch::jit::load(model_stream, device)));
266+
}
261267
}
262268
catch (const std::exception& ex) {
263269
return TRITONSERVER_ErrorNew(
@@ -607,7 +613,7 @@ ModelInstanceState::ModelInstanceState(
607613
}
608614

609615
THROW_IF_BACKEND_INSTANCE_ERROR(model_state->LoadModel(
610-
ArtifactFilename(), device_, &model_path_, &torch_model_));
616+
ArtifactFilename(), device_, &model_path_, Kind(), &torch_model_));
611617

612618
size_t expected_input_cnt = 0;
613619
{

0 commit comments

Comments
 (0)