Skip to content

Commit aeddb8d

Browse files
authored
[SYCL][Matrix] Extend W/A for more corner cases of AccessChain usage (#16370)
The new corner case is: AccessChain is used on arrays of Joint Matrices Fix for CMPLRLLVM-64465
1 parent a813b55 commit aeddb8d

File tree

2 files changed

+168
-29
lines changed

2 files changed

+168
-29
lines changed

llvm/lib/SYCLLowerIR/SYCLJointMatrixTransform.cpp

Lines changed: 113 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,74 @@ namespace {
2222
static constexpr char ACCESS_CHAIN[] = "_Z19__spirv_AccessChain";
2323
static constexpr char MATRIX_TYPE[] = "spirv.CooperativeMatrixKHR";
2424

25+
Type *getInnermostType(Type *Ty) {
26+
while (auto *ArrayTy = dyn_cast<ArrayType>(Ty))
27+
Ty = ArrayTy->getElementType();
28+
return Ty;
29+
}
30+
31+
Type *replaceInnermostType(Type *Ty, Type *NewInnermostTy) {
32+
if (auto *ArrayTy = dyn_cast<ArrayType>(Ty))
33+
return ArrayType::get(
34+
replaceInnermostType(ArrayTy->getElementType(), NewInnermostTy),
35+
ArrayTy->getNumElements());
36+
return NewInnermostTy;
37+
}
38+
39+
// This function is a copy of stripPointerCastsAndOffsets from Value.cpp,
40+
// simplified and modified to strip non-zero GEP indices as well and also
41+
// find nearest GEP instruction.
42+
Value *stripPointerCastsAndOffsets(Value *V, bool StopOnGEP = false) {
43+
if (!V->getType()->isPointerTy())
44+
return V;
45+
46+
// Even though we don't look through PHI nodes, we could be called on an
47+
// instruction in an unreachable block, which may be on a cycle.
48+
SmallPtrSet<Value *, 4> Visited;
49+
50+
Visited.insert(V);
51+
do {
52+
if (auto *GEP = dyn_cast<GEPOperator>(V)) {
53+
if (StopOnGEP && isa<GetElementPtrInst>(GEP))
54+
return V;
55+
V = GEP->getPointerOperand();
56+
} else if (auto *BC = dyn_cast<BitCastOperator>(V)) {
57+
Value *NewV = BC->getOperand(0);
58+
if (!NewV->getType()->isPointerTy())
59+
return V;
60+
V = NewV;
61+
} else if (auto *ASC = dyn_cast<AddrSpaceCastOperator>(V)) {
62+
V = ASC->getOperand(0);
63+
} else {
64+
if (auto *Call = dyn_cast<CallBase>(V)) {
65+
if (Value *RV = Call->getReturnedArgOperand()) {
66+
V = RV;
67+
// Strip the call instruction, since callee returns its RV
68+
// argument as return value. So, we need to continue stripping.
69+
continue;
70+
}
71+
}
72+
return V;
73+
}
74+
assert(V->getType()->isPointerTy() && "Unexpected operand type!");
75+
} while (Visited.insert(V).second);
76+
77+
return V;
78+
}
79+
80+
TargetExtType *extractMatrixType(StructType *WrapperMatrixTy) {
81+
if (!WrapperMatrixTy)
82+
return nullptr;
83+
TargetExtType *MatrixTy =
84+
dyn_cast<TargetExtType>(WrapperMatrixTy->getElementType(0));
85+
86+
if (!MatrixTy)
87+
return nullptr;
88+
if (MatrixTy->getName() != MATRIX_TYPE)
89+
return nullptr;
90+
return MatrixTy;
91+
}
92+
2593
// This function finds all calls to __spirv_AccessChain function and transforms
2694
// its users and operands to make LLVM IR more SPIR-V friendly.
2795
bool transformAccessChain(Function *F) {
@@ -60,33 +128,59 @@ bool transformAccessChain(Function *F) {
60128
// from sycl::joint_matrix class object if it's used in __spirv_AccessChain
61129
// function call. It's necessary because otherwise OpAccessChain indices
62130
// would be wrong.
63-
Instruction *Ptr =
64-
dyn_cast<Instruction>(CI->getArgOperand(0)->stripPointerCasts());
131+
Instruction *Ptr = dyn_cast<Instruction>(
132+
stripPointerCastsAndOffsets(CI->getArgOperand(0)));
65133
if (!Ptr || !isa<AllocaInst>(Ptr))
66134
continue;
67-
StructType *WrapperMatrixTy =
68-
dyn_cast<StructType>(cast<AllocaInst>(Ptr)->getAllocatedType());
69-
if (!WrapperMatrixTy)
70-
continue;
71-
TargetExtType *MatrixTy =
72-
dyn_cast<TargetExtType>(WrapperMatrixTy->getElementType(0));
73-
if (!MatrixTy)
135+
136+
Type *AllocaTy = cast<AllocaInst>(Ptr)->getAllocatedType();
137+
// It may happen that sycl::joint_matrix class object is wrapped into
138+
// nested arrays. We need to find the innermost type to extract
139+
if (StructType *WrapperMatrixTy =
140+
dyn_cast<StructType>(getInnermostType(AllocaTy))) {
141+
TargetExtType *MatrixTy = extractMatrixType(WrapperMatrixTy);
142+
if (!MatrixTy)
143+
continue;
144+
145+
AllocaInst *Alloca = nullptr;
146+
{
147+
IRBuilder Builder(CI);
148+
IRBuilderBase::InsertPointGuard IG(Builder);
149+
Builder.SetInsertPointPastAllocas(CI->getFunction());
150+
Alloca = Builder.CreateAlloca(replaceInnermostType(AllocaTy, MatrixTy));
151+
Alloca->takeName(Ptr);
152+
}
153+
Ptr->replaceAllUsesWith(Alloca);
154+
Ptr->dropAllReferences();
155+
Ptr->eraseFromParent();
156+
ModuleChanged = true;
157+
}
158+
159+
// In case spirv.CooperativeMatrixKHR is used in arrays, we also need to
160+
// insert GEP to get pointer to target exention type and use it instead of
161+
// pointer to sycl::joint_matrix class object when it is passed to
162+
// __spirv_AccessChain
163+
// First we check if the argument came from a GEP instruction
164+
GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(
165+
stripPointerCastsAndOffsets(CI->getArgOperand(0), /*StopOnGEP=*/true));
166+
if (!GEP)
74167
continue;
75-
StringRef Name = MatrixTy->getName();
76-
if (Name != MATRIX_TYPE)
168+
169+
// Check if GEP return type is a pointer to sycl::joint_matrix class object
170+
StructType *WrapperMatrixTy =
171+
dyn_cast<StructType>(GEP->getResultElementType());
172+
if (!extractMatrixType(WrapperMatrixTy))
77173
continue;
78174

79-
AllocaInst *Alloca = nullptr;
175+
// Insert GEP right before the __spirv_AccessChain call
80176
{
81177
IRBuilder Builder(CI);
82-
IRBuilderBase::InsertPointGuard IG(Builder);
83-
Builder.SetInsertPointPastAllocas(CI->getFunction());
84-
Alloca = Builder.CreateAlloca(MatrixTy);
178+
Value *NewGEP =
179+
Builder.CreateInBoundsGEP(WrapperMatrixTy, CI->getArgOperand(0),
180+
{Builder.getInt64(0), Builder.getInt32(0)});
181+
CI->setArgOperand(0, NewGEP);
182+
ModuleChanged = true;
85183
}
86-
Ptr->replaceAllUsesWith(Alloca);
87-
Ptr->dropAllReferences();
88-
Ptr->eraseFromParent();
89-
ModuleChanged = true;
90184
}
91185
return ModuleChanged;
92186
}

llvm/test/SYCLLowerIR/JointMatrixTransform/access_chain.ll

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,69 @@
33

44
; RUN: opt -passes=sycl-joint-matrix-transform < %s -S | FileCheck %s
55

6-
; CHECK: %[[#Alloca:]] = alloca target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)
7-
; CHECK: %[[#Cast:]] = addrspacecast ptr %[[#Alloca]] to ptr addrspace(4)
8-
; CHECK: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef %[[#Cast]], i64 noundef 0)
9-
106
; ModuleID = 'test.bc'
117
source_filename = "test.cpp"
128
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1"
139
target triple = "spir64-unknown-unknown"
1410

15-
%"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix" = type { target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0) }
11+
%"struct.sycl::joint_matrix" = type { target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0) }
12+
%"struct.sycl::_V1::long" = type { i64 }
13+
14+
define weak_odr dso_local spir_kernel void @test(i64 %ind) {
15+
; CHECK-LABEL: define weak_odr dso_local spir_kernel void @test(
16+
; CHECK-SAME: i64 [[IND:%.*]]) {
17+
18+
; non-matrix alloca not touched
19+
; CHECK: [[NOT_MATR:%.*]] = alloca [2 x [4 x %"struct.sycl::_V1::long"]]
20+
; both matrix-related allocas updated to use target extension types
21+
; CHECK-NEXT: [[MATR:%.*]] = alloca target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)
22+
; CHECK-NEXT: [[MATR_ARR:%.*]] = alloca [2 x [4 x target("spirv.CooperativeMatrixKHR", i8, 3, 16, 64, 0)]]
23+
24+
; CHECK-NEXT: [[ASCAST:%.*]] = addrspacecast ptr [[MATR]] to ptr addrspace(4)
25+
; no gep inserted, since not needed
26+
; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[ASCAST]], i64 noundef 0)
27+
28+
; CHECK: [[GEP:%.*]] = getelementptr inbounds [2 x [4 x %"struct.sycl::joint_matrix"]], ptr [[MATR_ARR]], i64 0, i64 [[IND]], i64 [[IND]]
29+
; CHECK-NEXT: [[ASCAST_1:%.*]] = addrspacecast ptr [[GEP]] to ptr addrspace(4)
30+
; CHECK-NEXT: [[ASCAST_2:%.*]] = addrspacecast ptr [[GEP]] to ptr addrspace(4)
31+
; gep is inserted for each of the accesschain calls to extract target extension type
32+
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds %"struct.sycl::joint_matrix", ptr addrspace(4) [[ASCAST_1]], i64 0, i32 0
33+
; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[TMP2]], i64 noundef 0)
34+
; CHECK: [[TMP5:%.*]] = getelementptr inbounds %"struct.sycl::joint_matrix", ptr addrspace(4) [[ASCAST_2]], i64 0, i32 0
35+
; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[TMP5]], i64 noundef 0)
36+
37+
; negative test - not touching non-matrix code
38+
; CHECK: [[GEP_1:%.*]] = getelementptr inbounds [2 x [4 x %"struct.sycl::_V1::long"]], ptr [[NOT_MATR]], i64 0, i64 [[IND]], i64 [[IND]]
39+
; CHECK-NEXT: [[ASCAST_3:%.*]] = addrspacecast ptr [[GEP_1]] to ptr addrspace(4)
40+
; CHECK-NEXT: call spir_func ptr addrspace(4) @_Z19__spirv_AccessChain{{.*}}(ptr addrspace(4) noundef [[ASCAST_3]], i64 noundef 0)
1641

17-
define weak_odr dso_local spir_kernel void @test() {
1842
entry:
19-
%0 = alloca %"struct.sycl::_V1::ext::oneapi::experimental::matrix::joint_matrix", align 8
20-
%1 = addrspacecast ptr %0 to ptr addrspace(4)
21-
%2 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef %1, i64 noundef 0)
43+
; allocas
44+
%matr = alloca %"struct.sycl::joint_matrix", align 8
45+
%matr.arr = alloca [2 x [4 x %"struct.sycl::joint_matrix"]], align 8
46+
%not.matr = alloca [2 x [4 x %"struct.sycl::_V1::long"]], align 8
47+
48+
; simple case
49+
%ascast = addrspacecast ptr %matr to ptr addrspace(4)
50+
%0 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0(ptr addrspace(4) noundef %ascast, i64 noundef 0)
51+
%1 = load i8, ptr addrspace(4) %0
52+
53+
; gep with non-zero inidices and multiple access chains per 1 alloca
54+
%gep = getelementptr inbounds [2 x [4 x %"struct.sycl::joint_matrix"]], ptr %matr.arr, i64 0, i64 %ind, i64 %ind
55+
%ascast.1 = addrspacecast ptr %gep to ptr addrspace(4)
56+
%ascast.2 = addrspacecast ptr %gep to ptr addrspace(4)
57+
%2 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0(ptr addrspace(4) noundef %ascast.1, i64 noundef 0)
2258
%3 = load i8, ptr addrspace(4) %2
59+
%4 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0(ptr addrspace(4) noundef %ascast.2, i64 noundef 0)
60+
%5 = load i8, ptr addrspace(4) %4
61+
62+
; negative test - not touching non-matrix code
63+
%gep.1 = getelementptr inbounds [2 x [4 x %"struct.sycl::_V1::long"]], ptr %not.matr, i64 0, i64 %ind, i64 %ind
64+
%ascast.3 = addrspacecast ptr %gep.1 to ptr addrspace(4)
65+
%6 = call spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0(ptr addrspace(4) noundef %ascast.3, i64 noundef 0)
66+
%7 = load i8, ptr addrspace(4) %6
67+
2368
ret void
2469
}
2570

26-
declare dso_local spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0_XT4_EXT1_EXT2_EXT3_EEEm(ptr addrspace(4) noundef, i64 noundef)
71+
declare dso_local spir_func ptr addrspace(4) @_Z19__spirv_AccessChainIiiLm16ELm16ELN5__spv9MatrixUseE2ELNS0_5Scope4FlagE3EEPT_PPNS0_28__spirv_CooperativeMatrixKHRIT0(ptr addrspace(4) noundef, i64 noundef)

0 commit comments

Comments
 (0)