Skip to content

Commit 7f103ad

Browse files
authored
[SPIR-V] Add llvm.loop.unroll metadata lowering (#132062)
.enable lowers to Unroll LoopControl .disable lowers to DontUnroll LoopControl .count lowers to PartialCount LoopControl .full lowers to Unroll LoopControl TODO in future patches: enable structurizer for non-vulkan targets. --------- Signed-off-by: Sidorov, Dmitry <[email protected]>
1 parent 45c3fe8 commit 7f103ad

File tree

6 files changed

+298
-4
lines changed

6 files changed

+298
-4
lines changed

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2957,10 +2957,11 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
29572957
case Intrinsic::spv_loop_merge: {
29582958
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpLoopMerge));
29592959
for (unsigned i = 1; i < I.getNumExplicitOperands(); ++i) {
2960-
assert(I.getOperand(i).isMBB());
2961-
MIB.addMBB(I.getOperand(i).getMBB());
2960+
if (I.getOperand(i).isMBB())
2961+
MIB.addMBB(I.getOperand(i).getMBB());
2962+
else
2963+
MIB.addImm(foldImm(I.getOperand(i), MRI));
29622964
}
2963-
MIB.addImm(SPIRV::SelectionControl::None);
29642965
return MIB.constrainAllUses(TII, TRI, RBI);
29652966
}
29662967
case Intrinsic::spv_selection_merge: {

llvm/lib/Target/SPIRV/SPIRVStructurizer.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,10 @@ class SPIRVStructurizer : public FunctionPass {
611611
auto MergeAddress = BlockAddress::get(Merge->getParent(), Merge);
612612
auto ContinueAddress = BlockAddress::get(Continue->getParent(), Continue);
613613
SmallVector<Value *, 2> Args = {MergeAddress, ContinueAddress};
614-
614+
SmallVector<unsigned, 1> LoopControlImms =
615+
getSpirvLoopControlOperandsFromLoopMetadata(L);
616+
for (unsigned Imm : LoopControlImms)
617+
Args.emplace_back(llvm::ConstantInt::get(Builder.getInt32Ty(), Imm));
615618
Builder.CreateIntrinsic(Intrinsic::spv_loop_merge, {}, {Args});
616619
Modified = true;
617620
}

llvm/lib/Target/SPIRV/SPIRVUtils.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,33 @@ createContinuedInstructions(MachineIRBuilder &MIRBuilder, unsigned Opcode,
854854
return Instructions;
855855
}
856856

857+
SmallVector<unsigned, 1> getSpirvLoopControlOperandsFromLoopMetadata(Loop *L) {
858+
unsigned LC = SPIRV::LoopControl::None;
859+
// Currently used only to store PartialCount value. Later when other
860+
// LoopControls are added - this map should be sorted before making
861+
// them loop_merge operands to satisfy 3.23. Loop Control requirements.
862+
std::vector<std::pair<unsigned, unsigned>> MaskToValueMap;
863+
if (getBooleanLoopAttribute(L, "llvm.loop.unroll.disable")) {
864+
LC |= SPIRV::LoopControl::DontUnroll;
865+
} else {
866+
if (getBooleanLoopAttribute(L, "llvm.loop.unroll.enable") ||
867+
getBooleanLoopAttribute(L, "llvm.loop.unroll.full")) {
868+
LC |= SPIRV::LoopControl::Unroll;
869+
}
870+
std::optional<int> Count =
871+
getOptionalIntLoopAttribute(L, "llvm.loop.unroll.count");
872+
if (Count && Count != 1) {
873+
LC |= SPIRV::LoopControl::PartialCount;
874+
MaskToValueMap.emplace_back(
875+
std::make_pair(SPIRV::LoopControl::PartialCount, *Count));
876+
}
877+
}
878+
SmallVector<unsigned, 1> Result = {LC};
879+
for (auto &[Mask, Val] : MaskToValueMap)
880+
Result.push_back(Val);
881+
return Result;
882+
}
883+
857884
const std::set<unsigned> &getTypeFoldingSupportedOpcodes() {
858885
// clang-format off
859886
static const std::set<unsigned> TypeFoldingSupportingOpcs = {

llvm/lib/Target/SPIRV/SPIRVUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,9 @@ createContinuedInstructions(MachineIRBuilder &MIRBuilder, unsigned Opcode,
464464
const std::set<unsigned> &getTypeFoldingSupportedOpcodes();
465465
bool isTypeFoldingSupported(unsigned Opcode);
466466

467+
// Get loop controls from llvm.loop. metadata.
468+
SmallVector<unsigned, 1> getSpirvLoopControlOperandsFromLoopMetadata(Loop *L);
469+
467470
// Traversing [g]MIR accounting for pseudo-instructions.
468471
MachineInstr *passCopy(MachineInstr *Def, const MachineRegisterInfo *MRI);
469472
MachineInstr *getDef(const MachineOperand &MO, const MachineRegisterInfo *MRI);

llvm/lib/Transforms/Utils/BreakCriticalEdges.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ llvm::SplitKnownCriticalEdge(Instruction *TI, unsigned SuccNum,
175175
// Create our unconditional branch.
176176
BranchInst *NewBI = BranchInst::Create(DestBB, NewBB);
177177
NewBI->setDebugLoc(TI->getDebugLoc());
178+
if (auto *LoopMD = TI->getMetadata(LLVMContext::MD_loop))
179+
NewBI->setMetadata(LLVMContext::MD_loop, LoopMD);
178180

179181
// Insert the block into the function... right after the block TI lives in.
180182
Function &F = *TIBB->getParent();
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
; RUN: llc -mtriple=spirv-unknown-vulkan-compute -O0 -verify-machineinstrs %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan-compute %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK-DAG: OpName %[[#For:]] "for_loop"
5+
; CHECK-DAG: OpName %[[#While:]] "while_loop"
6+
; CHECK-DAG: OpName %[[#DoWhile:]] "do_while_loop"
7+
; CHECK-DAG: OpName %[[#Disable:]] "unroll_disable"
8+
; CHECK-DAG: OpName %[[#Count:]] "unroll_count"
9+
; CHECK-DAG: OpName %[[#Full:]] "unroll_full"
10+
; CHECK-DAG: OpName %[[#FullCount:]] "unroll_full_count"
11+
; CHECK-DAG: OpName %[[#EnableDisable:]] "unroll_enable_disable"
12+
13+
; CHECK: %[[#For]] = OpFunction
14+
; CHECK: OpLoopMerge %[[#]] %[[#]] Unroll
15+
16+
; CHECK: %[[#While]] = OpFunction
17+
; CHECK: OpLoopMerge %[[#]] %[[#]] Unroll
18+
19+
; CHECK: %[[#DoWhile]] = OpFunction
20+
; CHECK: OpLoopMerge %[[#]] %[[#]] Unroll
21+
22+
; CHECK: %[[#Disable]] = OpFunction
23+
; CHECK: OpLoopMerge %[[#]] %[[#]] DontUnroll
24+
25+
; CHECK: %[[#Count]] = OpFunction
26+
; CHECK: OpLoopMerge %[[#]] %[[#]] PartialCount 4
27+
28+
; CHECK: %[[#Full]] = OpFunction
29+
; CHECK: OpLoopMerge %[[#]] %[[#]] Unroll
30+
31+
; CHECK: %[[#FullCount]] = OpFunction
32+
; CHECK: OpLoopMerge %[[#]] %[[#]] Unroll|PartialCount 4
33+
34+
; CHECK: %[[#EnableDisable]] = OpFunction
35+
; CHECK: OpLoopMerge %[[#]] %[[#]] DontUnroll
36+
; CHECK-NOT: Unroll|DontUnroll
37+
; CHECK-NOT: DontUnroll|Unroll
38+
39+
define dso_local void @for_loop(ptr noundef %0, i32 noundef %1) {
40+
%3 = alloca ptr, align 8
41+
%4 = alloca i32, align 4
42+
%5 = alloca i32, align 4
43+
store ptr %0, ptr %3, align 8
44+
store i32 %1, ptr %4, align 4
45+
store i32 0, ptr %5, align 4
46+
br label %6
47+
48+
6: ; preds = %15, %2
49+
%7 = load i32, ptr %5, align 4
50+
%8 = load i32, ptr %4, align 4
51+
%9 = icmp slt i32 %7, %8
52+
br i1 %9, label %10, label %18
53+
54+
10: ; preds = %6
55+
%11 = load i32, ptr %5, align 4
56+
%12 = load ptr, ptr %3, align 8
57+
%13 = load i32, ptr %12, align 4
58+
%14 = add nsw i32 %13, %11
59+
store i32 %14, ptr %12, align 4
60+
br label %15
61+
62+
15: ; preds = %10
63+
%16 = load i32, ptr %5, align 4
64+
%17 = add nsw i32 %16, 1
65+
store i32 %17, ptr %5, align 4
66+
br label %6, !llvm.loop !1
67+
68+
18: ; preds = %6
69+
ret void
70+
}
71+
72+
define dso_local void @while_loop(ptr noundef %0, i32 noundef %1) {
73+
%3 = alloca ptr, align 8
74+
%4 = alloca i32, align 4
75+
%5 = alloca i32, align 4
76+
store ptr %0, ptr %3, align 8
77+
store i32 %1, ptr %4, align 4
78+
store i32 0, ptr %5, align 4
79+
br label %6
80+
81+
6: ; preds = %10, %2
82+
%7 = load i32, ptr %5, align 4
83+
%8 = load i32, ptr %4, align 4
84+
%9 = icmp slt i32 %7, %8
85+
br i1 %9, label %10, label %17
86+
87+
10: ; preds = %6
88+
%11 = load i32, ptr %5, align 4
89+
%12 = load ptr, ptr %3, align 8
90+
%13 = load i32, ptr %12, align 4
91+
%14 = add nsw i32 %13, %11
92+
store i32 %14, ptr %12, align 4
93+
%15 = load i32, ptr %5, align 4
94+
%16 = add nsw i32 %15, 1
95+
store i32 %16, ptr %5, align 4
96+
br label %6, !llvm.loop !3
97+
98+
17: ; preds = %6
99+
ret void
100+
}
101+
102+
define dso_local void @do_while_loop(ptr noundef %0, i32 noundef %1) {
103+
%3 = alloca ptr, align 8
104+
%4 = alloca i32, align 4
105+
%5 = alloca i32, align 4
106+
store ptr %0, ptr %3, align 8
107+
store i32 %1, ptr %4, align 4
108+
store i32 0, ptr %5, align 4
109+
br label %6
110+
111+
6: ; preds = %13, %2
112+
%7 = load i32, ptr %5, align 4
113+
%8 = load ptr, ptr %3, align 8
114+
%9 = load i32, ptr %8, align 4
115+
%10 = add nsw i32 %9, %7
116+
store i32 %10, ptr %8, align 4
117+
%11 = load i32, ptr %5, align 4
118+
%12 = add nsw i32 %11, 1
119+
store i32 %12, ptr %5, align 4
120+
br label %13
121+
122+
13: ; preds = %6
123+
%14 = load i32, ptr %5, align 4
124+
%15 = load i32, ptr %4, align 4
125+
%16 = icmp slt i32 %14, %15
126+
br i1 %16, label %6, label %17, !llvm.loop !4
127+
128+
17: ; preds = %13
129+
ret void
130+
}
131+
132+
define dso_local void @unroll_disable(i32 noundef %0) {
133+
%2 = alloca i32, align 4
134+
%3 = alloca i32, align 4
135+
store i32 %0, ptr %2, align 4
136+
store i32 0, ptr %3, align 4
137+
br label %4
138+
139+
4: ; preds = %7, %1
140+
%5 = load i32, ptr %3, align 4
141+
%6 = add nsw i32 %5, 1
142+
store i32 %6, ptr %3, align 4
143+
br label %7
144+
145+
7: ; preds = %4
146+
%8 = load i32, ptr %3, align 4
147+
%9 = load i32, ptr %2, align 4
148+
%10 = icmp slt i32 %8, %9
149+
br i1 %10, label %4, label %11, !llvm.loop !5
150+
151+
11: ; preds = %7
152+
ret void
153+
}
154+
155+
define dso_local void @unroll_count(i32 noundef %0) {
156+
%2 = alloca i32, align 4
157+
%3 = alloca i32, align 4
158+
store i32 %0, ptr %2, align 4
159+
store i32 0, ptr %3, align 4
160+
br label %4
161+
162+
4: ; preds = %7, %1
163+
%5 = load i32, ptr %3, align 4
164+
%6 = add nsw i32 %5, 1
165+
store i32 %6, ptr %3, align 4
166+
br label %7
167+
168+
7: ; preds = %4
169+
%8 = load i32, ptr %3, align 4
170+
%9 = load i32, ptr %2, align 4
171+
%10 = icmp slt i32 %8, %9
172+
br i1 %10, label %4, label %11, !llvm.loop !7
173+
174+
11: ; preds = %7
175+
ret void
176+
}
177+
178+
define dso_local void @unroll_full(i32 noundef %0) {
179+
%2 = alloca i32, align 4
180+
%3 = alloca i32, align 4
181+
store i32 %0, ptr %2, align 4
182+
store i32 0, ptr %3, align 4
183+
br label %4
184+
185+
4: ; preds = %7, %1
186+
%5 = load i32, ptr %3, align 4
187+
%6 = add nsw i32 %5, 1
188+
store i32 %6, ptr %3, align 4
189+
br label %7
190+
191+
7: ; preds = %4
192+
%8 = load i32, ptr %3, align 4
193+
%9 = load i32, ptr %2, align 4
194+
%10 = icmp slt i32 %8, %9
195+
br i1 %10, label %4, label %11, !llvm.loop !9
196+
197+
11: ; preds = %7
198+
ret void
199+
}
200+
201+
define dso_local void @unroll_full_count(i32 noundef %0) {
202+
%2 = alloca i32, align 4
203+
%3 = alloca i32, align 4
204+
store i32 %0, ptr %2, align 4
205+
store i32 0, ptr %3, align 4
206+
br label %4
207+
208+
4: ; preds = %7, %1
209+
%5 = load i32, ptr %3, align 4
210+
%6 = add nsw i32 %5, 1
211+
store i32 %6, ptr %3, align 4
212+
br label %7
213+
214+
7: ; preds = %4
215+
%8 = load i32, ptr %3, align 4
216+
%9 = load i32, ptr %2, align 4
217+
%10 = icmp slt i32 %8, %9
218+
br i1 %10, label %4, label %11, !llvm.loop !11
219+
220+
11: ; preds = %7
221+
ret void
222+
}
223+
224+
define dso_local void @unroll_enable_disable(i32 noundef %0) {
225+
%2 = alloca i32, align 4
226+
%3 = alloca i32, align 4
227+
store i32 %0, ptr %2, align 4
228+
store i32 0, ptr %3, align 4
229+
br label %4
230+
231+
4: ; preds = %7, %1
232+
%5 = load i32, ptr %3, align 4
233+
%6 = add nsw i32 %5, 1
234+
store i32 %6, ptr %3, align 4
235+
br label %7
236+
237+
7: ; preds = %4
238+
%8 = load i32, ptr %3, align 4
239+
%9 = load i32, ptr %2, align 4
240+
%10 = icmp slt i32 %8, %9
241+
br i1 %10, label %4, label %11, !llvm.loop !12
242+
243+
11: ; preds = %7
244+
ret void
245+
}
246+
247+
!1 = distinct !{!1, !2}
248+
!2 = !{!"llvm.loop.unroll.enable"}
249+
!3 = distinct !{!3, !2}
250+
!4 = distinct !{!4, !2}
251+
!5 = distinct !{!5, !6}
252+
!6 = !{!"llvm.loop.unroll.disable"}
253+
!7 = distinct !{!7, !8}
254+
!8 = !{!"llvm.loop.unroll.count", i32 4}
255+
!9 = distinct !{!9, !10}
256+
!10 = !{!"llvm.loop.unroll.full"}
257+
!11 = distinct !{!11, !10, !8}
258+
!12 = distinct !{!12, !2, !6}

0 commit comments

Comments
 (0)