Skip to content

Commit afc0c91

Browse files
authored
[DevTSAN] Ignore check for joint matrix access (#18773)
we can't know exactly what size of joint matrix type is, so we need to ignore it.
1 parent d26cdad commit afc0c91

File tree

2 files changed

+71
-32
lines changed

2 files changed

+71
-32
lines changed

llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
#include "llvm/Support/Debug.h"
4444
#include "llvm/Support/Path.h"
4545
#include "llvm/Support/raw_ostream.h"
46+
#include "llvm/Transforms/Instrumentation/SPIRVSanitizerCommonUtils.h"
4647
#include "llvm/Transforms/Utils/EscapeEnumerator.h"
4748
#include "llvm/Transforms/Utils/Instrumentation.h"
4849
#include "llvm/Transforms/Utils/Local.h"
@@ -52,13 +53,6 @@ using namespace llvm;
5253

5354
#define DEBUG_TYPE "tsan"
5455

55-
// Spir memory address space
56-
static constexpr unsigned kSpirOffloadPrivateAS = 0;
57-
static constexpr unsigned kSpirOffloadGlobalAS = 1;
58-
static constexpr unsigned kSpirOffloadConstantAS = 2;
59-
static constexpr unsigned kSpirOffloadLocalAS = 3;
60-
static constexpr unsigned kSpirOffloadGenericAS = 4;
61-
6256
static cl::opt<bool> ClInstrumentMemoryAccesses(
6357
"tsan-instrument-memory-accesses", cl::init(true),
6458
cl::desc("Instrument memory accesses"), cl::Hidden);
@@ -127,6 +121,8 @@ struct ThreadSanitizerOnSpirv {
127121

128122
void appendDebugInfoToArgs(Instruction *I, SmallVectorImpl<Value *> &Args);
129123

124+
bool isUnsupportedSPIRAccess(Value *Addr, Instruction *Inst);
125+
130126
private:
131127
void instrumentGlobalVariables();
132128

@@ -383,6 +379,38 @@ void ThreadSanitizerOnSpirv::appendDebugInfoToArgs(
383379
Args.push_back(ConstantExpr::getPointerCast(FuncNameGV, ConstASPtrTy));
384380
}
385381

382+
bool ThreadSanitizerOnSpirv::isUnsupportedSPIRAccess(Value *Addr,
383+
Instruction *Inst) {
384+
auto *OrigValue = getUnderlyingObject(Addr);
385+
if (OrigValue->getName().starts_with("__spirv_BuiltIn"))
386+
return true;
387+
388+
// Ignore load/store for target ext type since we can't know exactly what size
389+
// it is.
390+
if (auto *SI = dyn_cast<StoreInst>(Inst))
391+
if (getTargetExtType(SI->getValueOperand()->getType()) ||
392+
isJointMatrixAccess(SI->getPointerOperand()))
393+
return true;
394+
395+
if (auto *LI = dyn_cast<LoadInst>(Inst))
396+
if (getTargetExtType(Inst->getType()) ||
397+
isJointMatrixAccess(LI->getPointerOperand()))
398+
return true;
399+
400+
auto AddrAS = cast<PointerType>(Addr->getType()->getScalarType())
401+
->getPointerAddressSpace();
402+
switch (AddrAS) {
403+
case kSpirOffloadPrivateAS:
404+
case kSpirOffloadLocalAS:
405+
case kSpirOffloadConstantAS:
406+
return true;
407+
case kSpirOffloadGlobalAS:
408+
case kSpirOffloadGenericAS:
409+
return false;
410+
}
411+
return false;
412+
}
413+
386414
bool ThreadSanitizerOnSpirv::isSupportedSPIRKernel(Function &F) {
387415

388416
if (!F.hasFnAttribute(Attribute::SanitizeThread) ||
@@ -709,30 +737,12 @@ static bool shouldInstrumentReadWriteFromAddress(const Module *M, Value *Addr) {
709737
}
710738
}
711739

712-
if (Triple(M->getTargetTriple()).isSPIROrSPIRV()) {
713-
auto *OrigValue = getUnderlyingObject(Addr);
714-
if (OrigValue->getName().starts_with("__spirv_BuiltIn"))
740+
// Do not instrument accesses from different address spaces; we cannot deal
741+
// with them.
742+
if (Addr) {
743+
Type *PtrTy = cast<PointerType>(Addr->getType()->getScalarType());
744+
if (PtrTy->getPointerAddressSpace() != 0)
715745
return false;
716-
717-
auto AddrAS = cast<PointerType>(Addr->getType()->getScalarType())
718-
->getPointerAddressSpace();
719-
switch (AddrAS) {
720-
case kSpirOffloadPrivateAS:
721-
case kSpirOffloadLocalAS:
722-
case kSpirOffloadConstantAS:
723-
return false;
724-
case kSpirOffloadGlobalAS:
725-
case kSpirOffloadGenericAS:
726-
return true;
727-
}
728-
} else {
729-
// Do not instrument accesses from different address spaces; we cannot deal
730-
// with them.
731-
if (Addr) {
732-
Type *PtrTy = cast<PointerType>(Addr->getType()->getScalarType());
733-
if (PtrTy->getPointerAddressSpace() != 0)
734-
return false;
735-
}
736746
}
737747

738748
return true;
@@ -781,7 +791,10 @@ void ThreadSanitizer::chooseInstructionsToInstrument(
781791
Value *Addr = IsWrite ? cast<StoreInst>(I)->getPointerOperand()
782792
: cast<LoadInst>(I)->getPointerOperand();
783793

784-
if (!shouldInstrumentReadWriteFromAddress(I->getModule(), Addr))
794+
if (Spirv) {
795+
if (Spirv->isUnsupportedSPIRAccess(Addr, I))
796+
continue;
797+
} else if (!shouldInstrumentReadWriteFromAddress(I->getModule(), Addr))
785798
continue;
786799

787800
if (!IsWrite) {
@@ -890,7 +903,8 @@ bool ThreadSanitizer::sanitizeFunction(Function &F,
890903
else if (isa<LoadInst>(Inst) || isa<StoreInst>(Inst))
891904
LocalLoadsAndStores.push_back(&Inst);
892905
else if (Spirv && isa<AllocaInst>(Inst) &&
893-
cast<AllocaInst>(Inst).getAllocatedType()->isSized())
906+
cast<AllocaInst>(Inst).getAllocatedType()->isSized() &&
907+
!getTargetExtType(cast<AllocaInst>(Inst).getAllocatedType()))
894908
Allocas.push_back(&Inst);
895909
else if ((isa<CallInst>(Inst) && !isa<DbgInfoIntrinsic>(Inst)) ||
896910
isa<InvokeInst>(Inst)) {
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
; RUN: opt < %s -passes='function(tsan),module(tsan-module)' -tsan-instrument-func-entry-exit=0 -tsan-instrument-memintrinsics=0 -S | FileCheck %s
2+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1"
3+
target triple = "spir64-unknown-unknown"
4+
5+
%"class.sycl::_V1::ext::oneapi::bfloat16" = type { i16 }
6+
%"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix" = type { target("spirv.CooperativeMatrixKHR", i16, 3, 16, 32, 0) }
7+
8+
declare dso_local spir_func noundef ptr addrspace(4) @_Z19__spirv_AccessChainIN4sycl3_V13ext6oneapi8bfloat16ES4_Lm16ELm32ELN5__spv9MatrixUseE0ELNS5_5Scope4FlagE3EEPT_PPNS5_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef, i64 noundef)
9+
10+
define weak_odr dso_local spir_kernel void @test() {
11+
; CHECK-LABEL: void @test
12+
; CHECK-NOT: call void @__tsan_write
13+
; CHECK-NOT: ptrtoint ptr %sub_a.i to i64
14+
entry:
15+
%sub_a.i = alloca %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix", align 8
16+
%element.i = alloca %"class.sycl::_V1::ext::oneapi::bfloat16", align 2
17+
%0 = getelementptr inbounds { i16 }, ptr %element.i, i64 0, i32 0
18+
%spvm.i = getelementptr inbounds %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix", ptr %sub_a.i, i64 0, i32 0
19+
%addrcast = addrspacecast ptr %spvm.i to ptr addrspace(4)
20+
%call.i67 = call spir_func noundef ptr addrspace(4) @_Z19__spirv_AccessChainIN4sycl3_V13ext6oneapi8bfloat16ES4_Lm16ELm32ELN5__spv9MatrixUseE0ELNS5_5Scope4FlagE3EEPT_PPNS5_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %addrcast, i64 1)
21+
%gep = getelementptr inbounds nuw { i16 }, ptr addrspace(4) %call.i67, i64 0, i32 0
22+
%val = load i16, ptr %0, align 2
23+
store i16 %val, ptr addrspace(4) %gep, align 2
24+
ret void
25+
}

0 commit comments

Comments
 (0)