@@ -806,6 +806,8 @@ class MemorySanitizerOnSpirv {
806
806
void initializeKernelCallerMap (Function *F);
807
807
808
808
private:
809
+ friend struct MemorySanitizerVisitor ;
810
+
809
811
Module &M;
810
812
LLVMContext &C;
811
813
const DataLayout &DL;
@@ -833,6 +835,7 @@ class MemorySanitizerOnSpirv {
833
835
FunctionCallee MsanBarrierFunc;
834
836
FunctionCallee MsanUnpoisonStackFunc;
835
837
FunctionCallee MsanSetPrivateBaseFunc;
838
+ FunctionCallee MsanUnpoisonStridedCopyFunc;
836
839
};
837
840
838
841
} // end anonymous namespace
@@ -899,14 +902,14 @@ void MemorySanitizerOnSpirv::initializeCallbacks() {
899
902
M.getOrInsertFunction (" __msan_unpoison_shadow_static_local" ,
900
903
IRB.getVoidTy (), IntptrTy, IntptrTy);
901
904
902
- // __asan_poison_shadow_dynamic_local (
905
+ // __msan_poison_shadow_dynamic_local (
903
906
// uptr ptr,
904
907
// uint32_t num_args
905
908
// )
906
909
MsanPoisonShadowDynamicLocalFunc = M.getOrInsertFunction (
907
910
" __msan_poison_shadow_dynamic_local" , IRB.getVoidTy (), IntptrTy, Int32Ty);
908
911
909
- // __asan_unpoison_shadow_dynamic_local (
912
+ // __msan_unpoison_shadow_dynamic_local (
910
913
// uptr ptr,
911
914
// uint32_t num_args
912
915
// )
@@ -930,6 +933,18 @@ void MemorySanitizerOnSpirv::initializeCallbacks() {
930
933
MsanSetPrivateBaseFunc =
931
934
M.getOrInsertFunction (" __msan_set_private_base" , IRB.getVoidTy (),
932
935
PointerType::get (C, kSpirOffloadPrivateAS ));
936
+
937
+ // __msan_unpoison_strided_copy(
938
+ // uptr dest, uint32_t dest_as,
939
+ // uptr src, uint32_t src_as,
940
+ // uint32_t element_size,
941
+ // uptr counts,
942
+ // uptr stride
943
+ // )
944
+ MsanUnpoisonStridedCopyFunc = M.getOrInsertFunction (
945
+ " __msan_unpoison_strided_copy" , IRB.getVoidTy (), IntptrTy,
946
+ IRB.getInt32Ty (), IntptrTy, IRB.getInt32Ty (), IRB.getInt32Ty (),
947
+ IRB.getInt64Ty (), IRB.getInt64Ty ());
933
948
}
934
949
935
950
// Handle global variables:
@@ -1833,7 +1848,8 @@ static void setNoSanitizedMetadataSPIR(Instruction &I) {
1833
1848
}
1834
1849
} else {
1835
1850
auto FuncName = Func->getName ();
1836
- if (FuncName.contains (" __spirv_" ))
1851
+ if (FuncName.contains (" __spirv_" ) &&
1852
+ !FuncName.contains (" __spirv_GroupAsyncCopy" ))
1837
1853
I.setNoSanitizeMetadata ();
1838
1854
}
1839
1855
}
@@ -1843,6 +1859,55 @@ static void setNoSanitizedMetadataSPIR(Instruction &I) {
1843
1859
I.setNoSanitizeMetadata ();
1844
1860
}
1845
1861
1862
+ // This is not a general-purpose function, but a helper for demangling
1863
+ // "__spirv_GroupAsyncCopy" function name
1864
+ static int getTypeSizeFromManglingName (StringRef Name) {
1865
+ auto GetTypeSize = [](const char C) {
1866
+ switch (C) {
1867
+ case ' a' : // signed char
1868
+ case ' c' : // char
1869
+ return 1 ;
1870
+ case ' s' : // short
1871
+ return 2 ;
1872
+ case ' f' : // float
1873
+ case ' i' : // int
1874
+ return 4 ;
1875
+ case ' d' : // double
1876
+ case ' l' : // long
1877
+ return 8 ;
1878
+ default :
1879
+ return 0 ;
1880
+ }
1881
+ };
1882
+
1883
+ // Name should always be long enough since it has other unmeaningful chars,
1884
+ // it should have at least 6 chars, such as "Dv16_d"
1885
+ if (Name.size () < 6 )
1886
+ return 0 ;
1887
+
1888
+ // 1. Basic type
1889
+ if (Name[0 ] != ' D' )
1890
+ return GetTypeSize (Name[0 ]);
1891
+
1892
+ // 2. Vector type
1893
+
1894
+ // Drop "Dv"
1895
+ assert (Name[0 ] == ' D' && Name[1 ] == ' v' &&
1896
+ " Invalid mangling name for vector type" );
1897
+ Name = Name.drop_front (2 );
1898
+
1899
+ // Vector length
1900
+ assert (isDigit (Name[0 ]) && " Invalid mangling name for vector type" );
1901
+ int Len = std::stoi (Name.str ());
1902
+ Name = Name.drop_front (Len >= 10 ? 2 : 1 );
1903
+
1904
+ assert (Name[0 ] == ' _' && " Invalid mangling name for vector type" );
1905
+ Name = Name.drop_front (1 );
1906
+
1907
+ int Size = GetTypeSize (Name[0 ]);
1908
+ return Len * Size;
1909
+ }
1910
+
1846
1911
namespace {
1847
1912
1848
1913
// / Helper class to attach debug information of the given instruction onto new
@@ -6395,6 +6460,41 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
6395
6460
VAHelper->visitCallBase (CB, IRB);
6396
6461
}
6397
6462
6463
+ if (SpirOrSpirv) {
6464
+ auto *Func = CB.getCalledFunction ();
6465
+ if (Func) {
6466
+ auto FuncName = Func->getName ();
6467
+ if (FuncName.contains (" __spirv_GroupAsyncCopy" )) {
6468
+ // clang-format off
6469
+ // Handle functions like "_Z22__spirv_GroupAsyncCopyiPU3AS3dPU3AS1dllP13__spirv_Event",
6470
+ // its demangled name is "__spirv_GroupAsyncCopy(int, double AS3* dst, double AS1* src, long, long, __spirv_Event*)"
6471
+ // The type of "src" and "dst" should always be same.
6472
+ // clang-format on
6473
+
6474
+ auto *Dest = CB.getArgOperand (1 );
6475
+ auto *Src = CB.getArgOperand (2 );
6476
+ auto *NumElements = CB.getArgOperand (3 );
6477
+ auto *Stride = CB.getArgOperand (4 );
6478
+
6479
+ // Skip "_Z22__spirv_GroupAsyncCopyiPU3AS3" (33 char), get the size of
6480
+ // parameter type directly
6481
+ const size_t kManglingPrefixLength = 33 ;
6482
+ int ElementSize = getTypeSizeFromManglingName (
6483
+ FuncName.substr (kManglingPrefixLength ));
6484
+ assert (ElementSize != 0 &&
6485
+ " Unsupported __spirv_GroupAsyncCopy element type" );
6486
+
6487
+ IRB.CreateCall (
6488
+ MS.Spirv .MsanUnpoisonStridedCopyFunc ,
6489
+ {IRB.CreatePointerCast (Dest, MS.Spirv .IntptrTy ),
6490
+ IRB.getInt32 (Dest->getType ()->getPointerAddressSpace ()),
6491
+ IRB.CreatePointerCast (Src, MS.Spirv .IntptrTy ),
6492
+ IRB.getInt32 (Src->getType ()->getPointerAddressSpace ()),
6493
+ IRB.getInt32 (ElementSize), NumElements, Stride});
6494
+ }
6495
+ }
6496
+ }
6497
+
6398
6498
// Now, get the shadow for the RetVal.
6399
6499
if (!CB.getType ()->isSized ())
6400
6500
return ;
0 commit comments