@@ -770,17 +770,21 @@ pi_device _pi_context::getRootDevice() const {
770
770
pi_result _pi_context::initialize () {
771
771
772
772
// Helper lambda to create various USM allocators for a device.
773
+ // Note that the CCS devices and their respective subdevices share a
774
+ // common ze_device_handle and therefore, also share USM allocators.
773
775
auto createUSMAllocators = [this ](pi_device Device) {
774
776
SharedMemAllocContexts.emplace (
775
- std::piecewise_construct, std::make_tuple (Device),
777
+ std::piecewise_construct, std::make_tuple (Device-> ZeDevice ),
776
778
std::make_tuple (std::unique_ptr<SystemMemory>(
777
779
new USMSharedMemoryAlloc (this , Device))));
780
+
778
781
SharedReadOnlyMemAllocContexts.emplace (
779
- std::piecewise_construct, std::make_tuple (Device),
782
+ std::piecewise_construct, std::make_tuple (Device-> ZeDevice ),
780
783
std::make_tuple (std::unique_ptr<SystemMemory>(
781
784
new USMSharedReadOnlyMemoryAlloc (this , Device))));
785
+
782
786
DeviceMemAllocContexts.emplace (
783
- std::piecewise_construct, std::make_tuple (Device),
787
+ std::piecewise_construct, std::make_tuple (Device-> ZeDevice ),
784
788
std::make_tuple (std::unique_ptr<SystemMemory>(
785
789
new USMDeviceMemoryAlloc (this , Device))));
786
790
};
@@ -807,8 +811,9 @@ pi_result _pi_context::initialize() {
807
811
std::unique_ptr<SystemMemory>(new USMHostMemoryAlloc (this )));
808
812
809
813
// We may allocate memory to this root device so create allocators.
810
- if (SingleRootDevice && DeviceMemAllocContexts.find (SingleRootDevice) ==
811
- DeviceMemAllocContexts.end ()) {
814
+ if (SingleRootDevice &&
815
+ DeviceMemAllocContexts.find (SingleRootDevice->ZeDevice ) ==
816
+ DeviceMemAllocContexts.end ()) {
812
817
createUSMAllocators (SingleRootDevice);
813
818
}
814
819
@@ -8191,7 +8196,7 @@ pi_result piextUSMDeviceAlloc(void **ResultPtr, pi_context Context,
8191
8196
}
8192
8197
8193
8198
try {
8194
- auto It = Context->DeviceMemAllocContexts .find (Device);
8199
+ auto It = Context->DeviceMemAllocContexts .find (Device-> ZeDevice );
8195
8200
if (It == Context->DeviceMemAllocContexts .end ())
8196
8201
return PI_ERROR_INVALID_VALUE;
8197
8202
@@ -8269,7 +8274,7 @@ pi_result piextUSMSharedAlloc(void **ResultPtr, pi_context Context,
8269
8274
try {
8270
8275
auto &Allocator = (DeviceReadOnly ? Context->SharedReadOnlyMemAllocContexts
8271
8276
: Context->SharedMemAllocContexts );
8272
- auto It = Allocator.find (Device);
8277
+ auto It = Allocator.find (Device-> ZeDevice );
8273
8278
if (It == Allocator.end ())
8274
8279
return PI_ERROR_INVALID_VALUE;
8275
8280
@@ -8432,10 +8437,11 @@ static pi_result USMFreeHelper(pi_context Context, void *Ptr,
8432
8437
PI_ASSERT (Device, PI_ERROR_INVALID_DEVICE);
8433
8438
8434
8439
auto DeallocationHelper =
8435
- [Context, Device, Ptr, OwnZeMemHandle](
8436
- std::unordered_map<pi_device, USMAllocContext> &AllocContextMap) {
8440
+ [Context, Device, Ptr,
8441
+ OwnZeMemHandle](std::unordered_map<ze_device_handle_t , USMAllocContext>
8442
+ &AllocContextMap) {
8437
8443
try {
8438
- auto It = AllocContextMap.find (Device);
8444
+ auto It = AllocContextMap.find (Device-> ZeDevice );
8439
8445
if (It == AllocContextMap.end ())
8440
8446
return PI_ERROR_INVALID_VALUE;
8441
8447
0 commit comments