18
18
19
19
using namespace mlir ;
20
20
21
+ namespace mlir {
22
+ struct ScfToSPIRVContextImpl {
23
+ // Map between the spirv region control flow operation (spv.loop or
24
+ // spv.selection) to the VariableOp created to store the region results. The
25
+ // order of the VariableOp matches the order of the results.
26
+ DenseMap<Operation *, SmallVector<spirv::VariableOp, 8 >> outputVars;
27
+ };
28
+ } // namespace mlir
29
+
30
+ // / We use ScfToSPIRVContext to store information about the lowering of the scf
31
+ // / region that need to be used later on. When we lower scf.for/scf.if we create
32
+ // / VariableOp to store the results. We need to keep track of the VariableOp
33
+ // / created as we need to insert stores into them when lowering Yield. Those
34
+ // / StoreOp cannot be created earlier as they may use a different type than
35
+ // / yield operands.
36
+ ScfToSPIRVContext::ScfToSPIRVContext () {
37
+ impl = std::make_unique<ScfToSPIRVContextImpl>();
38
+ }
39
+ ScfToSPIRVContext::~ScfToSPIRVContext () = default ;
40
+
21
41
namespace {
42
+ // / Common class for all vector to GPU patterns.
43
+ template <typename OpTy>
44
+ class SCFToSPIRVPattern : public SPIRVOpLowering <OpTy> {
45
+ public:
46
+ SCFToSPIRVPattern<OpTy>(MLIRContext *context, SPIRVTypeConverter &converter,
47
+ ScfToSPIRVContextImpl *scfToSPIRVContext)
48
+ : SPIRVOpLowering<OpTy>::SPIRVOpLowering(context, converter),
49
+ scfToSPIRVContext (scfToSPIRVContext) {}
50
+
51
+ protected:
52
+ ScfToSPIRVContextImpl *scfToSPIRVContext;
53
+ };
22
54
23
55
// / Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp.
24
- class ForOpConversion final : public SPIRVOpLowering <scf::ForOp> {
56
+ class ForOpConversion final : public SCFToSPIRVPattern <scf::ForOp> {
25
57
public:
26
- using SPIRVOpLowering <scf::ForOp>::SPIRVOpLowering ;
58
+ using SCFToSPIRVPattern <scf::ForOp>::SCFToSPIRVPattern ;
27
59
28
60
LogicalResult
29
61
matchAndRewrite (scf::ForOp forOp, ArrayRef<Value> operands,
@@ -32,29 +64,54 @@ class ForOpConversion final : public SPIRVOpLowering<scf::ForOp> {
32
64
33
65
// / Pattern to convert a scf::IfOp within kernel functions into
34
66
// / spirv::SelectionOp.
35
- class IfOpConversion final : public SPIRVOpLowering <scf::IfOp> {
67
+ class IfOpConversion final : public SCFToSPIRVPattern <scf::IfOp> {
36
68
public:
37
- using SPIRVOpLowering <scf::IfOp>::SPIRVOpLowering ;
69
+ using SCFToSPIRVPattern <scf::IfOp>::SCFToSPIRVPattern ;
38
70
39
71
LogicalResult
40
72
matchAndRewrite (scf::IfOp ifOp, ArrayRef<Value> operands,
41
73
ConversionPatternRewriter &rewriter) const override ;
42
74
};
43
75
44
- // / Pattern to erase a scf::YieldOp.
45
- class TerminatorOpConversion final : public SPIRVOpLowering<scf::YieldOp> {
76
+ class TerminatorOpConversion final : public SCFToSPIRVPattern<scf::YieldOp> {
46
77
public:
47
- using SPIRVOpLowering <scf::YieldOp>::SPIRVOpLowering ;
78
+ using SCFToSPIRVPattern <scf::YieldOp>::SCFToSPIRVPattern ;
48
79
49
80
LogicalResult
50
81
matchAndRewrite (scf::YieldOp terminatorOp, ArrayRef<Value> operands,
51
- ConversionPatternRewriter &rewriter) const override {
52
- rewriter.eraseOp (terminatorOp);
53
- return success ();
54
- }
82
+ ConversionPatternRewriter &rewriter) const override ;
55
83
};
56
84
} // namespace
57
85
86
+ // / Helper function to replaces SCF op outputs with SPIR-V variable loads.
87
+ // / We create VariableOp to handle the results value of the control flow region.
88
+ // / spv.loop/spv.selection currently don't yield value. Right after the loop
89
+ // / we load the value from the allocation and use it as the SCF op result.
90
+ template <typename ScfOp, typename OpTy>
91
+ static void replaceSCFOutputValue (ScfOp scfOp, OpTy newOp,
92
+ SPIRVTypeConverter &typeConverter,
93
+ ConversionPatternRewriter &rewriter,
94
+ ScfToSPIRVContextImpl *scfToSPIRVContext) {
95
+
96
+ Location loc = scfOp.getLoc ();
97
+ auto &allocas = scfToSPIRVContext->outputVars [newOp];
98
+ SmallVector<Value, 8 > resultValue;
99
+ for (Value result : scfOp.results ()) {
100
+ auto convertedType = typeConverter.convertType (result.getType ());
101
+ auto pointerType =
102
+ spirv::PointerType::get (convertedType, spirv::StorageClass::Function);
103
+ rewriter.setInsertionPoint (newOp);
104
+ auto alloc = rewriter.create <spirv::VariableOp>(
105
+ loc, pointerType, spirv::StorageClass::Function,
106
+ /* initializer=*/ nullptr );
107
+ allocas.push_back (alloc);
108
+ rewriter.setInsertionPointAfter (newOp);
109
+ Value loadResult = rewriter.create <spirv::LoadOp>(loc, alloc);
110
+ resultValue.push_back (loadResult);
111
+ }
112
+ rewriter.replaceOp (scfOp, resultValue);
113
+ }
114
+
58
115
// ===----------------------------------------------------------------------===//
59
116
// scf::ForOp.
60
117
// ===----------------------------------------------------------------------===//
@@ -83,6 +140,8 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
83
140
// Create the new induction variable to use.
84
141
BlockArgument newIndVar =
85
142
header->addArgument (forOperands.lowerBound ().getType ());
143
+ for (Value arg : forOperands.initArgs ())
144
+ header->addArgument (arg.getType ());
86
145
Block *body = forOp.getBody ();
87
146
88
147
// Apply signature conversion to the body of the forOp. It has a single block,
@@ -91,29 +150,28 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
91
150
TypeConverter::SignatureConversion signatureConverter (
92
151
body->getNumArguments ());
93
152
signatureConverter.remapInput (0 , newIndVar);
94
- FailureOr<Block *> newBody = rewriter.convertRegionTypes (
95
- &forOp.getLoopBody (), typeConverter, &signatureConverter);
96
- if (failed (newBody))
97
- return failure ();
98
- body = *newBody;
99
-
100
- // Delete the loop terminator.
101
- rewriter.eraseOp (body->getTerminator ());
153
+ for (unsigned i = 1 , e = body->getNumArguments (); i < e; i++)
154
+ signatureConverter.remapInput (i, header->getArgument (i));
155
+ body = rewriter.applySignatureConversion (&forOp.getLoopBody (),
156
+ signatureConverter);
102
157
103
158
// Move the blocks from the forOp into the loopOp. This is the body of the
104
159
// loopOp.
105
160
rewriter.inlineRegionBefore (forOp.getOperation ()->getRegion (0 ), loopOp.body (),
106
161
std::next (loopOp.body ().begin (), 2 ));
107
162
163
+ SmallVector<Value, 8 > args (1 , forOperands.lowerBound ());
164
+ args.append (forOperands.initArgs ().begin (), forOperands.initArgs ().end ());
108
165
// Branch into it from the entry.
109
166
rewriter.setInsertionPointToEnd (&(loopOp.body ().front ()));
110
- rewriter.create <spirv::BranchOp>(loc, header, forOperands. lowerBound () );
167
+ rewriter.create <spirv::BranchOp>(loc, header, args );
111
168
112
169
// Generate the rest of the loop header.
113
170
rewriter.setInsertionPointToEnd (header);
114
171
auto *mergeBlock = loopOp.getMergeBlock ();
115
172
auto cmpOp = rewriter.create <spirv::SLessThanOp>(
116
173
loc, rewriter.getI1Type (), newIndVar, forOperands.upperBound ());
174
+
117
175
rewriter.create <spirv::BranchConditionalOp>(
118
176
loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
119
177
@@ -127,7 +185,8 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
127
185
loc, newIndVar.getType (), newIndVar, forOperands.step ());
128
186
rewriter.create <spirv::BranchOp>(loc, header, updatedIndVar);
129
187
130
- rewriter.eraseOp (forOp);
188
+ replaceSCFOutputValue (forOp, loopOp, typeConverter, rewriter,
189
+ scfToSPIRVContext);
131
190
return success ();
132
191
}
133
192
@@ -179,13 +238,45 @@ IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
179
238
thenBlock, ArrayRef<Value>(),
180
239
elseBlock, ArrayRef<Value>());
181
240
182
- rewriter.eraseOp (ifOp);
241
+ replaceSCFOutputValue (ifOp, selectionOp, typeConverter, rewriter,
242
+ scfToSPIRVContext);
243
+ return success ();
244
+ }
245
+
246
+ // / Yield is lowered to stores to the VariableOp created during lowering of the
247
+ // / parent region. For loops we also need to update the branch looping back to
248
+ // / the header with the loop carried values.
249
+ LogicalResult TerminatorOpConversion::matchAndRewrite (
250
+ scf::YieldOp terminatorOp, ArrayRef<Value> operands,
251
+ ConversionPatternRewriter &rewriter) const {
252
+ // If the region is return values, store each value into the associated
253
+ // VariableOp created during lowering of the parent region.
254
+ if (!operands.empty ()) {
255
+ auto loc = terminatorOp.getLoc ();
256
+ auto &allocas = scfToSPIRVContext->outputVars [terminatorOp.getParentOp ()];
257
+ assert (allocas.size () == operands.size ());
258
+ for (unsigned i = 0 , e = operands.size (); i < e; i++)
259
+ rewriter.create <spirv::StoreOp>(loc, allocas[i], operands[i]);
260
+ if (isa<spirv::LoopOp>(terminatorOp.getParentOp ())) {
261
+ // For loops we also need to update the branch jumping back to the header.
262
+ auto br =
263
+ cast<spirv::BranchOp>(rewriter.getInsertionBlock ()->getTerminator ());
264
+ SmallVector<Value, 8 > args (br.getBlockArguments ());
265
+ args.append (operands.begin (), operands.end ());
266
+ rewriter.setInsertionPoint (br);
267
+ rewriter.create <spirv::BranchOp>(terminatorOp.getLoc (), br.getTarget (),
268
+ args);
269
+ rewriter.eraseOp (br);
270
+ }
271
+ }
272
+ rewriter.eraseOp (terminatorOp);
183
273
return success ();
184
274
}
185
275
186
276
void mlir::populateSCFToSPIRVPatterns (MLIRContext *context,
187
277
SPIRVTypeConverter &typeConverter,
278
+ ScfToSPIRVContext &scfToSPIRVContext,
188
279
OwningRewritePatternList &patterns) {
189
280
patterns.insert <ForOpConversion, IfOpConversion, TerminatorOpConversion>(
190
- context, typeConverter);
281
+ context, typeConverter, scfToSPIRVContext. getImpl () );
191
282
}
0 commit comments