Skip to content

Commit df4bd05

Browse files
authored
[DeviceMSAN] Fix false negative report due to __spirv_GroupAsyncCopy (#18216)
"__spirv_GroupAsyncCopy" is used to copy data between local and global buffer, we need to sync the shadow value of them. I add a new function "__msan_unpoison_strided_copy" for this sync.
1 parent e7ab07d commit df4bd05

File tree

3 files changed

+173
-3
lines changed

3 files changed

+173
-3
lines changed

libdevice/sanitizer/msan_rtl.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,16 @@ inline void __msan_exit() {
219219
__devicelib_exit();
220220
}
221221

222+
// This function is only used for shadow propagation
223+
template <typename T>
224+
void GroupAsyncCopy(uptr Dest, uptr Src, size_t NumElements, size_t Stride) {
225+
auto DestPtr = (__SYCL_GLOBAL__ T *)Dest;
226+
auto SrcPtr = (const __SYCL_GLOBAL__ T *)Src;
227+
for (size_t i = 0; i < NumElements; i++) {
228+
DestPtr[i] = SrcPtr[i * Stride];
229+
}
230+
}
231+
222232
} // namespace
223233

224234
#define MSAN_MAYBE_WARNING(type, size) \
@@ -590,4 +600,41 @@ __msan_set_private_base(__SYCL_PRIVATE__ void *ptr) {
590600
MSAN_DEBUG(__spirv_ocl_printf(__msan_print_private_base, sid, ptr));
591601
}
592602

603+
static __SYCL_CONSTANT__ const char __msan_print_strided_copy_unsupport_type[] =
604+
"[kernel] __msan_unpoison_strided_copy: unsupported type(%d)\n";
605+
606+
DEVICE_EXTERN_C_NOINLINE void
607+
__msan_unpoison_strided_copy(uptr dest, uint32_t dest_as, uptr src,
608+
uint32_t src_as, uint32_t element_size,
609+
uptr counts, uptr stride) {
610+
if (!GetMsanLaunchInfo)
611+
return;
612+
613+
MSAN_DEBUG(__spirv_ocl_printf(__msan_print_func_beg,
614+
"__msan_unpoison_strided_copy"));
615+
616+
uptr shadow_dest = (uptr)__msan_get_shadow(dest, dest_as);
617+
uptr shadow_src = (uptr)__msan_get_shadow(src, src_as);
618+
619+
switch (element_size) {
620+
case 1:
621+
GroupAsyncCopy<int8_t>(shadow_dest, shadow_src, counts, stride);
622+
break;
623+
case 2:
624+
GroupAsyncCopy<int16_t>(shadow_dest, shadow_src, counts, stride);
625+
break;
626+
case 4:
627+
GroupAsyncCopy<int32_t>(shadow_dest, shadow_src, counts, stride);
628+
break;
629+
case 8:
630+
GroupAsyncCopy<int64_t>(shadow_dest, shadow_src, counts, stride);
631+
break;
632+
default:
633+
__spirv_ocl_printf(__msan_print_strided_copy_unsupport_type, element_size);
634+
}
635+
636+
MSAN_DEBUG(__spirv_ocl_printf(__msan_print_func_end,
637+
"__msan_unpoison_strided_copy"));
638+
}
639+
593640
#endif // __SPIR__ || __SPIRV__

llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp

Lines changed: 103 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,8 @@ class MemorySanitizerOnSpirv {
806806
void initializeKernelCallerMap(Function *F);
807807

808808
private:
809+
friend struct MemorySanitizerVisitor;
810+
809811
Module &M;
810812
LLVMContext &C;
811813
const DataLayout &DL;
@@ -833,6 +835,7 @@ class MemorySanitizerOnSpirv {
833835
FunctionCallee MsanBarrierFunc;
834836
FunctionCallee MsanUnpoisonStackFunc;
835837
FunctionCallee MsanSetPrivateBaseFunc;
838+
FunctionCallee MsanUnpoisonStridedCopyFunc;
836839
};
837840

838841
} // end anonymous namespace
@@ -899,14 +902,14 @@ void MemorySanitizerOnSpirv::initializeCallbacks() {
899902
M.getOrInsertFunction("__msan_unpoison_shadow_static_local",
900903
IRB.getVoidTy(), IntptrTy, IntptrTy);
901904

902-
// __asan_poison_shadow_dynamic_local(
905+
// __msan_poison_shadow_dynamic_local(
903906
// uptr ptr,
904907
// uint32_t num_args
905908
// )
906909
MsanPoisonShadowDynamicLocalFunc = M.getOrInsertFunction(
907910
"__msan_poison_shadow_dynamic_local", IRB.getVoidTy(), IntptrTy, Int32Ty);
908911

909-
// __asan_unpoison_shadow_dynamic_local(
912+
// __msan_unpoison_shadow_dynamic_local(
910913
// uptr ptr,
911914
// uint32_t num_args
912915
// )
@@ -930,6 +933,18 @@ void MemorySanitizerOnSpirv::initializeCallbacks() {
930933
MsanSetPrivateBaseFunc =
931934
M.getOrInsertFunction("__msan_set_private_base", IRB.getVoidTy(),
932935
PointerType::get(C, kSpirOffloadPrivateAS));
936+
937+
// __msan_unpoison_strided_copy(
938+
// uptr dest, uint32_t dest_as,
939+
// uptr src, uint32_t src_as,
940+
// uint32_t element_size,
941+
// uptr counts,
942+
// uptr stride
943+
// )
944+
MsanUnpoisonStridedCopyFunc = M.getOrInsertFunction(
945+
"__msan_unpoison_strided_copy", IRB.getVoidTy(), IntptrTy,
946+
IRB.getInt32Ty(), IntptrTy, IRB.getInt32Ty(), IRB.getInt32Ty(),
947+
IRB.getInt64Ty(), IRB.getInt64Ty());
933948
}
934949

935950
// Handle global variables:
@@ -1833,7 +1848,8 @@ static void setNoSanitizedMetadataSPIR(Instruction &I) {
18331848
}
18341849
} else {
18351850
auto FuncName = Func->getName();
1836-
if (FuncName.contains("__spirv_"))
1851+
if (FuncName.contains("__spirv_") &&
1852+
!FuncName.contains("__spirv_GroupAsyncCopy"))
18371853
I.setNoSanitizeMetadata();
18381854
}
18391855
}
@@ -1843,6 +1859,55 @@ static void setNoSanitizedMetadataSPIR(Instruction &I) {
18431859
I.setNoSanitizeMetadata();
18441860
}
18451861

1862+
// This is not a general-purpose function, but a helper for demangling
1863+
// "__spirv_GroupAsyncCopy" function name
1864+
static int getTypeSizeFromManglingName(StringRef Name) {
1865+
auto GetTypeSize = [](const char C) {
1866+
switch (C) {
1867+
case 'a': // signed char
1868+
case 'c': // char
1869+
return 1;
1870+
case 's': // short
1871+
return 2;
1872+
case 'f': // float
1873+
case 'i': // int
1874+
return 4;
1875+
case 'd': // double
1876+
case 'l': // long
1877+
return 8;
1878+
default:
1879+
return 0;
1880+
}
1881+
};
1882+
1883+
// Name should always be long enough since it has other unmeaningful chars,
1884+
// it should have at least 6 chars, such as "Dv16_d"
1885+
if (Name.size() < 6)
1886+
return 0;
1887+
1888+
// 1. Basic type
1889+
if (Name[0] != 'D')
1890+
return GetTypeSize(Name[0]);
1891+
1892+
// 2. Vector type
1893+
1894+
// Drop "Dv"
1895+
assert(Name[0] == 'D' && Name[1] == 'v' &&
1896+
"Invalid mangling name for vector type");
1897+
Name = Name.drop_front(2);
1898+
1899+
// Vector length
1900+
assert(isDigit(Name[0]) && "Invalid mangling name for vector type");
1901+
int Len = std::stoi(Name.str());
1902+
Name = Name.drop_front(Len >= 10 ? 2 : 1);
1903+
1904+
assert(Name[0] == '_' && "Invalid mangling name for vector type");
1905+
Name = Name.drop_front(1);
1906+
1907+
int Size = GetTypeSize(Name[0]);
1908+
return Len * Size;
1909+
}
1910+
18461911
namespace {
18471912

18481913
/// Helper class to attach debug information of the given instruction onto new
@@ -6395,6 +6460,41 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
63956460
VAHelper->visitCallBase(CB, IRB);
63966461
}
63976462

6463+
if (SpirOrSpirv) {
6464+
auto *Func = CB.getCalledFunction();
6465+
if (Func) {
6466+
auto FuncName = Func->getName();
6467+
if (FuncName.contains("__spirv_GroupAsyncCopy")) {
6468+
// clang-format off
6469+
// Handle functions like "_Z22__spirv_GroupAsyncCopyiPU3AS3dPU3AS1dllP13__spirv_Event",
6470+
// its demangled name is "__spirv_GroupAsyncCopy(int, double AS3* dst, double AS1* src, long, long, __spirv_Event*)"
6471+
// The type of "src" and "dst" should always be same.
6472+
// clang-format on
6473+
6474+
auto *Dest = CB.getArgOperand(1);
6475+
auto *Src = CB.getArgOperand(2);
6476+
auto *NumElements = CB.getArgOperand(3);
6477+
auto *Stride = CB.getArgOperand(4);
6478+
6479+
// Skip "_Z22__spirv_GroupAsyncCopyiPU3AS3" (33 char), get the size of
6480+
// parameter type directly
6481+
const size_t kManglingPrefixLength = 33;
6482+
int ElementSize = getTypeSizeFromManglingName(
6483+
FuncName.substr(kManglingPrefixLength));
6484+
assert(ElementSize != 0 &&
6485+
"Unsupported __spirv_GroupAsyncCopy element type");
6486+
6487+
IRB.CreateCall(
6488+
MS.Spirv.MsanUnpoisonStridedCopyFunc,
6489+
{IRB.CreatePointerCast(Dest, MS.Spirv.IntptrTy),
6490+
IRB.getInt32(Dest->getType()->getPointerAddressSpace()),
6491+
IRB.CreatePointerCast(Src, MS.Spirv.IntptrTy),
6492+
IRB.getInt32(Src->getType()->getPointerAddressSpace()),
6493+
IRB.getInt32(ElementSize), NumElements, Stride});
6494+
}
6495+
}
6496+
}
6497+
63986498
// Now, get the shadow for the RetVal.
63996499
if (!CB.getType()->isSized())
64006500
return;
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
; RUN: opt < %s -passes=msan -msan-instrumentation-with-call-threshold=0 -msan-eager-checks=1 -msan-poison-stack-with-call=1 -S | FileCheck %s
2+
3+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
4+
target triple = "spir64-unknown-unknown"
5+
6+
declare spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyiPU3AS3iPU3AS1immP13__spirv_Event(i32, ptr addrspace(3), ptr addrspace(1), i64, i64, target("spirv.Event")) nounwind
7+
declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32, ptr addrspace(1), ptr addrspace(3), i64, i64, target("spirv.Event"))
8+
declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv4_aPU3AS1KS_mm9ocl_event(i32, ptr addrspace(3), ptr addrspace(1), i64, i64, target("spirv.Event"))
9+
10+
define spir_kernel void @kernel(ptr addrspace(3) %_arg_localAcc, ptr addrspace(1) %_arg_globalAcc) sanitize_memory {
11+
entry:
12+
; CHECK: @__msan_barrier()
13+
; CHECK: [[REG1:%[0-9]+]] = ptrtoint ptr addrspace(3) %_arg_localAcc to i64
14+
; CHECK-NEXT: [[REG2:%[0-9]+]] = ptrtoint ptr addrspace(1) %_arg_globalAcc to i64
15+
; CHECK-NEXT: call void @__msan_unpoison_strided_copy(i64 [[REG1]], i32 3, i64 [[REG2]], i32 1, i32 4, i64 512, i64 1)
16+
%copy = call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyiPU3AS3iPU3AS1immP13__spirv_Event(i32 2, ptr addrspace(3) %_arg_localAcc, ptr addrspace(1) %_arg_globalAcc, i64 512, i64 1, target("spirv.Event") zeroinitializer)
17+
18+
; CHECK: __msan_unpoison_strided_copy
19+
%copy2 = call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32 2, ptr addrspace(1) %_arg_globalAcc, ptr addrspace(3) %_arg_localAcc, i64 512, i64 1, target("spirv.Event") zeroinitializer)
20+
; CHECK: __msan_unpoison_strided_copy
21+
%copy3 = call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv4_aPU3AS1KS_mm9ocl_event(i32 2, ptr addrspace(3) %_arg_localAcc, ptr addrspace(1) %_arg_globalAcc, i64 512, i64 1, target("spirv.Event") zeroinitializer)
22+
ret void
23+
}

0 commit comments

Comments
 (0)