@@ -660,36 +660,65 @@ ur_result_t ur_queue_immediate_in_order_t::enqueueUSMMemcpy2D(
660
660
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
661
661
}
662
662
663
+ static void *getGlobalPointerFromModule (ze_module_handle_t hModule,
664
+ size_t offset, size_t count,
665
+ const char *name) {
666
+ // Find global variable pointer
667
+ size_t globalVarSize = 0 ;
668
+ void *globalVarPtr = nullptr ;
669
+ ZE2UR_CALL_THROWS (zeModuleGetGlobalPointer,
670
+ (hModule, name, &globalVarSize, &globalVarPtr));
671
+ if (globalVarSize < offset + count) {
672
+ setErrorMessage (" Write device global variable is out of range." ,
673
+ UR_RESULT_ERROR_INVALID_VALUE,
674
+ static_cast <int32_t >(ZE_RESULT_ERROR_INVALID_ARGUMENT));
675
+ throw UR_RESULT_ERROR_ADAPTER_SPECIFIC;
676
+ }
677
+ return globalVarPtr;
678
+ }
679
+
663
680
ur_result_t ur_queue_immediate_in_order_t::enqueueDeviceGlobalVariableWrite (
664
681
ur_program_handle_t hProgram, const char *name, bool blockingWrite,
665
682
size_t count, size_t offset, const void *pSrc, uint32_t numEventsInWaitList,
666
683
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
667
- std::ignore = hProgram;
668
- std::ignore = name;
669
- std::ignore = blockingWrite;
670
- std::ignore = count;
671
- std::ignore = offset;
672
- std::ignore = pSrc;
673
- std::ignore = numEventsInWaitList;
674
- std::ignore = phEventWaitList;
675
- std::ignore = phEvent;
676
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
684
+ // TODO: implement program->getZeModuleMap() to be sure that
685
+ // it's thread-safe
686
+ ze_module_handle_t zeModule{};
687
+ auto It = hProgram->ZeModuleMap .find (this ->hDevice ->ZeDevice );
688
+ if (It != hProgram->ZeModuleMap .end ()) {
689
+ zeModule = It->second ;
690
+ } else {
691
+ zeModule = hProgram->ZeModule ;
692
+ }
693
+
694
+ // Find global variable pointer
695
+ auto globalVarPtr = getGlobalPointerFromModule (zeModule, offset, count, name);
696
+
697
+ return enqueueUSMMemcpy (blockingWrite, ur_cast<char *>(globalVarPtr) + offset,
698
+ pSrc, count, numEventsInWaitList, phEventWaitList,
699
+ phEvent);
677
700
}
678
701
679
702
ur_result_t ur_queue_immediate_in_order_t::enqueueDeviceGlobalVariableRead (
680
703
ur_program_handle_t hProgram, const char *name, bool blockingRead,
681
704
size_t count, size_t offset, void *pDst, uint32_t numEventsInWaitList,
682
705
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
683
- std::ignore = hProgram;
684
- std::ignore = name;
685
- std::ignore = blockingRead;
686
- std::ignore = count;
687
- std::ignore = offset;
688
- std::ignore = pDst;
689
- std::ignore = numEventsInWaitList;
690
- std::ignore = phEventWaitList;
691
- std::ignore = phEvent;
692
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
706
+ // TODO: implement program->getZeModule() to be sure that
707
+ // it's thread-safe
708
+ ze_module_handle_t zeModule{};
709
+ auto It = hProgram->ZeModuleMap .find (this ->hDevice ->ZeDevice );
710
+ if (It != hProgram->ZeModuleMap .end ()) {
711
+ zeModule = It->second ;
712
+ } else {
713
+ zeModule = hProgram->ZeModule ;
714
+ }
715
+
716
+ // Find global variable pointer
717
+ auto globalVarPtr = getGlobalPointerFromModule (zeModule, offset, count, name);
718
+
719
+ return enqueueUSMMemcpy (blockingRead, pDst,
720
+ ur_cast<char *>(globalVarPtr) + offset, count,
721
+ numEventsInWaitList, phEventWaitList, phEvent);
693
722
}
694
723
695
724
ur_result_t ur_queue_immediate_in_order_t::enqueueReadHostPipe (
0 commit comments