17
17
18
18
#include < umf/providers/provider_level_zero.h>
19
19
20
+ static inline void UMF_CALL_THROWS (umf_result_t res) {
21
+ if (res != UMF_RESULT_SUCCESS) {
22
+ throw res;
23
+ }
24
+ }
25
+
20
26
namespace umf {
21
27
ur_result_t getProviderNativeError (const char *providerName,
22
28
int32_t nativeError) {
@@ -99,35 +105,21 @@ descToDisjoinPoolMemType(const usm::pool_descriptor &desc) {
99
105
static umf::provider_unique_handle_t
100
106
makeProvider (usm::pool_descriptor poolDescriptor) {
101
107
umf_level_zero_memory_provider_params_handle_t hParams;
102
- umf_result_t umf_ret = umfLevelZeroMemoryProviderParamsCreate (&hParams);
103
- if (umf_ret != UMF_RESULT_SUCCESS) {
104
- throw umf::umf2urResult (umf_ret);
105
- }
106
-
108
+ UMF_CALL_THROWS (umfLevelZeroMemoryProviderParamsCreate (&hParams));
107
109
std::unique_ptr<umf_level_zero_memory_provider_params_t ,
108
110
decltype (&umfLevelZeroMemoryProviderParamsDestroy)>
109
111
params (hParams, &umfLevelZeroMemoryProviderParamsDestroy);
110
112
111
- umf_ret = umfLevelZeroMemoryProviderParamsSetContext (
112
- hParams, poolDescriptor.hContext ->getZeHandle ());
113
- if (umf_ret != UMF_RESULT_SUCCESS) {
114
- throw umf::umf2urResult (umf_ret);
115
- };
113
+ UMF_CALL_THROWS (umfLevelZeroMemoryProviderParamsSetContext (
114
+ hParams, poolDescriptor.hContext ->getZeHandle ()));
116
115
117
116
ze_device_handle_t level_zero_device_handle =
118
117
poolDescriptor.hDevice ? poolDescriptor.hDevice ->ZeDevice : nullptr ;
119
118
120
- umf_ret = umfLevelZeroMemoryProviderParamsSetDevice (hParams,
121
- level_zero_device_handle);
122
- if (umf_ret != UMF_RESULT_SUCCESS) {
123
- throw umf::umf2urResult (umf_ret);
124
- }
125
-
126
- umf_ret = umfLevelZeroMemoryProviderParamsSetMemoryType (
127
- hParams, urToUmfMemoryType (poolDescriptor.type ));
128
- if (umf_ret != UMF_RESULT_SUCCESS) {
129
- throw umf::umf2urResult (umf_ret);
130
- }
119
+ UMF_CALL_THROWS (umfLevelZeroMemoryProviderParamsSetDevice (
120
+ hParams, level_zero_device_handle));
121
+ UMF_CALL_THROWS (umfLevelZeroMemoryProviderParamsSetMemoryType (
122
+ hParams, urToUmfMemoryType (poolDescriptor.type )));
131
123
132
124
std::vector<ze_device_handle_t > residentZeHandles;
133
125
@@ -140,13 +132,13 @@ makeProvider(usm::pool_descriptor poolDescriptor) {
140
132
residentZeHandles.push_back (device->ZeDevice );
141
133
}
142
134
143
- umf_ret = umfLevelZeroMemoryProviderParamsSetResidentDevices (
144
- hParams, residentZeHandles.data (), residentZeHandles.size ());
145
- if (umf_ret != UMF_RESULT_SUCCESS) {
146
- throw umf::umf2urResult (umf_ret);
147
- }
135
+ UMF_CALL_THROWS (umfLevelZeroMemoryProviderParamsSetResidentDevices (
136
+ hParams, residentZeHandles.data (), residentZeHandles.size ()));
148
137
}
149
138
139
+ UMF_CALL_THROWS (umfLevelZeroMemoryProviderParamsSetFreePolicy (
140
+ hParams, UMF_LEVEL_ZERO_MEMORY_PROVIDER_FREE_POLICY_BLOCKING_FREE));
141
+
150
142
auto [ret, provider] =
151
143
umf::providerMakeUniqueFromOps (umfLevelZeroMemoryProviderOps (), hParams);
152
144
if (ret != UMF_RESULT_SUCCESS) {
0 commit comments