11
11
// ===----------------------------------------------------------------------===//
12
12
#include " mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h"
13
13
#include " mlir/Dialect/GPU/GPUDialect.h"
14
- #include " mlir/Dialect/SCF/SCF.h"
15
14
#include " mlir/Dialect/SPIRV/SPIRVDialect.h"
16
15
#include " mlir/Dialect/SPIRV/SPIRVLowering.h"
17
16
#include " mlir/Dialect/SPIRV/SPIRVOps.h"
20
19
using namespace mlir ;
21
20
22
21
namespace {
23
-
24
- // / Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp.
25
- class ForOpConversion final : public SPIRVOpLowering<scf::ForOp> {
26
- public:
27
- using SPIRVOpLowering<scf::ForOp>::SPIRVOpLowering;
28
-
29
- LogicalResult
30
- matchAndRewrite (scf::ForOp forOp, ArrayRef<Value> operands,
31
- ConversionPatternRewriter &rewriter) const override ;
32
- };
33
-
34
- // / Pattern to convert a scf::IfOp within kernel functions into
35
- // / spirv::SelectionOp.
36
- class IfOpConversion final : public SPIRVOpLowering<scf::IfOp> {
37
- public:
38
- using SPIRVOpLowering<scf::IfOp>::SPIRVOpLowering;
39
-
40
- LogicalResult
41
- matchAndRewrite (scf::IfOp IfOp, ArrayRef<Value> operands,
42
- ConversionPatternRewriter &rewriter) const override ;
43
- };
44
-
45
- // / Pattern to erase a scf::YieldOp.
46
- class TerminatorOpConversion final : public SPIRVOpLowering<scf::YieldOp> {
47
- public:
48
- using SPIRVOpLowering<scf::YieldOp>::SPIRVOpLowering;
49
-
50
- LogicalResult
51
- matchAndRewrite (scf::YieldOp terminatorOp, ArrayRef<Value> operands,
52
- ConversionPatternRewriter &rewriter) const override {
53
- rewriter.eraseOp (terminatorOp);
54
- return success ();
55
- }
56
- };
57
-
58
22
// / Pattern lowering GPU block/thread size/id to loading SPIR-V invocation
59
23
// / builtin variables.
60
24
template <typename SourceOp, spirv::BuiltIn builtin>
@@ -128,134 +92,6 @@ class GPUReturnOpConversion final : public SPIRVOpLowering<gpu::ReturnOp> {
128
92
129
93
} // namespace
130
94
131
- // ===----------------------------------------------------------------------===//
132
- // scf::ForOp.
133
- // ===----------------------------------------------------------------------===//
134
-
135
- LogicalResult
136
- ForOpConversion::matchAndRewrite (scf::ForOp forOp, ArrayRef<Value> operands,
137
- ConversionPatternRewriter &rewriter) const {
138
- // scf::ForOp can be lowered to the structured control flow represented by
139
- // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
140
- // latch and the merge block the exit block. The resulting spirv::LoopOp has a
141
- // single back edge from the continue to header block, and a single exit from
142
- // header to merge.
143
- scf::ForOpAdaptor forOperands (operands);
144
- auto loc = forOp.getLoc ();
145
- auto loopControl = rewriter.getI32IntegerAttr (
146
- static_cast <uint32_t >(spirv::LoopControl::None));
147
- auto loopOp = rewriter.create <spirv::LoopOp>(loc, loopControl);
148
- loopOp.addEntryAndMergeBlock ();
149
-
150
- OpBuilder::InsertionGuard guard (rewriter);
151
- // Create the block for the header.
152
- auto header = new Block ();
153
- // Insert the header.
154
- loopOp.body ().getBlocks ().insert (std::next (loopOp.body ().begin (), 1 ), header);
155
-
156
- // Create the new induction variable to use.
157
- BlockArgument newIndVar =
158
- header->addArgument (forOperands.lowerBound ().getType ());
159
- Block *body = forOp.getBody ();
160
-
161
- // Apply signature conversion to the body of the forOp. It has a single block,
162
- // with argument which is the induction variable. That has to be replaced with
163
- // the new induction variable.
164
- TypeConverter::SignatureConversion signatureConverter (
165
- body->getNumArguments ());
166
- signatureConverter.remapInput (0 , newIndVar);
167
- FailureOr<Block *> newBody = rewriter.convertRegionTypes (
168
- &forOp.getLoopBody (), typeConverter, &signatureConverter);
169
- if (failed (newBody))
170
- return failure ();
171
- body = *newBody;
172
-
173
- // Delete the loop terminator.
174
- rewriter.eraseOp (body->getTerminator ());
175
-
176
- // Move the blocks from the forOp into the loopOp. This is the body of the
177
- // loopOp.
178
- rewriter.inlineRegionBefore (forOp.getOperation ()->getRegion (0 ), loopOp.body (),
179
- std::next (loopOp.body ().begin (), 2 ));
180
-
181
- // Branch into it from the entry.
182
- rewriter.setInsertionPointToEnd (&(loopOp.body ().front ()));
183
- rewriter.create <spirv::BranchOp>(loc, header, forOperands.lowerBound ());
184
-
185
- // Generate the rest of the loop header.
186
- rewriter.setInsertionPointToEnd (header);
187
- auto mergeBlock = loopOp.getMergeBlock ();
188
- auto cmpOp = rewriter.create <spirv::SLessThanOp>(
189
- loc, rewriter.getI1Type (), newIndVar, forOperands.upperBound ());
190
- rewriter.create <spirv::BranchConditionalOp>(
191
- loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
192
-
193
- // Generate instructions to increment the step of the induction variable and
194
- // branch to the header.
195
- Block *continueBlock = loopOp.getContinueBlock ();
196
- rewriter.setInsertionPointToEnd (continueBlock);
197
-
198
- // Add the step to the induction variable and branch to the header.
199
- Value updatedIndVar = rewriter.create <spirv::IAddOp>(
200
- loc, newIndVar.getType (), newIndVar, forOperands.step ());
201
- rewriter.create <spirv::BranchOp>(loc, header, updatedIndVar);
202
-
203
- rewriter.eraseOp (forOp);
204
- return success ();
205
- }
206
-
207
- // ===----------------------------------------------------------------------===//
208
- // scf::IfOp.
209
- // ===----------------------------------------------------------------------===//
210
-
211
- LogicalResult
212
- IfOpConversion::matchAndRewrite (scf::IfOp ifOp, ArrayRef<Value> operands,
213
- ConversionPatternRewriter &rewriter) const {
214
- // When lowering `scf::IfOp` we explicitly create a selection header block
215
- // before the control flow diverges and a merge block where control flow
216
- // subsequently converges.
217
- scf::IfOpAdaptor ifOperands (operands);
218
- auto loc = ifOp.getLoc ();
219
-
220
- // Create `spv.selection` operation, selection header block and merge block.
221
- auto selectionControl = rewriter.getI32IntegerAttr (
222
- static_cast <uint32_t >(spirv::SelectionControl::None));
223
- auto selectionOp = rewriter.create <spirv::SelectionOp>(loc, selectionControl);
224
- selectionOp.addMergeBlock ();
225
- auto *mergeBlock = selectionOp.getMergeBlock ();
226
-
227
- OpBuilder::InsertionGuard guard (rewriter);
228
- auto *selectionHeaderBlock = new Block ();
229
- selectionOp.body ().getBlocks ().push_front (selectionHeaderBlock);
230
-
231
- // Inline `then` region before the merge block and branch to it.
232
- auto &thenRegion = ifOp.thenRegion ();
233
- auto *thenBlock = &thenRegion.front ();
234
- rewriter.setInsertionPointToEnd (&thenRegion.back ());
235
- rewriter.create <spirv::BranchOp>(loc, mergeBlock);
236
- rewriter.inlineRegionBefore (thenRegion, mergeBlock);
237
-
238
- auto *elseBlock = mergeBlock;
239
- // If `else` region is not empty, inline that region before the merge block
240
- // and branch to it.
241
- if (!ifOp.elseRegion ().empty ()) {
242
- auto &elseRegion = ifOp.elseRegion ();
243
- elseBlock = &elseRegion.front ();
244
- rewriter.setInsertionPointToEnd (&elseRegion.back ());
245
- rewriter.create <spirv::BranchOp>(loc, mergeBlock);
246
- rewriter.inlineRegionBefore (elseRegion, mergeBlock);
247
- }
248
-
249
- // Create a `spv.BranchConditional` operation for selection header block.
250
- rewriter.setInsertionPointToEnd (selectionHeaderBlock);
251
- rewriter.create <spirv::BranchConditionalOp>(loc, ifOperands.condition (),
252
- thenBlock, ArrayRef<Value>(),
253
- elseBlock, ArrayRef<Value>());
254
-
255
- rewriter.eraseOp (ifOp);
256
- return success ();
257
- }
258
-
259
95
// ===----------------------------------------------------------------------===//
260
96
// Builtins.
261
97
// ===----------------------------------------------------------------------===//
@@ -479,8 +315,7 @@ void mlir::populateGPUToSPIRVPatterns(MLIRContext *context,
479
315
OwningRewritePatternList &patterns) {
480
316
populateWithGenerated (context, &patterns);
481
317
patterns.insert <
482
- ForOpConversion, GPUFuncOpConversion, GPUModuleConversion,
483
- GPUReturnOpConversion, IfOpConversion,
318
+ GPUFuncOpConversion, GPUModuleConversion, GPUReturnOpConversion,
484
319
LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
485
320
LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
486
321
LaunchConfigConversion<gpu::ThreadIdOp,
@@ -491,5 +326,5 @@ void mlir::populateGPUToSPIRVPatterns(MLIRContext *context,
491
326
spirv::BuiltIn::NumSubgroups>,
492
327
SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
493
328
spirv::BuiltIn::SubgroupSize>,
494
- TerminatorOpConversion, WorkGroupSizeConversion>(context, typeConverter);
329
+ WorkGroupSizeConversion>(context, typeConverter);
495
330
}
0 commit comments