Skip to content

Commit c50d65b

Browse files
authored
add thread control for pytorch backend (#125)
* add pytorch thread control * use function overloading and update copyright years
1 parent 4fa7daa commit c50d65b

File tree

3 files changed

+77
-3
lines changed

3 files changed

+77
-3
lines changed

src/libtorch.cc

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -56,6 +56,12 @@
5656
#include <cuda_runtime_api.h>
5757
#endif // TRITON_ENABLE_GPU
5858

59+
// for thread control
60+
// https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html#runtime-api
61+
// https://github.com/pytorch/pytorch/blob/v2.2.1-rc3/aten/src/ATen/Parallel.h#L133
62+
#include <ATen/Parallel.h>
63+
64+
5965
//
6066
// PyTorch C++ (LibTorch) Backend that implements the TRITONBACKEND API.
6167
//
@@ -465,6 +471,54 @@ ModelState::ParseParameters()
465471
" for model instance '" + Name() + "'")
466472
.c_str());
467473
}
474+
475+
// If 'INTRA_OP_THREAD_COUNT' is not present in 'parameters' then no update
476+
// is made to 'intra_op_thread_count', which by default will take all
477+
// threads
478+
int intra_op_thread_count = -1;
479+
err = ParseParameter(
480+
params, "INTRA_OP_THREAD_COUNT", &intra_op_thread_count);
481+
if (err != nullptr) {
482+
if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) {
483+
return err;
484+
} else {
485+
TRITONSERVER_ErrorDelete(err);
486+
}
487+
} else {
488+
if (intra_op_thread_count > 0) {
489+
at::set_num_threads(intra_op_thread_count);
490+
LOG_MESSAGE(
491+
TRITONSERVER_LOG_INFO,
492+
(std::string("Intra op thread count is set to ") +
493+
std::to_string(intra_op_thread_count) + " for model instance '" +
494+
Name() + "'")
495+
.c_str());
496+
}
497+
}
498+
499+
// If 'INTER_OP_THREAD_COUNT' is not present in 'parameters' then no update
500+
// is made to 'inter_op_thread_count', which by default will take all
501+
// threads
502+
int inter_op_thread_count = -1;
503+
err = ParseParameter(
504+
params, "INTER_OP_THREAD_COUNT", &inter_op_thread_count);
505+
if (err != nullptr) {
506+
if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) {
507+
return err;
508+
} else {
509+
TRITONSERVER_ErrorDelete(err);
510+
}
511+
} else {
512+
if (inter_op_thread_count > 0) {
513+
at::set_num_interop_threads(inter_op_thread_count);
514+
LOG_MESSAGE(
515+
TRITONSERVER_LOG_INFO,
516+
(std::string("Inter op thread count is set to ") +
517+
std::to_string(inter_op_thread_count) + " for model instance '" +
518+
Name() + "'")
519+
.c_str());
520+
}
521+
}
468522
}
469523

470524
return nullptr;

src/libtorch_utils.cc

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2020-21 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright (c) 2020-24 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -149,6 +149,19 @@ ParseParameter(
149149
return nullptr;
150150
}
151151

152+
TRITONSERVER_Error*
153+
ParseParameter(
154+
triton::common::TritonJson::Value& params, const std::string& mkey,
155+
int* value)
156+
{
157+
std::string value_str;
158+
RETURN_IF_ERROR(GetParameterValue(params, mkey, &value_str));
159+
RETURN_IF_ERROR(ParseIntValue(value_str, value));
160+
161+
return nullptr;
162+
}
163+
164+
152165
#ifdef TRITON_ENABLE_GPU
153166
TRITONSERVER_Error*
154167
ConvertCUDAStatusToTritonError(

src/libtorch_utils.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
// Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
//
33
// Redistribution and use in source and binary forms, with or without
44
// modification, are permitted provided that the following conditions
@@ -62,4 +62,11 @@ TRITONSERVER_Error* ParseParameter(
6262
triton::common::TritonJson::Value& params, const std::string& mkey,
6363
bool* value);
6464

65+
// If the key 'mkey' is present in 'params' then update 'value' with the
66+
// value associated with that key. If 'mkey' is not present in 'params' then
67+
// 'value' is set to 'default_value'.
68+
TRITONSERVER_Error* ParseParameter(
69+
triton::common::TritonJson::Value& params, const std::string& mkey,
70+
int* value);
71+
6572
}}} // namespace triton::backend::pytorch

0 commit comments

Comments
 (0)