10
10
11
11
#include < cassert>
12
12
#include < cstring>
13
+ #include < map>
13
14
#include < string>
14
15
#include < vector>
15
16
@@ -32,6 +33,14 @@ template <class To, class From> To cast(From value) {
32
33
// USM helper function to get an extension function pointer
33
34
template <typename T>
34
35
pi_result getExtFuncFromContext (pi_context context, const char *func, T *fptr) {
36
+ thread_local static std::map<pi_context, T> FuncPtrs;
37
+
38
+ // if cached, return cached FuncPtr
39
+ if (auto F = FuncPtrs[context]) {
40
+ *fptr = F;
41
+ return PI_SUCCESS;
42
+ }
43
+
35
44
size_t deviceCount;
36
45
cl_int ret_err = clGetContextInfo (
37
46
cast<cl_context>(context), CL_CONTEXT_DEVICES, 0 , nullptr , &deviceCount);
@@ -64,6 +73,8 @@ pi_result getExtFuncFromContext(pi_context context, const char *func, T *fptr) {
64
73
}
65
74
66
75
*fptr = FuncPtr;
76
+ FuncPtrs[context] = FuncPtr;
77
+
67
78
return cast<pi_result>(ret_err);
68
79
}
69
80
@@ -511,16 +522,14 @@ pi_result OCL(piextUSMHostAlloc)(void **result_ptr, pi_context context,
511
522
pi_result RetVal = PI_INVALID_OPERATION;
512
523
513
524
// First we need to look up the function pointer
514
- // It would be good if we could store it for future reuse
515
- // Are statics ok?
516
- clHostMemAllocINTEL_fn FuncPtr;
525
+ clHostMemAllocINTEL_fn FuncPtr = nullptr ;
517
526
RetVal = getExtFuncFromContext<clHostMemAllocINTEL_fn>(
518
527
context, " clHostMemAllocINTEL" , &FuncPtr);
519
528
520
- if (RetVal == PI_SUCCESS ) {
529
+ if (FuncPtr ) {
521
530
Ptr = FuncPtr (cast<cl_context>(context),
522
- cast<cl_mem_properties_intel *>(properties), size, alignment,
523
- cast<cl_int *>(&RetVal));
531
+ cast<cl_mem_properties_intel *>(properties), size, alignment,
532
+ cast<cl_int *>(&RetVal));
524
533
}
525
534
526
535
*result_ptr = Ptr;
@@ -545,13 +554,11 @@ pi_result OCL(piextUSMDeviceAlloc)(void **result_ptr, pi_context context,
545
554
pi_result RetVal = PI_INVALID_OPERATION;
546
555
547
556
// First we need to look up the function pointer
548
- // It would be good if we could store it for future reuse
549
- // Are statics ok?
550
- clDeviceMemAllocINTEL_fn FuncPtr;
557
+ clDeviceMemAllocINTEL_fn FuncPtr = nullptr ;
551
558
RetVal = getExtFuncFromContext<clDeviceMemAllocINTEL_fn>(
552
559
context, " clDeviceMemAllocINTEL" , &FuncPtr);
553
560
554
- if (RetVal == PI_SUCCESS ) {
561
+ if (FuncPtr ) {
555
562
Ptr = FuncPtr (cast<cl_context>(context), cast<cl_device_id>(device),
556
563
cast<cl_mem_properties_intel *>(properties), size, alignment,
557
564
cast<cl_int *>(&RetVal));
@@ -579,13 +586,11 @@ pi_result OCL(piextUSMSharedAlloc)(void **result_ptr, pi_context context,
579
586
pi_result RetVal = PI_INVALID_OPERATION;
580
587
581
588
// First we need to look up the function pointer
582
- // It would be good if we could store it for future reuse
583
- // Are statics ok?
584
- clSharedMemAllocINTEL_fn FuncPtr;
589
+ clSharedMemAllocINTEL_fn FuncPtr = nullptr ;
585
590
RetVal = getExtFuncFromContext<clSharedMemAllocINTEL_fn>(
586
591
context, " clSharedMemAllocINTEL" , &FuncPtr);
587
592
588
- if (RetVal == PI_SUCCESS ) {
593
+ if (FuncPtr ) {
589
594
Ptr = FuncPtr (cast<cl_context>(context), cast<cl_device_id>(device),
590
595
cast<cl_mem_properties_intel *>(properties), size, alignment,
591
596
cast<cl_int *>(&RetVal));
@@ -602,14 +607,16 @@ pi_result OCL(piextUSMSharedAlloc)(void **result_ptr, pi_context context,
602
607
// / @param ptr is the memory to be freed
603
608
pi_result OCL (piextUSMFree)(pi_context context, void *ptr) {
604
609
605
- clMemFreeINTEL_fn FuncPtr;
606
- pi_result RetVal = getExtFuncFromContext<clMemFreeINTEL_fn>(
607
- context, " clMemFreeINTEL" , &FuncPtr);
610
+ clMemFreeINTEL_fn FuncPtr = nullptr ;
611
+ pi_result RetVal = PI_INVALID_OPERATION;
612
+ RetVal = getExtFuncFromContext<clMemFreeINTEL_fn>(context, " clMemFreeINTEL" ,
613
+ &FuncPtr);
608
614
609
- if (RetVal != PI_SUCCESS)
610
- return RetVal;
615
+ if (FuncPtr) {
616
+ RetVal = cast<pi_result>(FuncPtr (cast<cl_context>(context), ptr));
617
+ }
611
618
612
- return cast<pi_result>( FuncPtr (cast<cl_context>(context), ptr)) ;
619
+ return RetVal ;
613
620
}
614
621
615
622
// / Sets up pointer arguments for CL kernels. An extra indirection
@@ -634,18 +641,19 @@ pi_result OCL(piextUSMKernelSetArgMemPointer)(pi_kernel kernel,
634
641
return cast<pi_result>(CLErr);
635
642
}
636
643
637
- clSetKernelArgMemPointerINTEL_fn FuncPtr;
644
+ clSetKernelArgMemPointerINTEL_fn FuncPtr = nullptr ;
638
645
pi_result RetVal = getExtFuncFromContext<clSetKernelArgMemPointerINTEL_fn>(
639
646
cast<pi_context>(CLContext), " clSetKernelArgMemPointerINTEL" , &FuncPtr);
640
647
641
- if (RetVal != PI_SUCCESS)
642
- return RetVal;
648
+ if (FuncPtr) {
649
+ // OpenCL passes pointers by value not by reference
650
+ // This means we need to deref the arg to get the pointer value
651
+ auto PtrToPtr = reinterpret_cast <const intptr_t *>(arg_value);
652
+ auto DerefPtr = reinterpret_cast <void *>(*PtrToPtr);
653
+ RetVal = cast<pi_result>(FuncPtr (cast<cl_kernel>(kernel), arg_index, DerefPtr));
654
+ }
643
655
644
- // OpenCL passes pointers by value not by reference
645
- // This means we need to deref the arg to get the pointer value
646
- auto PtrToPtr = reinterpret_cast <const intptr_t *>(arg_value);
647
- auto DerefPtr = reinterpret_cast <void *>(*PtrToPtr);
648
- return cast<pi_result>(FuncPtr (cast<cl_kernel>(kernel), arg_index, DerefPtr));
656
+ return RetVal;
649
657
}
650
658
651
659
// / Enables indirect access of pointers in kernels.
@@ -657,35 +665,35 @@ pi_result OCL(piextUSMKernelSetIndirectAccess)(pi_kernel kernel) {
657
665
// We test that each alloc type is supported before we actually try to
658
666
// set KernelExecInfo.
659
667
cl_bool TrueVal = CL_TRUE;
660
- pi_result RetVal;
661
- clHostMemAllocINTEL_fn HFunc;
662
- clSharedMemAllocINTEL_fn SFunc;
663
- clDeviceMemAllocINTEL_fn DFunc;
668
+ clHostMemAllocINTEL_fn HFunc = nullptr ;
669
+ clSharedMemAllocINTEL_fn SFunc = nullptr ;
670
+ clDeviceMemAllocINTEL_fn DFunc = nullptr ;
664
671
cl_context CLContext;
665
672
cl_int CLErr = clGetKernelInfo (cast<cl_kernel>(kernel), CL_KERNEL_CONTEXT,
666
673
sizeof (cl_context), &CLContext, nullptr );
667
674
if (CLErr != CL_SUCCESS) {
668
675
return cast<pi_result>(CLErr);
669
676
}
670
677
671
- // This would be really good to cache
672
- RetVal = getExtFuncFromContext<clHostMemAllocINTEL_fn>(
673
- cast<pi_context>(CLContext), " clHostMemAllocINTEL" , &HFunc);
674
- if (RetVal == PI_SUCCESS) {
678
+ getExtFuncFromContext<clHostMemAllocINTEL_fn>(cast<pi_context>(CLContext),
679
+ " clHostMemAllocINTEL" , &HFunc);
680
+ if (HFunc) {
675
681
clSetKernelExecInfo (cast<cl_kernel>(kernel),
676
682
CL_KERNEL_EXEC_INFO_INDIRECT_HOST_ACCESS_INTEL,
677
683
sizeof (cl_bool), &TrueVal);
678
684
}
679
- RetVal = getExtFuncFromContext<clDeviceMemAllocINTEL_fn>(
685
+
686
+ getExtFuncFromContext<clDeviceMemAllocINTEL_fn>(
680
687
cast<pi_context>(CLContext), " clDeviceMemAllocINTEL" , &DFunc);
681
- if (RetVal == PI_SUCCESS ) {
688
+ if (DFunc ) {
682
689
clSetKernelExecInfo (cast<cl_kernel>(kernel),
683
690
CL_KERNEL_EXEC_INFO_INDIRECT_DEVICE_ACCESS_INTEL,
684
691
sizeof (cl_bool), &TrueVal);
685
692
}
686
- RetVal = getExtFuncFromContext<clSharedMemAllocINTEL_fn>(
693
+
694
+ getExtFuncFromContext<clSharedMemAllocINTEL_fn>(
687
695
cast<pi_context>(CLContext), " clSharedMemAllocINTEL" , &SFunc);
688
- if (RetVal == PI_SUCCESS ) {
696
+ if (SFunc ) {
689
697
clSetKernelExecInfo (cast<cl_kernel>(kernel),
690
698
CL_KERNEL_EXEC_INFO_INDIRECT_SHARED_ACCESS_INTEL,
691
699
sizeof (cl_bool), &TrueVal);
@@ -718,16 +726,18 @@ pi_result OCL(piextUSMEnqueueMemset)(pi_queue queue, void *ptr, pi_int32 value,
718
726
return cast<pi_result>(CLErr);
719
727
}
720
728
721
- clEnqueueMemsetINTEL_fn FuncPtr;
729
+ clEnqueueMemsetINTEL_fn FuncPtr = nullptr ;
722
730
pi_result RetVal = getExtFuncFromContext<clEnqueueMemsetINTEL_fn>(
723
731
cast<pi_context>(CLContext), " clEnqueueMemsetINTEL" , &FuncPtr);
724
732
725
- if (RetVal != PI_SUCCESS)
726
- return RetVal;
733
+ if (FuncPtr) {
734
+ RetVal = cast<pi_result>(FuncPtr (cast<cl_command_queue>(queue), ptr, value,
735
+ count, num_events_in_waitlist,
736
+ cast<const cl_event *>(events_waitlist),
737
+ cast<cl_event *>(event)));
738
+ }
727
739
728
- return cast<pi_result>(FuncPtr (
729
- cast<cl_command_queue>(queue), ptr, value, count, num_events_in_waitlist,
730
- cast<const cl_event *>(events_waitlist), cast<cl_event *>(event)));
740
+ return RetVal;
731
741
}
732
742
733
743
// / USM Memcpy API
@@ -756,17 +766,18 @@ pi_result OCL(piextUSMEnqueueMemcpy)(pi_queue queue, pi_bool blocking,
756
766
return cast<pi_result>(CLErr);
757
767
}
758
768
759
- clEnqueueMemcpyINTEL_fn FuncPtr;
769
+ clEnqueueMemcpyINTEL_fn FuncPtr = nullptr ;
760
770
pi_result RetVal = getExtFuncFromContext<clEnqueueMemcpyINTEL_fn>(
761
771
cast<pi_context>(CLContext), " clEnqueueMemcpyINTEL" , &FuncPtr);
762
772
763
- if (RetVal != PI_SUCCESS)
764
- return RetVal;
773
+ if (FuncPtr) {
774
+ RetVal = cast<pi_result>(
775
+ FuncPtr (cast<cl_command_queue>(queue), blocking, dst_ptr, src_ptr, size,
776
+ num_events_in_waitlist, cast<const cl_event *>(events_waitlist),
777
+ cast<cl_event *>(event)));
778
+ }
765
779
766
- return cast<pi_result>(FuncPtr (cast<cl_command_queue>(queue), blocking,
767
- dst_ptr, src_ptr, size, num_events_in_waitlist,
768
- cast<const cl_event *>(events_waitlist),
769
- cast<cl_event *>(event)));
780
+ return RetVal;
770
781
}
771
782
772
783
// / Hint to migrate memory to the device
@@ -879,16 +890,17 @@ pi_result OCL(piextUSMGetMemAllocInfo)(pi_context context, const void *ptr,
879
890
void *param_value,
880
891
size_t *param_value_size_ret) {
881
892
882
- clGetMemAllocInfoINTEL_fn FuncPtr;
893
+ clGetMemAllocInfoINTEL_fn FuncPtr = nullptr ;
883
894
pi_result RetVal = getExtFuncFromContext<clGetMemAllocInfoINTEL_fn>(
884
895
context, " clGetMemAllocInfoINTEL" , &FuncPtr);
885
896
886
- if (RetVal != PI_SUCCESS)
887
- return RetVal;
897
+ if (FuncPtr) {
898
+ RetVal = cast<pi_result>(FuncPtr (cast<cl_context>(context), ptr, param_name,
899
+ param_value_size, param_value,
900
+ param_value_size_ret));
901
+ }
888
902
889
- return cast<pi_result>(FuncPtr (cast<cl_context>(context), ptr, param_name,
890
- param_value_size, param_value,
891
- param_value_size_ret));
903
+ return RetVal;
892
904
}
893
905
894
906
pi_result piPluginInit (pi_plugin *PluginInit) {
0 commit comments