Skip to content

Commit fbce985

Browse files
committed
[mlir][NFC] Move conversion of scf to spir-v ops in their own file
Move patterns for scf to spir-v ops in their own file/folder. Differential Revision: https://reviews.llvm.org/D82914
1 parent dd90408 commit fbce985

File tree

7 files changed

+249
-167
lines changed

7 files changed

+249
-167
lines changed
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
//===------------ SCFToSPIRV.h - Pass entrypoint ----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Provides patterns for lowering SCF ops to SPIR-V dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
#ifndef MLIR_CONVERSION_SCFTOSPIRV_SCFTOSPIRV_H_
13+
#define MLIR_CONVERSION_SCFTOSPIRV_SCFTOSPIRV_H_
14+
15+
#include <memory>
16+
17+
namespace mlir {
18+
class MLIRContext;
19+
class Pass;
20+
21+
// Owning list of rewriting patterns.
22+
class OwningRewritePatternList;
23+
class SPIRVTypeConverter;
24+
25+
/// Collects a set of patterns to lower from scf.for, scf.if, and
26+
/// loop.terminator to CFG operations within the SPIR-V dialect.
27+
void populateSCFToSPIRVPatterns(MLIRContext *context,
28+
SPIRVTypeConverter &typeConverter,
29+
OwningRewritePatternList &patterns);
30+
} // namespace mlir
31+
32+
#endif // MLIR_CONVERSION_SCFTOSPIRV_SCFTOSPIRV_H_

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ add_subdirectory(LinalgToLLVM)
99
add_subdirectory(LinalgToSPIRV)
1010
add_subdirectory(LinalgToStandard)
1111
add_subdirectory(SCFToGPU)
12+
add_subdirectory(SCFToSPIRV)
1213
add_subdirectory(SCFToStandard)
1314
add_subdirectory(ShapeToSCF)
1415
add_subdirectory(ShapeToStandard)

mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ add_mlir_conversion_library(MLIRGPUToSPIRVTransforms
1414
MLIRGPU
1515
MLIRIR
1616
MLIRPass
17+
MLIRSCFToSPIRV
1718
MLIRSPIRV
1819
MLIRStandardOps
1920
MLIRStandardToSPIRVTransforms

mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp

Lines changed: 2 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
//===----------------------------------------------------------------------===//
1212
#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h"
1313
#include "mlir/Dialect/GPU/GPUDialect.h"
14-
#include "mlir/Dialect/SCF/SCF.h"
1514
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
1615
#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
1716
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
@@ -20,41 +19,6 @@
2019
using namespace mlir;
2120

2221
namespace {
23-
24-
/// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp.
25-
class ForOpConversion final : public SPIRVOpLowering<scf::ForOp> {
26-
public:
27-
using SPIRVOpLowering<scf::ForOp>::SPIRVOpLowering;
28-
29-
LogicalResult
30-
matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
31-
ConversionPatternRewriter &rewriter) const override;
32-
};
33-
34-
/// Pattern to convert a scf::IfOp within kernel functions into
35-
/// spirv::SelectionOp.
36-
class IfOpConversion final : public SPIRVOpLowering<scf::IfOp> {
37-
public:
38-
using SPIRVOpLowering<scf::IfOp>::SPIRVOpLowering;
39-
40-
LogicalResult
41-
matchAndRewrite(scf::IfOp IfOp, ArrayRef<Value> operands,
42-
ConversionPatternRewriter &rewriter) const override;
43-
};
44-
45-
/// Pattern to erase a scf::YieldOp.
46-
class TerminatorOpConversion final : public SPIRVOpLowering<scf::YieldOp> {
47-
public:
48-
using SPIRVOpLowering<scf::YieldOp>::SPIRVOpLowering;
49-
50-
LogicalResult
51-
matchAndRewrite(scf::YieldOp terminatorOp, ArrayRef<Value> operands,
52-
ConversionPatternRewriter &rewriter) const override {
53-
rewriter.eraseOp(terminatorOp);
54-
return success();
55-
}
56-
};
57-
5822
/// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation
5923
/// builtin variables.
6024
template <typename SourceOp, spirv::BuiltIn builtin>
@@ -128,134 +92,6 @@ class GPUReturnOpConversion final : public SPIRVOpLowering<gpu::ReturnOp> {
12892

12993
} // namespace
13094

131-
//===----------------------------------------------------------------------===//
132-
// scf::ForOp.
133-
//===----------------------------------------------------------------------===//
134-
135-
LogicalResult
136-
ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
137-
ConversionPatternRewriter &rewriter) const {
138-
// scf::ForOp can be lowered to the structured control flow represented by
139-
// spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
140-
// latch and the merge block the exit block. The resulting spirv::LoopOp has a
141-
// single back edge from the continue to header block, and a single exit from
142-
// header to merge.
143-
scf::ForOpAdaptor forOperands(operands);
144-
auto loc = forOp.getLoc();
145-
auto loopControl = rewriter.getI32IntegerAttr(
146-
static_cast<uint32_t>(spirv::LoopControl::None));
147-
auto loopOp = rewriter.create<spirv::LoopOp>(loc, loopControl);
148-
loopOp.addEntryAndMergeBlock();
149-
150-
OpBuilder::InsertionGuard guard(rewriter);
151-
// Create the block for the header.
152-
auto header = new Block();
153-
// Insert the header.
154-
loopOp.body().getBlocks().insert(std::next(loopOp.body().begin(), 1), header);
155-
156-
// Create the new induction variable to use.
157-
BlockArgument newIndVar =
158-
header->addArgument(forOperands.lowerBound().getType());
159-
Block *body = forOp.getBody();
160-
161-
// Apply signature conversion to the body of the forOp. It has a single block,
162-
// with argument which is the induction variable. That has to be replaced with
163-
// the new induction variable.
164-
TypeConverter::SignatureConversion signatureConverter(
165-
body->getNumArguments());
166-
signatureConverter.remapInput(0, newIndVar);
167-
FailureOr<Block *> newBody = rewriter.convertRegionTypes(
168-
&forOp.getLoopBody(), typeConverter, &signatureConverter);
169-
if (failed(newBody))
170-
return failure();
171-
body = *newBody;
172-
173-
// Delete the loop terminator.
174-
rewriter.eraseOp(body->getTerminator());
175-
176-
// Move the blocks from the forOp into the loopOp. This is the body of the
177-
// loopOp.
178-
rewriter.inlineRegionBefore(forOp.getOperation()->getRegion(0), loopOp.body(),
179-
std::next(loopOp.body().begin(), 2));
180-
181-
// Branch into it from the entry.
182-
rewriter.setInsertionPointToEnd(&(loopOp.body().front()));
183-
rewriter.create<spirv::BranchOp>(loc, header, forOperands.lowerBound());
184-
185-
// Generate the rest of the loop header.
186-
rewriter.setInsertionPointToEnd(header);
187-
auto mergeBlock = loopOp.getMergeBlock();
188-
auto cmpOp = rewriter.create<spirv::SLessThanOp>(
189-
loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound());
190-
rewriter.create<spirv::BranchConditionalOp>(
191-
loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
192-
193-
// Generate instructions to increment the step of the induction variable and
194-
// branch to the header.
195-
Block *continueBlock = loopOp.getContinueBlock();
196-
rewriter.setInsertionPointToEnd(continueBlock);
197-
198-
// Add the step to the induction variable and branch to the header.
199-
Value updatedIndVar = rewriter.create<spirv::IAddOp>(
200-
loc, newIndVar.getType(), newIndVar, forOperands.step());
201-
rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
202-
203-
rewriter.eraseOp(forOp);
204-
return success();
205-
}
206-
207-
//===----------------------------------------------------------------------===//
208-
// scf::IfOp.
209-
//===----------------------------------------------------------------------===//
210-
211-
LogicalResult
212-
IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
213-
ConversionPatternRewriter &rewriter) const {
214-
// When lowering `scf::IfOp` we explicitly create a selection header block
215-
// before the control flow diverges and a merge block where control flow
216-
// subsequently converges.
217-
scf::IfOpAdaptor ifOperands(operands);
218-
auto loc = ifOp.getLoc();
219-
220-
// Create `spv.selection` operation, selection header block and merge block.
221-
auto selectionControl = rewriter.getI32IntegerAttr(
222-
static_cast<uint32_t>(spirv::SelectionControl::None));
223-
auto selectionOp = rewriter.create<spirv::SelectionOp>(loc, selectionControl);
224-
selectionOp.addMergeBlock();
225-
auto *mergeBlock = selectionOp.getMergeBlock();
226-
227-
OpBuilder::InsertionGuard guard(rewriter);
228-
auto *selectionHeaderBlock = new Block();
229-
selectionOp.body().getBlocks().push_front(selectionHeaderBlock);
230-
231-
// Inline `then` region before the merge block and branch to it.
232-
auto &thenRegion = ifOp.thenRegion();
233-
auto *thenBlock = &thenRegion.front();
234-
rewriter.setInsertionPointToEnd(&thenRegion.back());
235-
rewriter.create<spirv::BranchOp>(loc, mergeBlock);
236-
rewriter.inlineRegionBefore(thenRegion, mergeBlock);
237-
238-
auto *elseBlock = mergeBlock;
239-
// If `else` region is not empty, inline that region before the merge block
240-
// and branch to it.
241-
if (!ifOp.elseRegion().empty()) {
242-
auto &elseRegion = ifOp.elseRegion();
243-
elseBlock = &elseRegion.front();
244-
rewriter.setInsertionPointToEnd(&elseRegion.back());
245-
rewriter.create<spirv::BranchOp>(loc, mergeBlock);
246-
rewriter.inlineRegionBefore(elseRegion, mergeBlock);
247-
}
248-
249-
// Create a `spv.BranchConditional` operation for selection header block.
250-
rewriter.setInsertionPointToEnd(selectionHeaderBlock);
251-
rewriter.create<spirv::BranchConditionalOp>(loc, ifOperands.condition(),
252-
thenBlock, ArrayRef<Value>(),
253-
elseBlock, ArrayRef<Value>());
254-
255-
rewriter.eraseOp(ifOp);
256-
return success();
257-
}
258-
25995
//===----------------------------------------------------------------------===//
26096
// Builtins.
26197
//===----------------------------------------------------------------------===//
@@ -479,8 +315,7 @@ void mlir::populateGPUToSPIRVPatterns(MLIRContext *context,
479315
OwningRewritePatternList &patterns) {
480316
populateWithGenerated(context, &patterns);
481317
patterns.insert<
482-
ForOpConversion, GPUFuncOpConversion, GPUModuleConversion,
483-
GPUReturnOpConversion, IfOpConversion,
318+
GPUFuncOpConversion, GPUModuleConversion, GPUReturnOpConversion,
484319
LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
485320
LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
486321
LaunchConfigConversion<gpu::ThreadIdOp,
@@ -491,5 +326,5 @@ void mlir::populateGPUToSPIRVPatterns(MLIRContext *context,
491326
spirv::BuiltIn::NumSubgroups>,
492327
SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
493328
spirv::BuiltIn::SubgroupSize>,
494-
TerminatorOpConversion, WorkGroupSizeConversion>(context, typeConverter);
329+
WorkGroupSizeConversion>(context, typeConverter);
495330
}

mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h"
1515
#include "../PassDetail.h"
1616
#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h"
17+
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
1718
#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
1819
#include "mlir/Dialect/GPU/GPUDialect.h"
1920
#include "mlir/Dialect/SCF/SCF.h"
@@ -59,6 +60,7 @@ void GPUToSPIRVPass::runOnOperation() {
5960
SPIRVTypeConverter typeConverter(targetAttr);
6061
OwningRewritePatternList patterns;
6162
populateGPUToSPIRVPatterns(context, typeConverter, patterns);
63+
populateSCFToSPIRVPatterns(context, typeConverter, patterns);
6264
populateStandardToSPIRVPatterns(context, typeConverter, patterns);
6365

6466
if (failed(applyFullConversion(kernelModules, *target, patterns)))
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
add_mlir_conversion_library(MLIRSCFToSPIRV
2+
SCFToSPIRV.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SCFToSPIRV
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
10+
LINK_LIBS PUBLIC
11+
MLIRAffineOps
12+
MLIRAffineToStandard
13+
MLIRSPIRV
14+
MLIRIR
15+
MLIRLinalgOps
16+
MLIRPass
17+
MLIRStandardOps
18+
MLIRSupport
19+
MLIRTransforms
20+
)

0 commit comments

Comments
 (0)