Skip to content

Commit e682f15

Browse files
committed
More refactoring
1 parent 995952b commit e682f15

File tree

2 files changed

+23
-29
lines changed

2 files changed

+23
-29
lines changed

mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ using namespace mlir;
3737
namespace {
3838

3939
/// A pass to perform the SPIR-V conversion.
40-
struct ConvertToSPIRVPass
41-
: public impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
40+
struct ConvertToSPIRVPass final
41+
: impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
4242
using ConvertToSPIRVPassBase::ConvertToSPIRVPassBase;
4343

4444
void runOnOperation() override {
@@ -47,16 +47,13 @@ struct ConvertToSPIRVPass
4747

4848
if (runSignatureConversion) {
4949
// Unroll vectors in function signatures to native vector size.
50-
{
51-
RewritePatternSet patterns(context);
52-
populateFuncOpVectorRewritePatterns(patterns);
53-
populateReturnOpVectorRewritePatterns(patterns);
54-
GreedyRewriteConfig config;
55-
config.strictMode = GreedyRewriteStrictness::ExistingOps;
56-
if (failed(
57-
applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
58-
return signalPassFailure();
59-
}
50+
RewritePatternSet patterns(context);
51+
populateFuncOpVectorRewritePatterns(patterns);
52+
populateReturnOpVectorRewritePatterns(patterns);
53+
GreedyRewriteConfig config;
54+
config.strictMode = GreedyRewriteStrictness::ExistingOps;
55+
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
56+
return signalPassFailure();
6057
return;
6158
}
6259

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
#include "llvm/Support/Debug.h"
3434
#include "llvm/Support/MathExtras.h"
3535

36-
#include <cctype>
36+
#include <functional>
3737
#include <optional>
3838

3939
#define DEBUG_TYPE "mlir-spirv-conversion"
@@ -867,8 +867,7 @@ void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
867867
namespace {
868868
/// A pattern for rewriting function signature to convert vector arguments of
869869
/// functions to be of valid types
870-
class FuncOpVectorUnroll : public OpRewritePattern<func::FuncOp> {
871-
public:
870+
struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
872871
using OpRewritePattern::OpRewritePattern;
873872

874873
LogicalResult matchAndRewrite(func::FuncOp funcOp,
@@ -922,8 +921,8 @@ class FuncOpVectorUnroll : public OpRewritePattern<func::FuncOp> {
922921
rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
923922
tmpOps.insert({result.getDefiningOp(), newInputNo});
924923
oneToNTypeMapping.addInputs(origInputNo, origType);
925-
newInputNo++;
926-
newOpCount++;
924+
++newInputNo;
925+
++newOpCount;
927926
continue;
928927
}
929928
// Check whether the vector needs unrolling.
@@ -935,8 +934,8 @@ class FuncOpVectorUnroll : public OpRewritePattern<func::FuncOp> {
935934
rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
936935
tmpOps.insert({result.getDefiningOp(), newInputNo});
937936
oneToNTypeMapping.addInputs(origInputNo, origType);
938-
newInputNo++;
939-
newOpCount++;
937+
++newInputNo;
938+
++newOpCount;
940939
continue;
941940
}
942941
VectorType unrolledType =
@@ -947,11 +946,11 @@ class FuncOpVectorUnroll : public OpRewritePattern<func::FuncOp> {
947946
// Prepare the result vector.
948947
Value result = rewriter.create<arith::ConstantOp>(
949948
loc, origVecType, rewriter.getZeroAttr(origVecType));
950-
newOpCount++;
949+
++newOpCount;
951950
// Prepare the placeholder for the new arguments that will be added later.
952951
Value dummy = rewriter.create<arith::ConstantOp>(
953952
loc, unrolledType, rewriter.getZeroAttr(unrolledType));
954-
newOpCount++;
953+
++newOpCount;
955954

956955
// Create the `vector.insert_strided_slice` ops.
957956
SmallVector<int64_t> strides(targetShape->size(), 1);
@@ -962,8 +961,8 @@ class FuncOpVectorUnroll : public OpRewritePattern<func::FuncOp> {
962961
loc, dummy, result, offsets, strides);
963962
newTypes.push_back(unrolledType);
964963
unrolledInputNums.push_back(newInputNo);
965-
newInputNo++;
966-
newOpCount++;
964+
++newInputNo;
965+
++newOpCount;
967966
}
968967
rewriter.replaceAllUsesWith(newFuncOp.getArgument(origInputNo), result);
969968
oneToNTypeMapping.addInputs(origInputNo, newTypes);
@@ -999,15 +998,14 @@ class FuncOpVectorUnroll : public OpRewritePattern<func::FuncOp> {
999998
// not be touched.
1000999
if (count >= newOpCount)
10011000
continue;
1002-
auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op);
1003-
if (vecOp) {
1001+
if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
10041002
size_t unrolledInputNo = unrolledInputNums[idx];
10051003
rewriter.modifyOpInPlace(&op, [&] {
10061004
op.setOperand(0, newFuncOp.getArgument(unrolledInputNo));
10071005
});
1008-
idx++;
1006+
++idx;
10091007
}
1010-
count++;
1008+
++count;
10111009
}
10121010

10131011
// Erase the original funcOp. The `tmpOps` do not need to be erased since
@@ -1029,8 +1027,7 @@ void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) {
10291027
namespace {
10301028
/// A pattern for rewriting function signature and the return op to convert
10311029
/// vectors to be of valid types.
1032-
class ReturnOpVectorUnroll : public OpRewritePattern<func::ReturnOp> {
1033-
public:
1030+
struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
10341031
using OpRewritePattern::OpRewritePattern;
10351032

10361033
LogicalResult matchAndRewrite(func::ReturnOp returnOp,

0 commit comments

Comments
 (0)