Skip to content

Commit 0670f85

Browse files
committed
[mlir][spirv] Add support for lowering scf.for scf/if with return value
This allow lowering to support scf.for and scf.if with results. As right now spv region operations don't have return value the results are demoted to Function memory. We create one allocation per result right before the region and store the yield values in it. Then we can load back the value from allocation to be able to use the results. Differential Revision: https://reviews.llvm.org/D82246
1 parent fbce985 commit 0670f85

File tree

6 files changed

+245
-30
lines changed

6 files changed

+245
-30
lines changed

mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,23 @@ class Pass;
2121
// Owning list of rewriting patterns.
2222
class OwningRewritePatternList;
2323
class SPIRVTypeConverter;
24+
struct ScfToSPIRVContextImpl;
25+
26+
struct ScfToSPIRVContext {
27+
ScfToSPIRVContext();
28+
~ScfToSPIRVContext();
29+
30+
ScfToSPIRVContextImpl *getImpl() { return impl.get(); }
31+
32+
private:
33+
std::unique_ptr<ScfToSPIRVContextImpl> impl;
34+
};
2435

2536
/// Collects a set of patterns to lower from scf.for, scf.if, and
2637
/// loop.terminator to CFG operations within the SPIR-V dialect.
2738
void populateSCFToSPIRVPatterns(MLIRContext *context,
2839
SPIRVTypeConverter &typeConverter,
40+
ScfToSPIRVContext &scfToSPIRVContext,
2941
OwningRewritePatternList &patterns);
3042
} // namespace mlir
3143

mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,10 @@ void GPUToSPIRVPass::runOnOperation() {
5858
spirv::SPIRVConversionTarget::get(targetAttr);
5959

6060
SPIRVTypeConverter typeConverter(targetAttr);
61+
ScfToSPIRVContext scfContext;
6162
OwningRewritePatternList patterns;
6263
populateGPUToSPIRVPatterns(context, typeConverter, patterns);
63-
populateSCFToSPIRVPatterns(context, typeConverter, patterns);
64+
populateSCFToSPIRVPatterns(context, typeConverter,scfContext, patterns);
6465
populateStandardToSPIRVPatterns(context, typeConverter, patterns);
6566

6667
if (failed(applyFullConversion(kernelModules, *target, patterns)))

mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp

Lines changed: 114 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,44 @@
1818

1919
using namespace mlir;
2020

21+
namespace mlir {
22+
struct ScfToSPIRVContextImpl {
23+
// Map between the spirv region control flow operation (spv.loop or
24+
// spv.selection) to the VariableOp created to store the region results. The
25+
// order of the VariableOp matches the order of the results.
26+
DenseMap<Operation *, SmallVector<spirv::VariableOp, 8>> outputVars;
27+
};
28+
} // namespace mlir
29+
30+
/// We use ScfToSPIRVContext to store information about the lowering of the scf
31+
/// region that need to be used later on. When we lower scf.for/scf.if we create
32+
/// VariableOp to store the results. We need to keep track of the VariableOp
33+
/// created as we need to insert stores into them when lowering Yield. Those
34+
/// StoreOp cannot be created earlier as they may use a different type than
35+
/// yield operands.
36+
ScfToSPIRVContext::ScfToSPIRVContext() {
37+
impl = std::make_unique<ScfToSPIRVContextImpl>();
38+
}
39+
ScfToSPIRVContext::~ScfToSPIRVContext() = default;
40+
2141
namespace {
42+
/// Common class for all vector to GPU patterns.
43+
template <typename OpTy>
44+
class SCFToSPIRVPattern : public SPIRVOpLowering<OpTy> {
45+
public:
46+
SCFToSPIRVPattern<OpTy>(MLIRContext *context, SPIRVTypeConverter &converter,
47+
ScfToSPIRVContextImpl *scfToSPIRVContext)
48+
: SPIRVOpLowering<OpTy>::SPIRVOpLowering(context, converter),
49+
scfToSPIRVContext(scfToSPIRVContext) {}
50+
51+
protected:
52+
ScfToSPIRVContextImpl *scfToSPIRVContext;
53+
};
2254

2355
/// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp.
24-
class ForOpConversion final : public SPIRVOpLowering<scf::ForOp> {
56+
class ForOpConversion final : public SCFToSPIRVPattern<scf::ForOp> {
2557
public:
26-
using SPIRVOpLowering<scf::ForOp>::SPIRVOpLowering;
58+
using SCFToSPIRVPattern<scf::ForOp>::SCFToSPIRVPattern;
2759

2860
LogicalResult
2961
matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
@@ -32,29 +64,54 @@ class ForOpConversion final : public SPIRVOpLowering<scf::ForOp> {
3264

3365
/// Pattern to convert a scf::IfOp within kernel functions into
3466
/// spirv::SelectionOp.
35-
class IfOpConversion final : public SPIRVOpLowering<scf::IfOp> {
67+
class IfOpConversion final : public SCFToSPIRVPattern<scf::IfOp> {
3668
public:
37-
using SPIRVOpLowering<scf::IfOp>::SPIRVOpLowering;
69+
using SCFToSPIRVPattern<scf::IfOp>::SCFToSPIRVPattern;
3870

3971
LogicalResult
4072
matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
4173
ConversionPatternRewriter &rewriter) const override;
4274
};
4375

44-
/// Pattern to erase a scf::YieldOp.
45-
class TerminatorOpConversion final : public SPIRVOpLowering<scf::YieldOp> {
76+
class TerminatorOpConversion final : public SCFToSPIRVPattern<scf::YieldOp> {
4677
public:
47-
using SPIRVOpLowering<scf::YieldOp>::SPIRVOpLowering;
78+
using SCFToSPIRVPattern<scf::YieldOp>::SCFToSPIRVPattern;
4879

4980
LogicalResult
5081
matchAndRewrite(scf::YieldOp terminatorOp, ArrayRef<Value> operands,
51-
ConversionPatternRewriter &rewriter) const override {
52-
rewriter.eraseOp(terminatorOp);
53-
return success();
54-
}
82+
ConversionPatternRewriter &rewriter) const override;
5583
};
5684
} // namespace
5785

86+
/// Helper function to replaces SCF op outputs with SPIR-V variable loads.
87+
/// We create VariableOp to handle the results value of the control flow region.
88+
/// spv.loop/spv.selection currently don't yield value. Right after the loop
89+
/// we load the value from the allocation and use it as the SCF op result.
90+
template <typename ScfOp, typename OpTy>
91+
static void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
92+
SPIRVTypeConverter &typeConverter,
93+
ConversionPatternRewriter &rewriter,
94+
ScfToSPIRVContextImpl *scfToSPIRVContext) {
95+
96+
Location loc = scfOp.getLoc();
97+
auto &allocas = scfToSPIRVContext->outputVars[newOp];
98+
SmallVector<Value, 8> resultValue;
99+
for (Value result : scfOp.results()) {
100+
auto convertedType = typeConverter.convertType(result.getType());
101+
auto pointerType =
102+
spirv::PointerType::get(convertedType, spirv::StorageClass::Function);
103+
rewriter.setInsertionPoint(newOp);
104+
auto alloc = rewriter.create<spirv::VariableOp>(
105+
loc, pointerType, spirv::StorageClass::Function,
106+
/*initializer=*/nullptr);
107+
allocas.push_back(alloc);
108+
rewriter.setInsertionPointAfter(newOp);
109+
Value loadResult = rewriter.create<spirv::LoadOp>(loc, alloc);
110+
resultValue.push_back(loadResult);
111+
}
112+
rewriter.replaceOp(scfOp, resultValue);
113+
}
114+
58115
//===----------------------------------------------------------------------===//
59116
// scf::ForOp.
60117
//===----------------------------------------------------------------------===//
@@ -83,6 +140,8 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
83140
// Create the new induction variable to use.
84141
BlockArgument newIndVar =
85142
header->addArgument(forOperands.lowerBound().getType());
143+
for (Value arg : forOperands.initArgs())
144+
header->addArgument(arg.getType());
86145
Block *body = forOp.getBody();
87146

88147
// Apply signature conversion to the body of the forOp. It has a single block,
@@ -91,29 +150,28 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
91150
TypeConverter::SignatureConversion signatureConverter(
92151
body->getNumArguments());
93152
signatureConverter.remapInput(0, newIndVar);
94-
FailureOr<Block *> newBody = rewriter.convertRegionTypes(
95-
&forOp.getLoopBody(), typeConverter, &signatureConverter);
96-
if (failed(newBody))
97-
return failure();
98-
body = *newBody;
99-
100-
// Delete the loop terminator.
101-
rewriter.eraseOp(body->getTerminator());
153+
for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
154+
signatureConverter.remapInput(i, header->getArgument(i));
155+
body = rewriter.applySignatureConversion(&forOp.getLoopBody(),
156+
signatureConverter);
102157

103158
// Move the blocks from the forOp into the loopOp. This is the body of the
104159
// loopOp.
105160
rewriter.inlineRegionBefore(forOp.getOperation()->getRegion(0), loopOp.body(),
106161
std::next(loopOp.body().begin(), 2));
107162

163+
SmallVector<Value, 8> args(1, forOperands.lowerBound());
164+
args.append(forOperands.initArgs().begin(), forOperands.initArgs().end());
108165
// Branch into it from the entry.
109166
rewriter.setInsertionPointToEnd(&(loopOp.body().front()));
110-
rewriter.create<spirv::BranchOp>(loc, header, forOperands.lowerBound());
167+
rewriter.create<spirv::BranchOp>(loc, header, args);
111168

112169
// Generate the rest of the loop header.
113170
rewriter.setInsertionPointToEnd(header);
114171
auto *mergeBlock = loopOp.getMergeBlock();
115172
auto cmpOp = rewriter.create<spirv::SLessThanOp>(
116173
loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound());
174+
117175
rewriter.create<spirv::BranchConditionalOp>(
118176
loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
119177

@@ -127,7 +185,8 @@ ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
127185
loc, newIndVar.getType(), newIndVar, forOperands.step());
128186
rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
129187

130-
rewriter.eraseOp(forOp);
188+
replaceSCFOutputValue(forOp, loopOp, typeConverter, rewriter,
189+
scfToSPIRVContext);
131190
return success();
132191
}
133192

@@ -179,13 +238,45 @@ IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
179238
thenBlock, ArrayRef<Value>(),
180239
elseBlock, ArrayRef<Value>());
181240

182-
rewriter.eraseOp(ifOp);
241+
replaceSCFOutputValue(ifOp, selectionOp, typeConverter, rewriter,
242+
scfToSPIRVContext);
243+
return success();
244+
}
245+
246+
/// Yield is lowered to stores to the VariableOp created during lowering of the
247+
/// parent region. For loops we also need to update the branch looping back to
248+
/// the header with the loop carried values.
249+
LogicalResult TerminatorOpConversion::matchAndRewrite(
250+
scf::YieldOp terminatorOp, ArrayRef<Value> operands,
251+
ConversionPatternRewriter &rewriter) const {
252+
// If the region is return values, store each value into the associated
253+
// VariableOp created during lowering of the parent region.
254+
if (!operands.empty()) {
255+
auto loc = terminatorOp.getLoc();
256+
auto &allocas = scfToSPIRVContext->outputVars[terminatorOp.getParentOp()];
257+
assert(allocas.size() == operands.size());
258+
for (unsigned i = 0, e = operands.size(); i < e; i++)
259+
rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]);
260+
if (isa<spirv::LoopOp>(terminatorOp.getParentOp())) {
261+
// For loops we also need to update the branch jumping back to the header.
262+
auto br =
263+
cast<spirv::BranchOp>(rewriter.getInsertionBlock()->getTerminator());
264+
SmallVector<Value, 8> args(br.getBlockArguments());
265+
args.append(operands.begin(), operands.end());
266+
rewriter.setInsertionPoint(br);
267+
rewriter.create<spirv::BranchOp>(terminatorOp.getLoc(), br.getTarget(),
268+
args);
269+
rewriter.eraseOp(br);
270+
}
271+
}
272+
rewriter.eraseOp(terminatorOp);
183273
return success();
184274
}
185275

186276
void mlir::populateSCFToSPIRVPatterns(MLIRContext *context,
187277
SPIRVTypeConverter &typeConverter,
278+
ScfToSPIRVContext &scfToSPIRVContext,
188279
OwningRewritePatternList &patterns) {
189280
patterns.insert<ForOpConversion, IfOpConversion, TerminatorOpConversion>(
190-
context, typeConverter);
281+
context, typeConverter, scfToSPIRVContext.getImpl());
191282
}

mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -589,9 +589,6 @@ StorageClass PointerType::getStorageClass() const {
589589

590590
void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
591591
Optional<StorageClass> storage) {
592-
if (storage)
593-
assert(*storage == getStorageClass() && "inconsistent storage class!");
594-
595592
// Use this pointer type's storage class because this pointer indicates we are
596593
// using the pointee type in that specific storage class.
597594
getPointeeType().cast<SPIRVType>().getExtensions(extensions,
@@ -604,9 +601,6 @@ void PointerType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
604601
void PointerType::getCapabilities(
605602
SPIRVType::CapabilityArrayRefVector &capabilities,
606603
Optional<StorageClass> storage) {
607-
if (storage)
608-
assert(*storage == getStorageClass() && "inconsistent storage class!");
609-
610604
// Use this pointer type's storage class because this pointer indicates we are
611605
// using the pointee type in that specific storage class.
612606
getPointeeType().cast<SPIRVType>().getCapabilities(capabilities,

mlir/test/Conversion/GPUToSPIRV/if.mlir

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,5 +89,79 @@ module attributes {
8989
}
9090
gpu.return
9191
}
92+
// CHECK-LABEL: @simple_if_yield
93+
gpu.func @simple_if_yield(%arg2 : memref<10xf32>, %arg3 : i1) kernel
94+
attributes {spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}} {
95+
// CHECK: %[[VAR1:.*]] = spv.Variable : !spv.ptr<f32, Function>
96+
// CHECK: %[[VAR2:.*]] = spv.Variable : !spv.ptr<f32, Function>
97+
// CHECK: spv.selection {
98+
// CHECK-NEXT: spv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[FALSE:\^.*]]
99+
// CHECK-NEXT: [[TRUE]]:
100+
// CHECK: %[[RET1TRUE:.*]] = spv.constant 0.000000e+00 : f32
101+
// CHECK: %[[RET2TRUE:.*]] = spv.constant 1.000000e+00 : f32
102+
// CHECK-DAG: spv.Store "Function" %[[VAR1]], %[[RET1TRUE]] : f32
103+
// CHECK-DAG: spv.Store "Function" %[[VAR2]], %[[RET2TRUE]] : f32
104+
// CHECK: spv.Branch ^[[MERGE:.*]]
105+
// CHECK-NEXT: [[FALSE]]:
106+
// CHECK: %[[RET2FALSE:.*]] = spv.constant 2.000000e+00 : f32
107+
// CHECK: %[[RET1FALSE:.*]] = spv.constant 3.000000e+00 : f32
108+
// CHECK-DAG: spv.Store "Function" %[[VAR1]], %[[RET1FALSE]] : f32
109+
// CHECK-DAG: spv.Store "Function" %[[VAR2]], %[[RET2FALSE]] : f32
110+
// CHECK: spv.Branch ^[[MERGE]]
111+
// CHECK-NEXT: ^[[MERGE]]:
112+
// CHECK: spv._merge
113+
// CHECK-NEXT: }
114+
// CHECK-DAG: %[[OUT1:.*]] = spv.Load "Function" %[[VAR1]] : f32
115+
// CHECK-DAG: %[[OUT2:.*]] = spv.Load "Function" %[[VAR2]] : f32
116+
// CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT1]] : f32
117+
// CHECK: spv.Store "StorageBuffer" {{%.*}}, %[[OUT2]] : f32
118+
// CHECK: spv.Return
119+
%0:2 = scf.if %arg3 -> (f32, f32) {
120+
%c0 = constant 0.0 : f32
121+
%c1 = constant 1.0 : f32
122+
scf.yield %c0, %c1 : f32, f32
123+
} else {
124+
%c0 = constant 2.0 : f32
125+
%c1 = constant 3.0 : f32
126+
scf.yield %c1, %c0 : f32, f32
127+
}
128+
%i = constant 0 : index
129+
%j = constant 1 : index
130+
store %0#0, %arg2[%i] : memref<10xf32>
131+
store %0#1, %arg2[%j] : memref<10xf32>
132+
gpu.return
133+
}
134+
// TODO(thomasraoux): The transformation should only be legal if
135+
// VariablePointer capability is supported. This test is still useful to
136+
// make sure we can handle scf op result with type change.
137+
// CHECK-LABEL: @simple_if_yield_type_change
138+
// CHECK: %[[VAR:.*]] = spv.Variable : !spv.ptr<!spv.ptr<!spv.struct<!spv.array<10 x f32, stride=4> [0]>, StorageBuffer>, Function>
139+
// CHECK: spv.selection {
140+
// CHECK-NEXT: spv.BranchConditional {{%.*}}, [[TRUE:\^.*]], [[FALSE:\^.*]]
141+
// CHECK-NEXT: [[TRUE]]:
142+
// CHECK: spv.Store "Function" %[[VAR]], {{%.*}} : !spv.ptr<!spv.struct<!spv.array<10 x f32, stride=4> [0]>, StorageBuffer>
143+
// CHECK: spv.Branch ^[[MERGE:.*]]
144+
// CHECK-NEXT: [[FALSE]]:
145+
// CHECK: spv.Store "Function" %[[VAR]], {{%.*}} : !spv.ptr<!spv.struct<!spv.array<10 x f32, stride=4> [0]>, StorageBuffer>
146+
// CHECK: spv.Branch ^[[MERGE]]
147+
// CHECK-NEXT: ^[[MERGE]]:
148+
// CHECK: spv._merge
149+
// CHECK-NEXT: }
150+
// CHECK: %[[OUT:.*]] = spv.Load "Function" %[[VAR]] : !spv.ptr<!spv.struct<!spv.array<10 x f32, stride=4> [0]>, StorageBuffer>
151+
// CHECK: %[[ADD:.*]] = spv.AccessChain %[[OUT]][{{%.*}}, {{%.*}}] : !spv.ptr<!spv.struct<!spv.array<10 x f32, stride=4> [0]>, StorageBuffer>
152+
// CHECK: spv.Store "StorageBuffer" %[[ADD]], {{%.*}} : f32
153+
// CHECK: spv.Return
154+
gpu.func @simple_if_yield_type_change(%arg2 : memref<10xf32>, %arg3 : memref<10xf32>, %arg4 : i1) kernel
155+
attributes {spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}} {
156+
%i = constant 0 : index
157+
%value = constant 0.0 : f32
158+
%0 = scf.if %arg4 -> (memref<10xf32>) {
159+
scf.yield %arg2 : memref<10xf32>
160+
} else {
161+
scf.yield %arg3 : memref<10xf32>
162+
}
163+
store %value, %0[%i] : memref<10xf32>
164+
gpu.return
165+
}
92166
}
93167
}

0 commit comments

Comments
 (0)