@@ -61,104 +61,12 @@ struct pool_descriptor {
61
61
bool operator ==(const pool_descriptor &other) const ;
62
62
friend std::ostream &operator <<(std::ostream &os,
63
63
const pool_descriptor &desc);
64
- static std::pair<ur_result_t , std::vector<pool_descriptor>>
65
- create (ur_usm_pool_handle_t poolHandle, ur_context_handle_t hContext);
64
+ static std::vector<pool_descriptor>
65
+ createFromDevices (ur_usm_pool_handle_t poolHandle,
66
+ ur_context_handle_t hContext,
67
+ const std::vector<ur_device_handle_t > &devices);
66
68
};
67
69
68
- static inline std::pair<ur_result_t , std::vector<ur_device_handle_t >>
69
- urGetSubDevices (ur_device_handle_t hDevice) {
70
- static detail::ddiTables ddi;
71
-
72
- uint32_t nComputeUnits;
73
- auto ret = ddi.deviceDdiTable .pfnGetInfo (
74
- hDevice, UR_DEVICE_INFO_MAX_COMPUTE_UNITS, sizeof (nComputeUnits),
75
- &nComputeUnits, nullptr );
76
- if (ret != UR_RESULT_SUCCESS) {
77
- return {ret, {}};
78
- }
79
-
80
- ur_device_partition_property_t prop;
81
- prop.type = UR_DEVICE_PARTITION_BY_CSLICE;
82
- prop.value .affinity_domain = 0 ;
83
-
84
- ur_device_partition_properties_t properties{
85
- UR_STRUCTURE_TYPE_DEVICE_PARTITION_PROPERTIES,
86
- nullptr ,
87
- &prop,
88
- 1 ,
89
- };
90
-
91
- // Get the number of devices that will be created
92
- uint32_t deviceCount;
93
- ret = ddi.deviceDdiTable .pfnPartition (hDevice, &properties, 0 , nullptr ,
94
- &deviceCount);
95
- if (ret != UR_RESULT_SUCCESS) {
96
- return {ret, {}};
97
- }
98
-
99
- std::vector<ur_device_handle_t > sub_devices (deviceCount);
100
- ret = ddi.deviceDdiTable .pfnPartition (
101
- hDevice, &properties, static_cast <uint32_t >(sub_devices.size ()),
102
- sub_devices.data (), nullptr );
103
- if (ret != UR_RESULT_SUCCESS) {
104
- return {ret, {}};
105
- }
106
-
107
- return {UR_RESULT_SUCCESS, sub_devices};
108
- }
109
-
110
- inline std::pair<ur_result_t , std::vector<ur_device_handle_t >>
111
- urGetAllDevicesAndSubDevices (ur_context_handle_t hContext) {
112
- static detail::ddiTables ddi;
113
-
114
- size_t deviceCount = 0 ;
115
- auto ret = ddi.contextDdiTable .pfnGetInfo (
116
- hContext, UR_CONTEXT_INFO_NUM_DEVICES, sizeof (deviceCount), &deviceCount,
117
- nullptr );
118
- if (ret != UR_RESULT_SUCCESS || deviceCount == 0 ) {
119
- return {ret, {}};
120
- }
121
-
122
- std::vector<ur_device_handle_t > devices (deviceCount);
123
- ret = ddi.contextDdiTable .pfnGetInfo (hContext, UR_CONTEXT_INFO_DEVICES,
124
- sizeof (ur_device_handle_t ) * deviceCount,
125
- devices.data (), nullptr );
126
- if (ret != UR_RESULT_SUCCESS) {
127
- return {ret, {}};
128
- }
129
-
130
- std::vector<ur_device_handle_t > devicesAndSubDevices;
131
- std::function<ur_result_t (ur_device_handle_t )> addPoolsForDevicesRec =
132
- [&](ur_device_handle_t hDevice) {
133
- devicesAndSubDevices.push_back (hDevice);
134
- auto [ret, subDevices] = urGetSubDevices (hDevice);
135
- if (ret != UR_RESULT_SUCCESS) {
136
- return ret;
137
- }
138
- for (auto &subDevice : subDevices) {
139
- ret = addPoolsForDevicesRec (subDevice);
140
- if (ret != UR_RESULT_SUCCESS) {
141
- return ret;
142
- }
143
- }
144
- return UR_RESULT_SUCCESS;
145
- };
146
-
147
- for (size_t i = 0 ; i < deviceCount; i++) {
148
- ret = addPoolsForDevicesRec (devices[i]);
149
- if (ret != UR_RESULT_SUCCESS) {
150
- if (ret == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) {
151
- // Return main devices when sub-devices are unsupported.
152
- return {ret, std::move (devices)};
153
- }
154
-
155
- return {ret, {}};
156
- }
157
- }
158
-
159
- return {UR_RESULT_SUCCESS, devicesAndSubDevices};
160
- }
161
-
162
70
static inline bool
163
71
isSharedAllocationReadOnlyOnDevice (const pool_descriptor &desc) {
164
72
return desc.type == UR_USM_TYPE_SHARED && desc.deviceReadOnly ;
@@ -205,14 +113,9 @@ inline std::ostream &operator<<(std::ostream &os, const pool_descriptor &desc) {
205
113
return os;
206
114
}
207
115
208
- inline std::pair<ur_result_t , std::vector<pool_descriptor>>
209
- pool_descriptor::create (ur_usm_pool_handle_t poolHandle,
210
- ur_context_handle_t hContext) {
211
- auto [ret, devices] = urGetAllDevicesAndSubDevices (hContext);
212
- if (ret != UR_RESULT_SUCCESS) {
213
- return {ret, {}};
214
- }
215
-
116
+ inline std::vector<pool_descriptor> pool_descriptor::createFromDevices (
117
+ ur_usm_pool_handle_t poolHandle, ur_context_handle_t hContext,
118
+ const std::vector<ur_device_handle_t > &devices) {
216
119
std::vector<pool_descriptor> descriptors;
217
120
pool_descriptor &desc = descriptors.emplace_back ();
218
121
desc.poolHandle = poolHandle;
@@ -245,7 +148,7 @@ pool_descriptor::create(ur_usm_pool_handle_t poolHandle,
245
148
}
246
149
}
247
150
248
- return {ret, descriptors} ;
151
+ return descriptors;
249
152
}
250
153
251
154
template <typename D> struct pool_manager {
0 commit comments