Skip to content

Commit ae284f1

Browse files
[SYCL] correct sub-device count calculation for numa partitioning (#6005)
Signed-off-by: Sergey V Maslov <[email protected]>
1 parent 7efb3e6 commit ae284f1

File tree

1 file changed

+46
-7
lines changed

1 file changed

+46
-7
lines changed

sycl/source/detail/device_impl.cpp

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,11 @@ device_impl::create_sub_devices(const cl_device_partition_property *Properties,
133133
Plugin.call<sycl::errc::invalid, PiApiKind::piDevicePartition>(
134134
MDevice, Properties, SubDevicesCount, SubDevices.data(),
135135
&ReturnedSubDevices);
136-
// TODO: check that returned number of sub-devices matches what was
137-
// requested, otherwise this walk below is wrong.
138-
//
136+
if (ReturnedSubDevices != SubDevicesCount) {
137+
throw sycl::exception(
138+
errc::invalid,
139+
"Could not partition to the specified number of sub-devices");
140+
}
139141
// TODO: Need to describe the subdevice model. Some sub_device management
140142
// may be necessary. What happens if create_sub_devices is called multiple
141143
// times with the same arguments?
@@ -161,8 +163,14 @@ std::vector<device> device_impl::create_sub_devices(size_t ComputeUnits) const {
161163
if (!is_partition_supported(info::partition_property::partition_equally)) {
162164
throw cl::sycl::feature_not_supported();
163165
}
164-
size_t SubDevicesCount =
165-
get_info<info::device::max_compute_units>() / ComputeUnits;
166+
// If count exceeds the total number of compute units in the device, an
167+
// exception with the errc::invalid error code must be thrown.
168+
auto MaxComputeUnits = get_info<info::device::max_compute_units>();
169+
if (ComputeUnits > MaxComputeUnits)
170+
throw sycl::exception(errc::invalid,
171+
"Total counts exceed max compute units");
172+
173+
size_t SubDevicesCount = MaxComputeUnits / ComputeUnits;
166174
const cl_device_partition_property Properties[3] = {
167175
CL_DEVICE_PARTITION_EQUALLY, (cl_device_partition_property)ComputeUnits,
168176
0};
@@ -184,7 +192,33 @@ device_impl::create_sub_devices(const std::vector<size_t> &Counts) const {
184192
static const cl_device_partition_property P[] = {
185193
CL_DEVICE_PARTITION_BY_COUNTS, CL_DEVICE_PARTITION_BY_COUNTS_LIST_END, 0};
186194
std::vector<cl_device_partition_property> Properties(P, P + 3);
187-
Properties.insert(Properties.begin() + 1, Counts.begin(), Counts.end());
195+
196+
// Fill the properties vector with counts and validate it
197+
auto It = Properties.begin() + 1;
198+
size_t TotalCounts = 0;
199+
size_t NonZeroCounts = 0;
200+
for (auto Count : Counts) {
201+
TotalCounts += Count;
202+
NonZeroCounts += (Count != 0) ? 1 : 0;
203+
It = Properties.insert(It, Count);
204+
}
205+
206+
// If the number of non-zero values in counts exceeds the device’s maximum
207+
// number of sub devices (as returned by info::device::
208+
// partition_max_sub_devices) an exception with the errc::invalid
209+
// error code must be thrown.
210+
if (NonZeroCounts > get_info<info::device::partition_max_sub_devices>())
211+
throw sycl::exception(errc::invalid,
212+
"Total non-zero counts exceed max sub-devices");
213+
214+
// If the total of all the values in the counts vector exceeds the total
215+
// number of compute units in the device (as returned by
216+
// info::device::max_compute_units), an exception with the errc::invalid
217+
// error code must be thrown.
218+
if (TotalCounts > get_info<info::device::max_compute_units>())
219+
throw sycl::exception(errc::invalid,
220+
"Total counts exceed max compute units");
221+
188222
return create_sub_devices(Properties.data(), Counts.size());
189223
}
190224

@@ -205,7 +239,12 @@ std::vector<device> device_impl::create_sub_devices(
205239
const pi_device_partition_property Properties[3] = {
206240
PI_DEVICE_PARTITION_BY_AFFINITY_DOMAIN,
207241
(pi_device_partition_property)AffinityDomain, 0};
208-
size_t SubDevicesCount = get_info<info::device::partition_max_sub_devices>();
242+
243+
pi_uint32 SubDevicesCount = 0;
244+
const detail::plugin &Plugin = getPlugin();
245+
Plugin.call<sycl::errc::invalid, PiApiKind::piDevicePartition>(
246+
MDevice, Properties, 0, nullptr, &SubDevicesCount);
247+
209248
return create_sub_devices(Properties, SubDevicesCount);
210249
}
211250

0 commit comments

Comments
 (0)