|
31 | 31 | #include <mutex>
|
32 | 32 | #include <queue>
|
33 | 33 | #include <chrono>
|
| 34 | +#include <unordered_set> |
| 35 | +#include <optional> |
34 | 36 |
|
35 | 37 | #include "ggml-impl.h"
|
36 | 38 | #include "ggml-backend-impl.h"
|
@@ -93,6 +95,26 @@ int32_t ggml_cann_get_device() {
|
93 | 95 | return id;
|
94 | 96 | }
|
95 | 97 |
|
| 98 | +/** |
| 99 | + * @brief Get the value of the specified environment variable (name). |
| 100 | + * if not empty, return a std::string object |
| 101 | + */ |
| 102 | +std::optional<std::string> get_env(const std::string& name) { |
| 103 | + const char* val = std::getenv(name.c_str()); |
| 104 | + if (!val) return std::nullopt; |
| 105 | + std::string res = std::string(val); |
| 106 | + std::transform(res.begin(), res.end(), res.begin(), ::tolower); |
| 107 | + return res; |
| 108 | +} |
| 109 | + |
| 110 | +/** |
| 111 | + * @brief Verify whether the environment variable is a valid value. |
| 112 | + */ |
| 113 | +bool parse_bool(const std::string& value) { |
| 114 | + std::unordered_set<std::string> valid_values = {"on", "1", "yes", "y", "enable", "true"}; |
| 115 | + return valid_values.find(value) != valid_values.end(); |
| 116 | +} |
| 117 | + |
96 | 118 | /**
|
97 | 119 | * @brief Initialize the CANN device information.
|
98 | 120 | *
|
@@ -214,7 +236,7 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
|
214 | 236 | * @param device The device ID to associate with this buffer pool.
|
215 | 237 | */
|
216 | 238 | explicit ggml_cann_pool_buf_prio(int device) : device(device) {
|
217 |
| - disable_clean = getenv("GGML_CANN_DISABLE_BUF_POOL_CLEAN") != nullptr; |
| 239 | + disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or("")); |
218 | 240 | }
|
219 | 241 |
|
220 | 242 | /**
|
@@ -410,7 +432,7 @@ struct ggml_cann_pool_buf : public ggml_cann_pool {
|
410 | 432 | * @param device The device ID to associate with this buffer pool.
|
411 | 433 | */
|
412 | 434 | explicit ggml_cann_pool_buf(int device) : device(device) {
|
413 |
| - disable_clean = getenv("GGML_CANN_DISABLE_BUF_POOL_CLEAN") != nullptr; |
| 435 | + disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or("")); |
414 | 436 | }
|
415 | 437 |
|
416 | 438 | /**
|
@@ -731,16 +753,18 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
|
731 | 753 | */
|
732 | 754 | std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
|
733 | 755 | int device) {
|
734 |
| - bool disable_vmm = (getenv("GGML_CANN_DISABLE_VMM_POOL") != nullptr); |
735 |
| - if (!disable_vmm && ggml_cann_info().devices[device].vmm) { |
736 |
| - GGML_LOG_INFO("%s: device %d use vmm pool\n", __func__, device); |
737 |
| - return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device)); |
738 |
| - } |
739 |
| - bool enable_buf_prio = (getenv("GGML_CANN_ENABLE_BUF_PRIO_POOL") != nullptr); |
740 |
| - if (enable_buf_prio) { |
| 756 | + std::string mem_pool_type = get_env("GGML_CANN_MEM_POOL").value_or(""); |
| 757 | + |
| 758 | + if (mem_pool_type == "prio") { |
741 | 759 | GGML_LOG_INFO("%s: device %d use buffer pool with priority queue\n", __func__, device);
|
742 | 760 | return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_buf_prio(device));
|
743 | 761 | }
|
| 762 | + |
| 763 | + if (ggml_cann_info().devices[device].vmm && mem_pool_type != "leg") { |
| 764 | + GGML_LOG_INFO("%s: device %d use vmm pool\n", __func__, device); |
| 765 | + return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device)); |
| 766 | + } |
| 767 | + |
744 | 768 | GGML_LOG_INFO("%s: device %d use buffer pool\n", __func__, device);
|
745 | 769 | return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_buf(device));
|
746 | 770 | }
|
|
0 commit comments