33
33
#include " llvm/Support/Debug.h"
34
34
#include " llvm/Support/MathExtras.h"
35
35
36
- #include < cctype >
36
+ #include < functional >
37
37
#include < optional>
38
38
39
39
#define DEBUG_TYPE " mlir-spirv-conversion"
@@ -867,8 +867,7 @@ void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
867
867
namespace {
868
868
// / A pattern for rewriting function signature to convert vector arguments of
869
869
// / functions to be of valid types
870
- class FuncOpVectorUnroll : public OpRewritePattern <func::FuncOp> {
871
- public:
870
+ struct FuncOpVectorUnroll final : OpRewritePattern<func::FuncOp> {
872
871
using OpRewritePattern::OpRewritePattern;
873
872
874
873
LogicalResult matchAndRewrite (func::FuncOp funcOp,
@@ -922,8 +921,8 @@ class FuncOpVectorUnroll : public OpRewritePattern<func::FuncOp> {
922
921
rewriter.replaceAllUsesWith (newFuncOp.getArgument (origInputNo), result);
923
922
tmpOps.insert ({result.getDefiningOp (), newInputNo});
924
923
oneToNTypeMapping.addInputs (origInputNo, origType);
925
- newInputNo++ ;
926
- newOpCount++ ;
924
+ ++newInputNo ;
925
+ ++newOpCount ;
927
926
continue ;
928
927
}
929
928
// Check whether the vector needs unrolling.
@@ -935,8 +934,8 @@ class FuncOpVectorUnroll : public OpRewritePattern<func::FuncOp> {
935
934
rewriter.replaceAllUsesWith (newFuncOp.getArgument (origInputNo), result);
936
935
tmpOps.insert ({result.getDefiningOp (), newInputNo});
937
936
oneToNTypeMapping.addInputs (origInputNo, origType);
938
- newInputNo++ ;
939
- newOpCount++ ;
937
+ ++newInputNo ;
938
+ ++newOpCount ;
940
939
continue ;
941
940
}
942
941
VectorType unrolledType =
@@ -947,11 +946,11 @@ class FuncOpVectorUnroll : public OpRewritePattern<func::FuncOp> {
947
946
// Prepare the result vector.
948
947
Value result = rewriter.create <arith::ConstantOp>(
949
948
loc, origVecType, rewriter.getZeroAttr (origVecType));
950
- newOpCount++ ;
949
+ ++newOpCount ;
951
950
// Prepare the placeholder for the new arguments that will be added later.
952
951
Value dummy = rewriter.create <arith::ConstantOp>(
953
952
loc, unrolledType, rewriter.getZeroAttr (unrolledType));
954
- newOpCount++ ;
953
+ ++newOpCount ;
955
954
956
955
// Create the `vector.insert_strided_slice` ops.
957
956
SmallVector<int64_t > strides (targetShape->size (), 1 );
@@ -962,8 +961,8 @@ class FuncOpVectorUnroll : public OpRewritePattern<func::FuncOp> {
962
961
loc, dummy, result, offsets, strides);
963
962
newTypes.push_back (unrolledType);
964
963
unrolledInputNums.push_back (newInputNo);
965
- newInputNo++ ;
966
- newOpCount++ ;
964
+ ++newInputNo ;
965
+ ++newOpCount ;
967
966
}
968
967
rewriter.replaceAllUsesWith (newFuncOp.getArgument (origInputNo), result);
969
968
oneToNTypeMapping.addInputs (origInputNo, newTypes);
@@ -999,15 +998,14 @@ class FuncOpVectorUnroll : public OpRewritePattern<func::FuncOp> {
999
998
// not be touched.
1000
999
if (count >= newOpCount)
1001
1000
continue ;
1002
- auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op);
1003
- if (vecOp) {
1001
+ if (auto vecOp = dyn_cast<vector::InsertStridedSliceOp>(op)) {
1004
1002
size_t unrolledInputNo = unrolledInputNums[idx];
1005
1003
rewriter.modifyOpInPlace (&op, [&] {
1006
1004
op.setOperand (0 , newFuncOp.getArgument (unrolledInputNo));
1007
1005
});
1008
- idx++ ;
1006
+ ++idx ;
1009
1007
}
1010
- count++ ;
1008
+ ++count ;
1011
1009
}
1012
1010
1013
1011
// Erase the original funcOp. The `tmpOps` do not need to be erased since
@@ -1029,8 +1027,7 @@ void mlir::populateFuncOpVectorRewritePatterns(RewritePatternSet &patterns) {
1029
1027
namespace {
1030
1028
// / A pattern for rewriting function signature and the return op to convert
1031
1029
// / vectors to be of valid types.
1032
- class ReturnOpVectorUnroll : public OpRewritePattern <func::ReturnOp> {
1033
- public:
1030
+ struct ReturnOpVectorUnroll final : OpRewritePattern<func::ReturnOp> {
1034
1031
using OpRewritePattern::OpRewritePattern;
1035
1032
1036
1033
LogicalResult matchAndRewrite (func::ReturnOp returnOp,
0 commit comments