20
20
#include " ../helpers/kernel_helpers.hpp"
21
21
#include " ../platform.hpp"
22
22
#include " ../program.hpp"
23
+ #include " ../sampler.hpp"
23
24
#include " ../ur_interface_loader.hpp"
24
25
25
26
ur_single_device_kernel_t ::ur_single_device_kernel_t (ur_device_handle_t hDevice,
@@ -378,17 +379,15 @@ ur_result_t urKernelRelease(
378
379
}
379
380
380
381
ur_result_t urKernelSetArgValue (
381
- // / [in] handle of the kernel object
382
- ur_kernel_handle_t hKernel,
383
- // / [in] argument index in range [0, num args - 1]
384
- uint32_t argIndex,
385
- // / [in] size of argument type
386
- size_t argSize,
387
- // / [in][optional] argument properties
388
- const ur_kernel_arg_value_properties_t *pProperties,
389
- // / [in] argument value represented as matching arg type.
390
- const void *pArgValue) try {
391
- TRACK_SCOPE_LATENCY (" ur_kernel_handle_t_::setArgValue" );
382
+ ur_kernel_handle_t hKernel, // /< [in] handle of the kernel object
383
+ uint32_t argIndex, // /< [in] argument index in range [0, num args - 1]
384
+ size_t argSize, // /< [in] size of argument type
385
+ const ur_kernel_arg_value_properties_t
386
+ *pProperties, // /< [in][optional] argument properties
387
+ const void
388
+ *pArgValue // /< [in] argument value represented as matching arg type.
389
+ ) try {
390
+ TRACK_SCOPE_LATENCY (" urKernelSetArgValue" );
392
391
393
392
std::scoped_lock<ur_shared_mutex> guard (hKernel->Mutex );
394
393
return hKernel->setArgValue (argIndex, argSize, pProperties, pArgValue);
@@ -397,15 +396,14 @@ ur_result_t urKernelSetArgValue(
397
396
}
398
397
399
398
ur_result_t urKernelSetArgPointer (
400
- // / [in] handle of the kernel object
401
- ur_kernel_handle_t hKernel,
402
- // / [in] argument index in range [0, num args - 1]
403
- uint32_t argIndex,
404
- // / [in][optional] argument properties
405
- const ur_kernel_arg_pointer_properties_t *pProperties,
406
- // / [in] argument value represented as matching arg type.
407
- const void *pArgValue) try {
408
- TRACK_SCOPE_LATENCY (" ur_kernel_handle_t_::setArgPointer" );
399
+ ur_kernel_handle_t hKernel, // /< [in] handle of the kernel object
400
+ uint32_t argIndex, // /< [in] argument index in range [0, num args - 1]
401
+ const ur_kernel_arg_pointer_properties_t
402
+ *pProperties, // /< [in][optional] argument properties
403
+ const void
404
+ *pArgValue // /< [in] argument value represented as matching arg type.
405
+ ) try {
406
+ TRACK_SCOPE_LATENCY (" urKernelSetArgPointer" );
409
407
410
408
std::scoped_lock<ur_shared_mutex> guard (hKernel->Mutex );
411
409
return hKernel->setArgPointer (argIndex, pProperties, pArgValue);
@@ -434,7 +432,7 @@ ur_result_t
434
432
urKernelSetArgMemObj (ur_kernel_handle_t hKernel, uint32_t argIndex,
435
433
const ur_kernel_arg_mem_obj_properties_t *pProperties,
436
434
ur_mem_handle_t hArgValue) try {
437
- TRACK_SCOPE_LATENCY (" ur_kernel_handle_t_::setArgMemObj " );
435
+ TRACK_SCOPE_LATENCY (" urKernelSetArgMemObj " );
438
436
439
437
std::scoped_lock<ur_shared_mutex> guard (hKernel->Mutex );
440
438
@@ -450,7 +448,7 @@ ur_result_t
450
448
urKernelSetArgLocal (ur_kernel_handle_t hKernel, uint32_t argIndex,
451
449
size_t argSize,
452
450
const ur_kernel_arg_local_properties_t *pProperties) try {
453
- TRACK_SCOPE_LATENCY (" ur_kernel_handle_t_::setArgLocal " );
451
+ TRACK_SCOPE_LATENCY (" urKernelSetArgLocal " );
454
452
455
453
std::scoped_lock<ur_shared_mutex> guard (hKernel->Mutex );
456
454
@@ -697,4 +695,17 @@ ur_result_t urKernelSuggestMaxCooperativeGroupCountExp(
697
695
*pGroupCountRet = totalGroupCount;
698
696
return UR_RESULT_SUCCESS;
699
697
}
698
+
699
+ ur_result_t
700
+ urKernelSetArgSampler (ur_kernel_handle_t hKernel, uint32_t argIndex,
701
+ const ur_kernel_arg_sampler_properties_t *pProperties,
702
+ ur_sampler_handle_t hArgValue) try {
703
+ TRACK_SCOPE_LATENCY (" urKernelSetArgSampler" );
704
+ std::scoped_lock<ur_shared_mutex> guard (hKernel->Mutex );
705
+ std::ignore = pProperties;
706
+ return hKernel->setArgValue (argIndex, sizeof (void *), nullptr ,
707
+ &hArgValue->ZeSampler );
708
+ } catch (...) {
709
+ return exceptionToResult (std::current_exception ());
710
+ }
700
711
} // namespace ur::level_zero
0 commit comments