16
16
#include " mlir/Dialect/ArmNeon/Transforms.h"
17
17
#include " mlir/Dialect/Func/IR/FuncOps.h"
18
18
#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
19
+ #include " mlir/Dialect/Utils/IndexingUtils.h"
19
20
#include " mlir/Dialect/Vector/IR/VectorOps.h"
21
+ #include " mlir/IR/AffineMap.h"
20
22
#include " mlir/IR/PatternMatch.h"
21
23
#include " mlir/Support/LogicalResult.h"
22
24
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -36,19 +38,17 @@ static Type matchContainerType(Type element, Type container) {
36
38
return element;
37
39
}
38
40
39
- // / Lowering from a single vector::contractOp directly to the arm neon smmla
40
- // / intrinsic. The shapes of the contract and intrinsic must match.
41
+ // / Lowering from a vector::contractOp arm neon smmla intrinsic. This up to an
42
+ // / 8x8x8 vector contract that is tiled (up to 16) smmla instructions with
43
+ // / unrolling. If no unrolling is necessary, a single smmla instruction is
44
+ // / emitted.
41
45
class LowerContractionToSMMLAPattern
42
46
: public OpRewritePattern<vector::ContractionOp> {
43
47
public:
44
48
using OpRewritePattern::OpRewritePattern;
45
49
LogicalResult matchAndRewrite (vector::ContractionOp op,
46
50
PatternRewriter &rewriter) const override {
47
51
Location loc = op.getLoc ();
48
- Value lhs = op.getLhs ();
49
- Value rhs = op.getRhs ();
50
- Value res = op.getAcc ();
51
-
52
52
// Check index maps that represent M N K in contract.
53
53
auto indexingMaps = op.getIndexingMapsArray ();
54
54
if (llvm::any_of (indexingMaps, [](mlir::AffineMap affineMap) {
@@ -57,7 +57,6 @@ class LowerContractionToSMMLAPattern
57
57
})) {
58
58
return failure ();
59
59
}
60
-
61
60
// Check iterator types for contract.
62
61
auto iteratorTypes = op.getIteratorTypesArray ();
63
62
if (iteratorTypes.size () != 3 ||
@@ -66,22 +65,24 @@ class LowerContractionToSMMLAPattern
66
65
iteratorTypes[2 ] != vector::IteratorType::reduction) {
67
66
return failure ();
68
67
}
69
-
70
- // Check the tile size by mapping the dimensions of the contract.
68
+ // Infer tile sizes from operands; Note: RHS is not transposed.
71
69
mlir::VectorType lhsType = op.getLhsType ();
72
70
mlir::VectorType rhsType = op.getRhsType ();
73
71
auto dimM = lhsType.getDimSize (0 );
74
72
auto dimN = rhsType.getDimSize (0 );
75
73
auto dimK = lhsType.getDimSize (1 );
76
- if (rhsType.getDimSize (1 ) != dimK || dimM != 2 || dimN != 2 || dimK != 8 ) {
74
+
75
+ // Unrolling patterns can handle [(2|4|8), (2|4|8), 8] shaped inputs for
76
+ // tiling.
77
+ if (dimM % 2 != 0 || dimM > 8 || dimN % 2 != 0 || dimN > 8 || dimK != 8 ) {
77
78
return failure ();
78
79
}
79
80
80
81
// Check two extsi inputs Rhs Lhs for contract.
81
82
arith::ExtSIOp origLhsExtOp =
82
- dyn_cast_or_null<arith::ExtSIOp>(lhs .getDefiningOp ());
83
+ dyn_cast_or_null<arith::ExtSIOp>(op. getLhs () .getDefiningOp ());
83
84
arith::ExtSIOp origRhsExtOp =
84
- dyn_cast_or_null<arith::ExtSIOp>(rhs .getDefiningOp ());
85
+ dyn_cast_or_null<arith::ExtSIOp>(op. getRhs () .getDefiningOp ());
85
86
if (!origLhsExtOp || !origRhsExtOp) {
86
87
return failure ();
87
88
}
@@ -113,26 +114,73 @@ class LowerContractionToSMMLAPattern
113
114
return failure ();
114
115
}
115
116
116
- // Collapse to 1D vectors required by smmla intrinsic
117
- auto collapsedInputType = VectorType::get (
118
- {16 }, extsiLhs.getType ().cast <ShapedType>().getElementType ());
119
- auto collapsedOutputType =
120
- VectorType::get ({4 }, res.getType ().cast <ShapedType>().getElementType ());
121
- auto collapsedLhs = rewriter.createOrFold <vector::ShapeCastOp>(
122
- extsiLhs.getLoc (), collapsedInputType, extsiLhs);
123
- auto collapsedRhs = rewriter.createOrFold <vector::ShapeCastOp>(
124
- extsiRhs.getLoc (), collapsedInputType, extsiRhs);
125
- auto collapsedRes = rewriter.createOrFold <vector::ShapeCastOp>(
126
- res.getLoc (), collapsedOutputType, res);
127
-
128
- // Replace the contract with a neon op
129
- auto smmlaOp = rewriter.createOrFold <arm_neon::SmmlaOp>(
130
- op.getLoc (), collapsedRes.getType (), collapsedRes, collapsedLhs,
131
- collapsedRhs);
132
-
133
- // Reshape output back to 2D
134
- rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(op, op.getResultType (),
135
- smmlaOp);
117
+ // Initial accumulator for the final result. This is the un-tiled result if
118
+ // tiling is done.
119
+ Value result = rewriter.create <arith::ConstantOp>(
120
+ loc, op.getResultType (), rewriter.getZeroAttr (op.getResultType ()));
121
+
122
+ SmallVector<int64_t > unrolledSize = *op.getShapeForUnroll ();
123
+ SmallVector<int64_t > smmlaShape{2 , 2 , 8 };
124
+ SmallVector<int64_t > loopOrder{0 , 1 , 2 };
125
+ for (SmallVector<int64_t > offsets :
126
+ StaticTileOffsetRange (unrolledSize, smmlaShape, loopOrder)) {
127
+
128
+ // Helper to compute the new shape of each operand and extract the slice.
129
+ auto extractOperand = [&](Value operand, AffineMap permutationMap,
130
+ ArrayRef<int64_t > operandOffsets) {
131
+ SmallVector<int64_t > operandShape =
132
+ applyPermutationMap (permutationMap, ArrayRef<int64_t >(smmlaShape));
133
+ SmallVector<int64_t > operandStrides (operandOffsets.size (), 1 );
134
+ return rewriter.createOrFold <vector::ExtractStridedSliceOp>(
135
+ loc, operand, operandOffsets, operandShape, operandStrides);
136
+ };
137
+
138
+ // Extract tiled lhs, rhs, and acc
139
+ AffineMap lhsPermutationMap = op.getIndexingMapsArray ()[0 ];
140
+ SmallVector<int64_t > lhsOffsets =
141
+ applyPermutationMap (lhsPermutationMap, ArrayRef<int64_t >(offsets));
142
+ auto tiledLhs = extractOperand (extsiLhs, lhsPermutationMap, lhsOffsets);
143
+ AffineMap rhsPermutationMap = op.getIndexingMapsArray ()[1 ];
144
+ SmallVector<int64_t > rhsOffsets =
145
+ applyPermutationMap (rhsPermutationMap, ArrayRef<int64_t >(offsets));
146
+ auto tiledRhs = extractOperand (extsiRhs, rhsPermutationMap, rhsOffsets);
147
+ AffineMap accPermutationMap = op.getIndexingMapsArray ()[2 ];
148
+ SmallVector<int64_t > accOffsets =
149
+ applyPermutationMap (accPermutationMap, ArrayRef<int64_t >(offsets));
150
+ auto tiledAcc =
151
+ extractOperand (op.getAcc (), accPermutationMap, accOffsets);
152
+
153
+ // Collapse tiled operands to 1D vectors required by smmla intrinsic
154
+ auto collapsedInputType = VectorType::get (
155
+ tiledLhs.getType ().cast <ShapedType>().getNumElements (),
156
+ tiledLhs.getType ().cast <ShapedType>().getElementType ());
157
+ auto collapsedOutputType = VectorType::get (
158
+ {4 }, tiledAcc.getType ().cast <ShapedType>().getElementType ());
159
+ auto collapsedLhs = rewriter.createOrFold <vector::ShapeCastOp>(
160
+ tiledLhs.getLoc (), collapsedInputType, tiledLhs);
161
+ auto collapsedRhs = rewriter.createOrFold <vector::ShapeCastOp>(
162
+ tiledRhs.getLoc (), collapsedInputType, tiledRhs);
163
+ auto collapsedRes = rewriter.createOrFold <vector::ShapeCastOp>(
164
+ tiledAcc.getLoc (), collapsedOutputType, tiledAcc);
165
+
166
+ // Insert contract op
167
+ auto smmlaOp = rewriter.createOrFold <arm_neon::SmmlaOp>(
168
+ op.getLoc (), collapsedRes.getType (), collapsedRes, collapsedLhs,
169
+ collapsedRhs);
170
+
171
+ // Reshape output back to 2D
172
+ Value tiledRes = rewriter.createOrFold <vector::ShapeCastOp>(
173
+ smmlaOp.getLoc (), tiledAcc.getType (), smmlaOp);
174
+
175
+ // Insert the tiled result back into the non tiled result of the
176
+ // contract op.
177
+ SmallVector<int64_t > strides (
178
+ tiledRes.getType ().cast <ShapedType>().getRank (), 1 );
179
+ result = rewriter.createOrFold <vector::InsertStridedSliceOp>(
180
+ loc, tiledRes, result, accOffsets, strides);
181
+ }
182
+
183
+ rewriter.replaceOp (op, result);
136
184
return success ();
137
185
}
138
186
};
0 commit comments