|
19 | 19 | #include "helpers/memory_helpers.hpp"
|
20 | 20 | #include "image_common.hpp"
|
21 | 21 | #include "logger/ur_logger.hpp"
|
| 22 | +#include "platform.hpp" |
22 | 23 | #include "sampler.hpp"
|
23 | 24 | #include "ur_interface_loader.hpp"
|
24 | 25 |
|
25 |
| -typedef ze_result_t(ZE_APICALL *zeMemGetPitchFor2dImage_pfn)( |
26 |
| - ze_context_handle_t hContext, ze_device_handle_t hDevice, size_t imageWidth, |
27 |
| - size_t imageHeight, unsigned int elementSizeInBytes, size_t *rowPitch); |
28 |
| - |
29 |
| -typedef ze_result_t(ZE_APICALL *zeImageGetDeviceOffsetExp_pfn)( |
30 |
| - ze_image_handle_t hImage, uint64_t *pDeviceOffset); |
31 |
| - |
32 |
| -zeMemGetPitchFor2dImage_pfn zeMemGetPitchFor2dImageFunctionPtr = nullptr; |
33 |
| -zeImageGetDeviceOffsetExp_pfn zeImageGetDeviceOffsetExpFunctionPtr = nullptr; |
34 |
| - |
35 | 26 | namespace {
|
36 | 27 |
|
37 | 28 | /// Construct UR image format from ZE image desc.
|
@@ -370,26 +361,16 @@ ur_result_t bindlessImagesCreateImpl(ur_context_handle_t hContext,
|
370 | 361 | return UR_RESULT_ERROR_INVALID_VALUE;
|
371 | 362 | }
|
372 | 363 |
|
373 |
| - static std::once_flag InitFlag; |
374 |
| - std::call_once(InitFlag, [&]() { |
375 |
| - ze_driver_handle_t DriverHandle = hContext->getPlatform()->ZeDriver; |
376 |
| - auto Result = zeDriverGetExtensionFunctionAddress( |
377 |
| - DriverHandle, "zeImageGetDeviceOffsetExp", |
378 |
| - (void **)&zeImageGetDeviceOffsetExpFunctionPtr); |
379 |
| - if (Result != ZE_RESULT_SUCCESS) |
380 |
| - UR_LOG(ERR, |
381 |
| - "zeDriverGetExtensionFunctionAddress " |
382 |
| - "zeImageGetDeviceOffsetExpv failed, err = {}", |
383 |
| - Result); |
384 |
| - }); |
385 |
| - if (!zeImageGetDeviceOffsetExpFunctionPtr) |
| 364 | + if (!hDevice->Platform->ZeImageGetDeviceOffsetExt.Supported) |
386 | 365 | return UR_RESULT_ERROR_INVALID_OPERATION;
|
| 366 | + |
387 | 367 | uint64_t DeviceOffset{};
|
388 | 368 | ze_image_handle_t ZeImageTranslated;
|
389 | 369 | ZE2UR_CALL(zelLoaderTranslateHandle,
|
390 | 370 | (ZEL_HANDLE_IMAGE, ZeImage.get(), (void **)&ZeImageTranslated));
|
391 |
| - ZE2UR_CALL(zeImageGetDeviceOffsetExpFunctionPtr, |
392 |
| - (ZeImageTranslated, &DeviceOffset)); |
| 371 | + ZE2UR_CALL( |
| 372 | + hDevice->Platform->ZeImageGetDeviceOffsetExt.zeImageGetDeviceOffsetExp, |
| 373 | + (ZeImageTranslated, &DeviceOffset)); |
393 | 374 | *phImage = DeviceOffset;
|
394 | 375 |
|
395 | 376 | std::shared_lock<ur_shared_mutex> Lock(hDevice->Mutex);
|
@@ -1078,29 +1059,19 @@ ur_result_t urUSMPitchedAllocExp(ur_context_handle_t hContext,
|
1078 | 1059 | UR_ASSERT(widthInBytes != 0, UR_RESULT_ERROR_INVALID_USM_SIZE);
|
1079 | 1060 | UR_ASSERT(ppMem && pResultPitch, UR_RESULT_ERROR_INVALID_NULL_POINTER);
|
1080 | 1061 |
|
1081 |
| - static std::once_flag InitFlag; |
1082 |
| - std::call_once(InitFlag, [&]() { |
1083 |
| - ze_driver_handle_t DriverHandle = hContext->getPlatform()->ZeDriver; |
1084 |
| - auto Result = zeDriverGetExtensionFunctionAddress( |
1085 |
| - DriverHandle, "zeMemGetPitchFor2dImage", |
1086 |
| - (void **)&zeMemGetPitchFor2dImageFunctionPtr); |
1087 |
| - if (Result != ZE_RESULT_SUCCESS) |
1088 |
| - UR_LOG(ERR, |
1089 |
| - "zeDriverGetExtensionFunctionAddress zeMemGetPitchFor2dImage " |
1090 |
| - "failed, err = {}", |
1091 |
| - Result); |
1092 |
| - }); |
1093 |
| - if (!zeMemGetPitchFor2dImageFunctionPtr) |
| 1062 | + if (!hDevice->Platform->ZeMemGetPitchFor2dImageExt.Supported) { |
1094 | 1063 | return UR_RESULT_ERROR_INVALID_OPERATION;
|
| 1064 | + } |
1095 | 1065 |
|
1096 | 1066 | size_t Width = widthInBytes / elementSizeBytes;
|
1097 | 1067 | size_t RowPitch;
|
1098 | 1068 | ze_device_handle_t ZeDeviceTranslated;
|
1099 | 1069 | ZE2UR_CALL(zelLoaderTranslateHandle, (ZEL_HANDLE_DEVICE, hDevice->ZeDevice,
|
1100 | 1070 | (void **)&ZeDeviceTranslated));
|
1101 |
| - ZE2UR_CALL(zeMemGetPitchFor2dImageFunctionPtr, |
1102 |
| - (hContext->getZeHandle(), ZeDeviceTranslated, Width, height, |
1103 |
| - elementSizeBytes, &RowPitch)); |
| 1071 | + ZE2UR_CALL( |
| 1072 | + hDevice->Platform->ZeMemGetPitchFor2dImageExt.zeMemGetPitchFor2dImage, |
| 1073 | + (hContext->getZeHandle(), ZeDeviceTranslated, Width, height, |
| 1074 | + elementSizeBytes, &RowPitch)); |
1104 | 1075 | *pResultPitch = RowPitch;
|
1105 | 1076 |
|
1106 | 1077 | size_t Size = height * RowPitch;
|
|
0 commit comments