Skip to content

Commit 036fff0

Browse files
authored
[DeviceSanitizer] Ignore load/store from joint matrix AccessChain (#15907)
The load/store is from sycl header. Skip them since sanitizer is mostly interested in user code. This avoids backend compiler from removing the unneeded instrumentation when lowering joint matrix access.
1 parent 20775ab commit 036fff0

File tree

2 files changed

+60
-19
lines changed

2 files changed

+60
-19
lines changed

llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,23 +1506,45 @@ static bool isUnsupportedAMDGPUAddrspace(Value *Addr) {
15061506
return false;
15071507
}
15081508

1509-
static bool containsTargetExtType(const Type *Ty) {
1510-
if (isa<TargetExtType>(Ty))
1511-
return true;
1509+
static TargetExtType *getTargetExtType(Type *Ty) {
1510+
if (auto *TargetTy = dyn_cast<TargetExtType>(Ty))
1511+
return TargetTy;
15121512

15131513
if (Ty->isVectorTy())
1514-
return containsTargetExtType(Ty->getScalarType());
1514+
return getTargetExtType(Ty->getScalarType());
15151515

15161516
if (Ty->isArrayTy())
1517-
return containsTargetExtType(Ty->getArrayElementType());
1517+
return getTargetExtType(Ty->getArrayElementType());
15181518

15191519
if (auto *STy = dyn_cast<StructType>(Ty)) {
15201520
for (unsigned int i = 0; i < STy->getNumElements(); i++)
1521-
if (containsTargetExtType(STy->getElementType(i)))
1522-
return true;
1523-
return false;
1521+
if (auto *TargetTy = getTargetExtType(STy->getElementType(i)))
1522+
return TargetTy;
1523+
return nullptr;
15241524
}
15251525

1526+
return nullptr;
1527+
}
1528+
1529+
// Skip pointer operand that is sycl joint matrix access since it isn't from
1530+
// user code, e.g. %call:
1531+
// clang-format off
1532+
// %a = alloca %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix", align 8
1533+
// %0 = getelementptr inbounds %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix", ptr %a, i64 0, i32 0
1534+
// %call = call spir_func ptr
1535+
// @_Z19__spirv_AccessChainIfN4sycl3_V13ext6oneapi12experimental6matrix9precision4tf32ELm8ELm8ELN5__spv9MatrixUseE0ELNS8_5Scope4FlagE3EEPT_PPNS8_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr %0, i64 0)
1536+
// %1 = load float, ptr %call, align 4
1537+
// store float %1, ptr %call, align 4
1538+
// clang-format on
1539+
static bool isJointMatrixAccess(Value *V) {
1540+
if (auto *CI = dyn_cast<CallInst>(V)) {
1541+
for (Value *Op : CI->args()) {
1542+
if (auto *AI = dyn_cast<AllocaInst>(Op->stripInBoundsOffsets()))
1543+
if (auto *TargetTy = getTargetExtType(AI->getAllocatedType()))
1544+
return TargetTy->getName().startswith("spirv.") &&
1545+
TargetTy->getName().contains("Matrix");
1546+
}
1547+
}
15261548
return false;
15271549
}
15281550

@@ -1534,13 +1556,15 @@ static bool isUnsupportedSPIRAccess(Value *Addr, Instruction *Inst) {
15341556

15351557
// Ignore load/store for target ext type since we can't know exactly what size
15361558
// it is.
1537-
if (isa<StoreInst>(Inst) &&
1538-
containsTargetExtType(
1539-
cast<StoreInst>(Inst)->getValueOperand()->getType()))
1540-
return true;
1559+
if (auto *SI = dyn_cast<StoreInst>(Inst))
1560+
if (getTargetExtType(SI->getValueOperand()->getType()) ||
1561+
isJointMatrixAccess(SI->getPointerOperand()))
1562+
return true;
15411563

1542-
if (isa<LoadInst>(Inst) && containsTargetExtType(Inst->getType()))
1543-
return true;
1564+
if (auto *LI = dyn_cast<LoadInst>(Inst))
1565+
if (getTargetExtType(Inst->getType()) ||
1566+
isJointMatrixAccess(LI->getPointerOperand()))
1567+
return true;
15441568

15451569
Type *PtrTy = cast<PointerType>(Addr->getType()->getScalarType());
15461570
switch (PtrTy->getPointerAddressSpace()) {
@@ -1789,7 +1813,7 @@ bool AddressSanitizer::isInterestingAlloca(const AllocaInst &AI) {
17891813
!(SSGI && SSGI->isSafe(AI)) &&
17901814
// ignore alloc contains target ext type since we can't know exactly what
17911815
// size it is.
1792-
!containsTargetExtType(AI.getAllocatedType()));
1816+
!getTargetExtType(AI.getAllocatedType()));
17931817

17941818
ProcessedAllocas[&AI] = IsInteresting;
17951819
return IsInteresting;

llvm/test/Instrumentation/AddressSanitizer/SPIRV/ignore_target_ext_type.ll

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,33 @@ target triple = "spir64-unknown-unknown"
55

66
%"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix" = type { target("spirv.JointMatrixINTEL", i16, 16, 32, 0, 3, 0, 1) }
77

8-
define spir_kernel void @_ZTS4multIN4sycl3_V13ext6oneapi8bfloat16ELm16ELm16ELm32EE() {
8+
define spir_kernel void @_ZTS4multIN4sycl3_V13ext6oneapi8bfloat16ELm16ELm16ELm32EE() sanitize_address {
99
entry:
10+
; CHECK-LABEL: @_ZTS4multIN4sycl3_V13ext6oneapi8bfloat16ELm16ELm16ELm32EE
1011
; CHECK-NOT: MyAlloc
11-
%sub_a.i = alloca [2 x %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix"], i32 0, align 8
12+
%a = alloca [2 x %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix"], i32 0, align 8
1213
br label %for.cond10.i
1314

1415
for.cond10.i: ; preds = %for.cond10.i, %entry
1516
%0 = load target("spirv.JointMatrixINTEL", i16, 16, 32, 0, 3, 0, 1), ptr null, align 8
1617
store target("spirv.JointMatrixINTEL", float, 16, 16, 3, 3, 2) zeroinitializer, ptr null, align 8
17-
; CHECK-NOT: asan_load
18-
; CHECK-NOT: asan_store
18+
; CHECK-NOT: call void @asan_load
19+
; CHECK-NOT: call void @asan_store
1920
br label %for.cond10.i
2021
}
22+
23+
define spir_kernel void @AccessChain() sanitize_address {
24+
entry:
25+
; CHECK-LABEL: @AccessChain
26+
%a = alloca %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix", align 8
27+
%0 = getelementptr inbounds %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix", ptr %a, i64 0, i32 0
28+
%call.i35 = call spir_func ptr @_Z19__spirv_AccessChainIfN4sycl3_V13ext6oneapi12experimental6matrix9precision4tf32ELm8ELm8ELN5__spv9MatrixUseE0ELNS8_5Scope4FlagE3EEPT_PPNS8_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr %0, i64 0)
29+
; CHECK-NOT: call void @__asan_load
30+
; CHECK-NOT: call void @__asan_store
31+
%1 = load float, ptr %call.i35, align 4
32+
%call.i42 = call spir_func ptr @_Z19__spirv_AccessChainIfN4sycl3_V13ext6oneapi12experimental6matrix9precision4tf32ELm8ELm8ELN5__spv9MatrixUseE0ELNS8_5Scope4FlagE3EEPT_PPNS8_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr %0, i64 0)
33+
store float %1, ptr %call.i42, align 4
34+
ret void
35+
}
36+
37+
declare spir_func ptr @_Z19__spirv_AccessChainIfN4sycl3_V13ext6oneapi12experimental6matrix9precision4tf32ELm8ELm8ELN5__spv9MatrixUseE0ELNS8_5Scope4FlagE3EEPT_PPNS8_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr, i64)

0 commit comments

Comments
 (0)