@@ -3017,7 +3017,7 @@ pi_result ContextReleaseHelper(pi_context Context) {
3017
3017
3018
3018
if (--(Context->RefCount ) == 0 ) {
3019
3019
if (IndirectAccessTrackingEnabled) {
3020
- pi_platform Plt = Context->Devices [ 0 ]-> Platform ;
3020
+ pi_platform Plt = Context->getPlatform () ;
3021
3021
auto &Contexts = Plt->Contexts ;
3022
3022
auto It = std::find (Contexts.begin (), Contexts.end (), Context);
3023
3023
if (It != Contexts.end ())
@@ -3047,7 +3047,7 @@ pi_result ContextReleaseHelper(pi_context Context) {
3047
3047
}
3048
3048
3049
3049
pi_result piContextRelease (pi_context Context) {
3050
- pi_platform Plt = Context->Devices [ 0 ]-> Platform ;
3050
+ pi_platform Plt = Context->getPlatform () ;
3051
3051
std::unique_lock<std::mutex> ContextsLock (Plt->ContextsMutex ,
3052
3052
std::defer_lock);
3053
3053
if (IndirectAccessTrackingEnabled)
@@ -3364,7 +3364,7 @@ static pi_result ZeDeviceMemAllocHelper(void **ResultPtr, pi_context Context,
3364
3364
// otherwise just calls zeMemAllocHost.
3365
3365
static pi_result ZeHostMemAllocHelper (void **ResultPtr, pi_context Context,
3366
3366
size_t Size) {
3367
- pi_platform Plt = Context->Devices [ 0 ]-> Platform ;
3367
+ pi_platform Plt = Context->getPlatform () ;
3368
3368
std::unique_lock<std::mutex> ContextsLock (Plt->ContextsMutex ,
3369
3369
std::defer_lock);
3370
3370
if (IndirectAccessTrackingEnabled) {
@@ -3439,7 +3439,7 @@ pi_result piMemBufferCreate(pi_context Context, pi_mem_flags Flags, size_t Size,
3439
3439
// If not shared of any type, we can import the ptr
3440
3440
if (ZeMemoryAllocationProperties.type == ZE_MEMORY_TYPE_UNKNOWN) {
3441
3441
// Promote the host ptr to USM host memory
3442
- ze_driver_handle_t driverHandle = Context->Devices [ 0 ]-> Platform ->ZeDriver ;
3442
+ ze_driver_handle_t driverHandle = Context->getPlatform () ->ZeDriver ;
3443
3443
ZeUSMImport.doZeUSMImport (driverHandle, HostPtr, Size);
3444
3444
HostPtrImported = true ;
3445
3445
}
@@ -3449,8 +3449,7 @@ pi_result piMemBufferCreate(pi_context Context, pi_mem_flags Flags, size_t Size,
3449
3449
auto HostPtrOrNull =
3450
3450
(Flags & PI_MEM_FLAGS_HOST_PTR_USE) ? pi_cast<char *>(HostPtr) : nullptr ;
3451
3451
try {
3452
- Buffer = new _pi_buffer (Context, Size, HostPtrOrNull,
3453
- true /* OwnZeMemHandle */ , HostPtrImported);
3452
+ Buffer = new _pi_buffer (Context, Size, HostPtrOrNull, HostPtrImported);
3454
3453
} catch (const std::bad_alloc &) {
3455
3454
return PI_OUT_OF_HOST_MEMORY;
3456
3455
} catch (...) {
@@ -3499,12 +3498,8 @@ pi_result piMemGetInfo(pi_mem Mem, pi_mem_info ParamName, size_t ParamValueSize,
3499
3498
return ReturnValue (Mem->Context );
3500
3499
case PI_MEM_SIZE: {
3501
3500
// Get size of the allocation
3502
- size_t Size;
3503
- ZE_CALL (zeMemGetAddressRange,
3504
- (Mem->Context ->ZeContext , pi_cast<void *>(Mem->getZeHandle ()),
3505
- nullptr , &Size));
3506
-
3507
- return ReturnValue (Size);
3501
+ auto Buffer = pi_cast<pi_buffer>(Mem);
3502
+ return ReturnValue (size_t {Buffer->Size });
3508
3503
}
3509
3504
default :
3510
3505
die (" piMemGetInfo: Parameter is not implemented" );
@@ -3526,7 +3521,7 @@ pi_result piMemRetain(pi_mem Mem) {
3526
3521
// performed.
3527
3522
static pi_result ZeMemFreeHelper (pi_context Context, void *Ptr,
3528
3523
bool OwnZeMemHandle = true ) {
3529
- pi_platform Plt = Context->Devices [ 0 ]-> Platform ;
3524
+ pi_platform Plt = Context->getPlatform () ;
3530
3525
std::unique_lock<std::mutex> ContextsLock (Plt->ContextsMutex ,
3531
3526
std::defer_lock);
3532
3527
if (IndirectAccessTrackingEnabled) {
@@ -3555,7 +3550,7 @@ static pi_result ZeMemFreeHelper(pi_context Context, void *Ptr,
3555
3550
}
3556
3551
3557
3552
static pi_result USMFreeHelper (pi_context Context, void *Ptr,
3558
- bool OwnZeMemHandle = true );
3553
+ bool OwnZeMemHandle);
3559
3554
3560
3555
pi_result piMemRelease (pi_mem Mem) {
3561
3556
PI_ASSERT (Mem, PI_INVALID_MEM_OBJECT);
@@ -3729,7 +3724,7 @@ pi_result piMemImageCreate(pi_context Context, pi_mem_flags Flags,
3729
3724
(Flags & PI_MEM_FLAGS_HOST_PTR_USE) ? pi_cast<char *>(HostPtr) : nullptr ;
3730
3725
3731
3726
try {
3732
- auto ZePIImage = new _pi_image (Context, ZeHImage, HostPtrOrNull, true );
3727
+ auto ZePIImage = new _pi_image (Context, ZeHImage, HostPtrOrNull);
3733
3728
3734
3729
#ifndef NDEBUG
3735
3730
ZePIImage->ZeImageDesc = ZeImageDesc;
@@ -3767,25 +3762,24 @@ pi_result piextMemCreateWithNativeHandle(pi_native_handle NativeHandle,
3767
3762
PI_ASSERT (Mem, PI_INVALID_VALUE);
3768
3763
PI_ASSERT (NativeHandle, PI_INVALID_VALUE);
3769
3764
PI_ASSERT (Context, PI_INVALID_CONTEXT);
3770
- PI_ASSERT (Context->Devices .size () == 1 , PI_INVALID_CONTEXT);
3771
3765
3772
3766
// Get base of the allocation
3773
3767
void *Base;
3774
3768
size_t Size;
3775
3769
void *Ptr = pi_cast<void *>(NativeHandle);
3776
3770
ZE_CALL (zeMemGetAddressRange, (Context->ZeContext , Ptr, &Base, &Size));
3777
-
3778
3771
PI_ASSERT (Ptr == Base, PI_INVALID_VALUE);
3779
3772
3780
- // Check type of the allocation
3781
3773
ZeStruct<ze_memory_allocation_properties_t > ZeMemProps;
3782
- ze_device_handle_t ZeDevice;
3774
+ ze_device_handle_t ZeDevice = nullptr ;
3783
3775
ZE_CALL (zeMemGetAllocProperties,
3784
3776
(Context->ZeContext , Ptr, &ZeMemProps, &ZeDevice));
3785
- bool OnHost = false ;
3777
+
3778
+ // Check type of the allocation
3779
+ bool HostAllocation = false ;
3786
3780
switch (ZeMemProps.type ) {
3787
3781
case ZE_MEMORY_TYPE_HOST:
3788
- OnHost = true ;
3782
+ HostAllocation = true ;
3789
3783
break ;
3790
3784
case ZE_MEMORY_TYPE_SHARED:
3791
3785
case ZE_MEMORY_TYPE_DEVICE:
@@ -3797,12 +3791,21 @@ pi_result piextMemCreateWithNativeHandle(pi_native_handle NativeHandle,
3797
3791
die (" Unexpected memory type" );
3798
3792
}
3799
3793
3794
+ pi_device Device = nullptr ;
3795
+ if (ZeDevice) {
3796
+ Device = Context->getPlatform ()->getDeviceFromNativeHandle (ZeDevice);
3797
+ // Check that the device is present in this context.
3798
+ if (std::find (Context->Devices .begin (), Context->Devices .end (), Device) ==
3799
+ Context->Devices .end ()) {
3800
+ return PI_INVALID_CONTEXT;
3801
+ }
3802
+ }
3803
+
3800
3804
try {
3801
- *Mem = new _pi_buffer (Context, pi_cast<char *>(NativeHandle),
3802
- nullptr /* HostPtr */ , ownNativeHandle, nullptr , 0 , 0 ,
3803
- OnHost);
3805
+ *Mem = new _pi_buffer (Context, Size, Device, HostAllocation,
3806
+ pi_cast<char *>(NativeHandle), ownNativeHandle);
3804
3807
3805
- pi_platform Plt = Context->Devices [ 0 ]-> Platform ;
3808
+ pi_platform Plt = Context->getPlatform () ;
3806
3809
std::unique_lock<std::mutex> ContextsLock (Plt->ContextsMutex ,
3807
3810
std::defer_lock);
3808
3811
if (IndirectAccessTrackingEnabled) {
@@ -3821,6 +3824,26 @@ pi_result piextMemCreateWithNativeHandle(pi_native_handle NativeHandle,
3821
3824
} catch (...) {
3822
3825
return PI_ERROR_UNKNOWN;
3823
3826
}
3827
+
3828
+ // Initialize the buffer as necessary
3829
+ auto Buffer = pi_cast<pi_buffer>(*Mem);
3830
+ if (Device) {
3831
+ // If this allocation is on a device, then we re-use it for the buffer.
3832
+ // Nothing to do.
3833
+ } else if (HostAllocation && Buffer->OnHost ) {
3834
+ // If this is host allocation and buffer always stays on host there
3835
+ // nothing more to do.
3836
+ } else {
3837
+ // In all other cases (shared allocation, or host allocation that cannot
3838
+ // represent the buffer in this context) copy the data to a newly
3839
+ // created device allocation.
3840
+ char *ZeHandleDst;
3841
+ PI_CALL (Buffer->getZeHandle (ZeHandleDst, _pi_mem::write_only));
3842
+ ZE_CALL (zeCommandListAppendMemoryCopy,
3843
+ (Context->ZeCommandListInit , ZeHandleDst, Ptr, Size, nullptr , 0 ,
3844
+ nullptr ));
3845
+ }
3846
+
3824
3847
return PI_SUCCESS;
3825
3848
}
3826
3849
@@ -4730,7 +4753,7 @@ pi_result piKernelRelease(pi_kernel Kernel) {
4730
4753
// then release referenced memory allocations. As a result, memory can be
4731
4754
// deallocated and context can be removed from container in the platform.
4732
4755
// That's why we need to lock a mutex here.
4733
- pi_platform Plt = Kernel->Program ->Context ->Devices [ 0 ]-> Platform ;
4756
+ pi_platform Plt = Kernel->Program ->Context ->getPlatform () ;
4734
4757
std::lock_guard<std::mutex> ContextsLock (Plt->ContextsMutex );
4735
4758
4736
4759
if (--Kernel->SubmissionsCount == 0 ) {
@@ -5518,8 +5541,7 @@ pi_result piSamplerCreate(pi_context Context,
5518
5541
// border", i.e. logic is flipped. Starting from API version 1.3 this
5519
5542
// problem is going to be fixed. That's why check for API version to set
5520
5543
// an address mode.
5521
- ze_api_version_t ZeApiVersion =
5522
- Context->Devices [0 ]->Platform ->ZeApiVersion ;
5544
+ ze_api_version_t ZeApiVersion = Context->getPlatform ()->ZeApiVersion ;
5523
5545
// TODO: add support for PI_SAMPLER_ADDRESSING_MODE_CLAMP_TO_EDGE
5524
5546
switch (CurValueAddressingMode) {
5525
5547
case PI_SAMPLER_ADDRESSING_MODE_NONE:
@@ -6746,7 +6768,7 @@ pi_result piMemBufferPartition(pi_mem Buffer, pi_mem_flags Flags,
6746
6768
void *BufferCreateInfo, pi_mem *RetMem) {
6747
6769
6748
6770
PI_ASSERT (Buffer && !Buffer->isImage () &&
6749
- !(static_cast <_pi_buffer * >(Buffer))->isSubBuffer (),
6771
+ !(static_cast <pi_buffer >(Buffer))->isSubBuffer (),
6750
6772
PI_INVALID_MEM_OBJECT);
6751
6773
6752
6774
PI_ASSERT (BufferCreateType == PI_BUFFER_CREATE_TYPE_REGION &&
@@ -7211,7 +7233,7 @@ pi_result piextUSMHostAlloc(void **ResultPtr, pi_context Context,
7211
7233
if (Alignment > 65536 )
7212
7234
return PI_INVALID_VALUE;
7213
7235
7214
- pi_platform Plt = Context->Devices [ 0 ]-> Platform ;
7236
+ pi_platform Plt = Context->getPlatform () ;
7215
7237
std::unique_lock<std::mutex> ContextsLock (Plt->ContextsMutex ,
7216
7238
std::defer_lock);
7217
7239
if (IndirectAccessTrackingEnabled) {
@@ -7332,7 +7354,7 @@ static pi_result USMFreeHelper(pi_context Context, void *Ptr,
7332
7354
PI_ASSERT (Device->ZeDevice == ZeDeviceHandle, PI_INVALID_DEVICE);
7333
7355
} else {
7334
7356
// All devices in the context are of the same platform.
7335
- auto Platform = Context->Devices [ 0 ]-> Platform ;
7357
+ auto Platform = Context->getPlatform () ;
7336
7358
Device = Platform->getDeviceFromNativeHandle (ZeDeviceHandle);
7337
7359
PI_ASSERT (Device, PI_INVALID_DEVICE);
7338
7360
}
@@ -7382,12 +7404,12 @@ static pi_result USMFreeHelper(pi_context Context, void *Ptr,
7382
7404
}
7383
7405
7384
7406
pi_result piextUSMFree (pi_context Context, void *Ptr) {
7385
- pi_platform Plt = Context->Devices [ 0 ]-> Platform ;
7407
+ pi_platform Plt = Context->getPlatform () ;
7386
7408
std::unique_lock<std::mutex> ContextsLock (Plt->ContextsMutex ,
7387
7409
std::defer_lock);
7388
7410
if (IndirectAccessTrackingEnabled)
7389
7411
ContextsLock.lock ();
7390
- return USMFreeHelper (Context, Ptr);
7412
+ return USMFreeHelper (Context, Ptr, true /* OwnZeMemHandle */ );
7391
7413
}
7392
7414
7393
7415
pi_result piextKernelSetArgPointer (pi_kernel Kernel, pi_uint32 ArgIndex,
@@ -7651,8 +7673,7 @@ pi_result piextUSMGetMemAllocInfo(pi_context Context, const void *Ptr,
7651
7673
}
7652
7674
case PI_MEM_ALLOC_DEVICE:
7653
7675
if (ZeDeviceHandle) {
7654
- // All devices in the context are of the same platform.
7655
- auto Platform = Context->Devices [0 ]->Platform ;
7676
+ auto Platform = Context->getPlatform ();
7656
7677
auto Device = Platform->getDeviceFromNativeHandle (ZeDeviceHandle);
7657
7678
return Device ? ReturnValue (Device) : PI_INVALID_VALUE;
7658
7679
} else {
@@ -7961,23 +7982,31 @@ pi_result _pi_buffer::free() {
7961
7982
return PI_SUCCESS;
7962
7983
}
7963
7984
if (HostPtrImported) {
7964
- ze_driver_handle_t DriverHandle = Context->Devices [ 0 ]-> Platform ->ZeDriver ;
7985
+ ze_driver_handle_t DriverHandle = Context->getPlatform () ->ZeDriver ;
7965
7986
ZeUSMImport.doZeUSMRelease (DriverHandle, MapHostPtr);
7966
7987
if (OnHost) {
7967
7988
// We were using imported host pointer, so nothing to free.
7968
7989
return PI_SUCCESS;
7969
7990
}
7970
7991
}
7971
7992
7993
+ pi_platform Plt = Context->getPlatform ();
7994
+ std::unique_lock<std::mutex> ContextsLock (Plt->ContextsMutex ,
7995
+ std::defer_lock);
7996
+ if (IndirectAccessTrackingEnabled)
7997
+ ContextsLock.lock ();
7998
+
7972
7999
for (auto &Alloc : Allocations) {
7973
8000
auto Device = Alloc.first ;
7974
8001
if (Context->SingleRootDevice && Context->SingleRootDevice != Device) {
7975
8002
// These were re-using root-device allocations
7976
8003
}
7977
8004
// It is possible that the real allocation wasn't made if the buffer
7978
8005
// wasn't really used on this device.
7979
- if (Alloc.second .ZeHandle )
7980
- PI_CALL (piextUSMFree (Context, Alloc.second .ZeHandle ));
8006
+ if (Alloc.second .ZeHandle ) {
8007
+ PI_CALL (USMFreeHelper (Context, Alloc.second .ZeHandle ,
8008
+ Alloc.second .ZeHandle != NotOwnZeMemHandle));
8009
+ }
7981
8010
}
7982
8011
return PI_SUCCESS;
7983
8012
}
0 commit comments