@@ -2697,12 +2697,28 @@ inline pi_result piMemBufferCreate(pi_context Context, pi_mem_flags Flags,
2697
2697
inline pi_result piextUSMHostAlloc (void **ResultPtr, pi_context Context,
2698
2698
pi_usm_mem_properties *Properties,
2699
2699
size_t Size, pi_uint32 Alignment) {
2700
+ ur_usm_desc_t USMDesc{};
2701
+ USMDesc.align = Alignment;
2702
+
2703
+ ur_usm_alloc_location_desc_t UsmLocationDesc{};
2704
+ UsmLocationDesc.stype = UR_STRUCTURE_TYPE_USM_ALLOC_LOCATION_DESC;
2705
+
2706
+ if (Properties) {
2707
+ uint32_t Next = 0 ;
2708
+ while (Properties[Next]) {
2709
+ if (Properties[Next] == PI_MEM_USM_ALLOC_BUFFER_LOCATION) {
2710
+ UsmLocationDesc.location = static_cast <uint32_t >(Properties[Next + 1 ]);
2711
+ USMDesc.pNext = &UsmLocationDesc;
2712
+ } else {
2713
+ return PI_ERROR_INVALID_VALUE;
2714
+ }
2715
+ Next += 2 ;
2716
+ }
2717
+ }
2700
2718
2701
- std::ignore = Properties;
2702
2719
ur_context_handle_t UrContext =
2703
2720
reinterpret_cast <ur_context_handle_t >(Context);
2704
- ur_usm_desc_t USMDesc{};
2705
- USMDesc.align = Alignment;
2721
+
2706
2722
ur_usm_pool_handle_t Pool{};
2707
2723
HANDLE_ERRORS (urUSMHostAlloc (UrContext, &USMDesc, Pool, Size, ResultPtr));
2708
2724
return PI_SUCCESS;
@@ -3131,14 +3147,29 @@ inline pi_result piextUSMDeviceAlloc(void **ResultPtr, pi_context Context,
3131
3147
pi_device Device,
3132
3148
pi_usm_mem_properties *Properties,
3133
3149
size_t Size, pi_uint32 Alignment) {
3134
-
3135
- std::ignore = Properties;
3136
3150
ur_context_handle_t UrContext =
3137
3151
reinterpret_cast <ur_context_handle_t >(Context);
3138
3152
auto UrDevice = reinterpret_cast <ur_device_handle_t >(Device);
3139
3153
3140
3154
ur_usm_desc_t USMDesc{};
3141
3155
USMDesc.align = Alignment;
3156
+
3157
+ ur_usm_alloc_location_desc_t UsmLocDesc{};
3158
+ UsmLocDesc.stype = UR_STRUCTURE_TYPE_USM_ALLOC_LOCATION_DESC;
3159
+
3160
+ if (Properties) {
3161
+ uint32_t Next = 0 ;
3162
+ while (Properties[Next]) {
3163
+ if (Properties[Next] == PI_MEM_USM_ALLOC_BUFFER_LOCATION) {
3164
+ UsmLocDesc.location = static_cast <uint32_t >(Properties[Next + 1 ]);
3165
+ USMDesc.pNext = &UsmLocDesc;
3166
+ } else {
3167
+ return PI_ERROR_INVALID_VALUE;
3168
+ }
3169
+ Next += 2 ;
3170
+ }
3171
+ }
3172
+
3142
3173
ur_usm_pool_handle_t Pool{};
3143
3174
HANDLE_ERRORS (
3144
3175
urUSMDeviceAlloc (UrContext, UrDevice, &USMDesc, Pool, Size, ResultPtr));
@@ -3171,42 +3202,58 @@ inline pi_result piextUSMSharedAlloc(void **ResultPtr, pi_context Context,
3171
3202
pi_device Device,
3172
3203
pi_usm_mem_properties *Properties,
3173
3204
size_t Size, pi_uint32 Alignment) {
3174
-
3175
- std::ignore = Properties;
3176
- if (Properties && *Properties != 0 ) {
3177
- PI_ASSERT (*(Properties) == PI_MEM_ALLOC_FLAGS && *(Properties + 2 ) == 0 ,
3178
- PI_ERROR_INVALID_VALUE);
3179
- }
3180
-
3181
3205
ur_context_handle_t UrContext =
3182
3206
reinterpret_cast <ur_context_handle_t >(Context);
3183
3207
auto UrDevice = reinterpret_cast <ur_device_handle_t >(Device);
3184
3208
3185
3209
ur_usm_desc_t USMDesc{};
3210
+ USMDesc.align = Alignment;
3186
3211
ur_usm_device_desc_t UsmDeviceDesc{};
3187
3212
UsmDeviceDesc.stype = UR_STRUCTURE_TYPE_USM_DEVICE_DESC;
3188
3213
ur_usm_host_desc_t UsmHostDesc{};
3189
3214
UsmHostDesc.stype = UR_STRUCTURE_TYPE_USM_HOST_DESC;
3215
+ ur_usm_alloc_location_desc_t UsmLocationDesc{};
3216
+ UsmLocationDesc.stype = UR_STRUCTURE_TYPE_USM_ALLOC_LOCATION_DESC;
3217
+
3218
+ // One properties bitfield can correspond to a host_desc and a device_desc
3219
+ // struct, since having `0` values in these is harmless we can set up this
3220
+ // pNext chain in advance.
3221
+ USMDesc.pNext = &UsmDeviceDesc;
3222
+ UsmDeviceDesc.pNext = &UsmHostDesc;
3223
+
3190
3224
if (Properties) {
3191
- if (Properties[0 ] == PI_MEM_ALLOC_FLAGS) {
3192
- if (Properties[1 ] == PI_MEM_ALLOC_WRTITE_COMBINED) {
3193
- UsmDeviceDesc.flags |= UR_USM_DEVICE_MEM_FLAG_WRITE_COMBINED;
3194
- }
3195
- if (Properties[1 ] == PI_MEM_ALLOC_INITIAL_PLACEMENT_DEVICE) {
3196
- UsmDeviceDesc.flags |= UR_USM_DEVICE_MEM_FLAG_INITIAL_PLACEMENT;
3225
+ uint32_t Next = 0 ;
3226
+ while (Properties[Next]) {
3227
+ switch (Properties[Next]) {
3228
+ case PI_MEM_ALLOC_FLAGS: {
3229
+ if (Properties[Next + 1 ] & PI_MEM_ALLOC_WRTITE_COMBINED) {
3230
+ UsmDeviceDesc.flags |= UR_USM_DEVICE_MEM_FLAG_WRITE_COMBINED;
3231
+ }
3232
+ if (Properties[Next + 1 ] & PI_MEM_ALLOC_INITIAL_PLACEMENT_DEVICE) {
3233
+ UsmDeviceDesc.flags |= UR_USM_DEVICE_MEM_FLAG_INITIAL_PLACEMENT;
3234
+ }
3235
+ if (Properties[Next + 1 ] & PI_MEM_ALLOC_INITIAL_PLACEMENT_HOST) {
3236
+ UsmHostDesc.flags |= UR_USM_HOST_MEM_FLAG_INITIAL_PLACEMENT;
3237
+ }
3238
+ if (Properties[Next + 1 ] & PI_MEM_ALLOC_DEVICE_READ_ONLY) {
3239
+ UsmDeviceDesc.flags |= UR_USM_DEVICE_MEM_FLAG_DEVICE_READ_ONLY;
3240
+ }
3241
+ break ;
3197
3242
}
3198
- if (Properties[1 ] == PI_MEM_ALLOC_INITIAL_PLACEMENT_HOST) {
3199
- UsmHostDesc.flags |= UR_USM_HOST_MEM_FLAG_INITIAL_PLACEMENT;
3243
+ case PI_MEM_USM_ALLOC_BUFFER_LOCATION: {
3244
+ UsmLocationDesc.location = static_cast <uint32_t >(Properties[Next + 1 ]);
3245
+ // We wait until we've seen a BUFFER_LOCATION property to tack this
3246
+ // onto the end of the chain, a `0` here might be valid as far as we
3247
+ // know so we must exclude it unless we've been given a value.
3248
+ UsmHostDesc.pNext = &UsmLocationDesc;
3249
+ break ;
3200
3250
}
3201
- if (Properties[ 1 ] == PI_MEM_ALLOC_DEVICE_READ_ONLY) {
3202
- UsmDeviceDesc. flags |= UR_USM_DEVICE_MEM_FLAG_DEVICE_READ_ONLY ;
3251
+ default :
3252
+ return PI_ERROR_INVALID_VALUE ;
3203
3253
}
3254
+ Next += 2 ;
3204
3255
}
3205
3256
}
3206
- UsmDeviceDesc.pNext = &UsmHostDesc;
3207
- USMDesc.pNext = &UsmDeviceDesc;
3208
-
3209
- USMDesc.align = Alignment;
3210
3257
3211
3258
ur_usm_pool_handle_t Pool{};
3212
3259
HANDLE_ERRORS (
0 commit comments