|
1 |
| -// Copyright 2019-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 1 | +// Copyright 2019-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
2 | 2 | //
|
3 | 3 | // Redistribution and use in source and binary forms, with or without
|
4 | 4 | // modification, are permitted provided that the following conditions
|
|
56 | 56 | #include <cuda_runtime_api.h>
|
57 | 57 | #endif // TRITON_ENABLE_GPU
|
58 | 58 |
|
| 59 | +// for thread control |
| 60 | +// https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html#runtime-api |
| 61 | +// https://github.com/pytorch/pytorch/blob/v2.2.1-rc3/aten/src/ATen/Parallel.h#L133 |
| 62 | +#include <ATen/Parallel.h> |
| 63 | + |
| 64 | + |
59 | 65 | //
|
60 | 66 | // PyTorch C++ (LibTorch) Backend that implements the TRITONBACKEND API.
|
61 | 67 | //
|
@@ -465,6 +471,54 @@ ModelState::ParseParameters()
|
465 | 471 | " for model instance '" + Name() + "'")
|
466 | 472 | .c_str());
|
467 | 473 | }
|
| 474 | + |
| 475 | + // If 'INTRA_OP_THREAD_COUNT' is not present in 'parameters' then no update |
| 476 | + // is made to 'intra_op_thread_count', which by default will take all |
| 477 | + // threads |
| 478 | + int intra_op_thread_count = -1; |
| 479 | + err = ParseParameter( |
| 480 | + params, "INTRA_OP_THREAD_COUNT", &intra_op_thread_count); |
| 481 | + if (err != nullptr) { |
| 482 | + if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) { |
| 483 | + return err; |
| 484 | + } else { |
| 485 | + TRITONSERVER_ErrorDelete(err); |
| 486 | + } |
| 487 | + } else { |
| 488 | + if (intra_op_thread_count > 0) { |
| 489 | + at::set_num_threads(intra_op_thread_count); |
| 490 | + LOG_MESSAGE( |
| 491 | + TRITONSERVER_LOG_INFO, |
| 492 | + (std::string("Intra op thread count is set to ") + |
| 493 | + std::to_string(intra_op_thread_count) + " for model instance '" + |
| 494 | + Name() + "'") |
| 495 | + .c_str()); |
| 496 | + } |
| 497 | + } |
| 498 | + |
| 499 | + // If 'INTER_OP_THREAD_COUNT' is not present in 'parameters' then no update |
| 500 | + // is made to 'inter_op_thread_count', which by default will take all |
| 501 | + // threads |
| 502 | + int inter_op_thread_count = -1; |
| 503 | + err = ParseParameter( |
| 504 | + params, "INTER_OP_THREAD_COUNT", &inter_op_thread_count); |
| 505 | + if (err != nullptr) { |
| 506 | + if (TRITONSERVER_ErrorCode(err) != TRITONSERVER_ERROR_NOT_FOUND) { |
| 507 | + return err; |
| 508 | + } else { |
| 509 | + TRITONSERVER_ErrorDelete(err); |
| 510 | + } |
| 511 | + } else { |
| 512 | + if (inter_op_thread_count > 0) { |
| 513 | + at::set_num_interop_threads(inter_op_thread_count); |
| 514 | + LOG_MESSAGE( |
| 515 | + TRITONSERVER_LOG_INFO, |
| 516 | + (std::string("Inter op thread count is set to ") + |
| 517 | + std::to_string(inter_op_thread_count) + " for model instance '" + |
| 518 | + Name() + "'") |
| 519 | + .c_str()); |
| 520 | + } |
| 521 | + } |
468 | 522 | }
|
469 | 523 |
|
470 | 524 | return nullptr;
|
|
0 commit comments