Skip to content

Commit 4259a93

Browse files
committed
implement vecmat unroll for i8mm
1 parent 29bf32e commit 4259a93

File tree

2 files changed

+168
-27
lines changed

2 files changed

+168
-27
lines changed

mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -40,41 +40,38 @@ static Type matchContainerType(Type element, Type container) {
4040

4141
/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
4242
/// any vector.contract into multiple smmla instructions with unrolling so long
43-
/// as [2,2,8] is a divisor of its shape. If no unrolling is necessary, a single
44-
/// smmla instruction is emitted.
43+
/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
44+
/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
45+
/// necessary, a single smmla instruction is emitted.
4546
class LowerContractionToSMMLAPattern
4647
: public OpRewritePattern<vector::ContractionOp> {
4748
public:
4849
using OpRewritePattern::OpRewritePattern;
4950
LogicalResult matchAndRewrite(vector::ContractionOp op,
5051
PatternRewriter &rewriter) const override {
5152
Location loc = op.getLoc();
52-
// Check index maps that represent M N K in contract.
53-
auto indexingMaps = op.getIndexingMapsArray();
54-
if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
55-
return affineMap.isPermutation() || affineMap.getNumDims() != 3 ||
56-
affineMap.getNumResults() != 2;
57-
})) {
58-
return failure();
59-
}
60-
// Check iterator types for contract.
61-
auto iteratorTypes = op.getIteratorTypesArray();
62-
if (iteratorTypes.size() != 3 ||
63-
iteratorTypes[0] != vector::IteratorType::parallel ||
64-
iteratorTypes[1] != vector::IteratorType::parallel ||
65-
iteratorTypes[2] != vector::IteratorType::reduction) {
66-
return failure();
67-
}
68-
// Infer tile sizes from operands; Note: RHS is not transposed.
53+
// Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
54+
// Note: RHS is not transposed.
6955
mlir::VectorType lhsType = op.getLhsType();
7056
mlir::VectorType rhsType = op.getRhsType();
71-
auto dimM = lhsType.getDimSize(0);
57+
auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
7258
auto dimN = rhsType.getDimSize(0);
73-
auto dimK = lhsType.getDimSize(1);
74-
59+
auto dimK = rhsType.getDimSize(1);
60+
bool isVecmat = dimM == 1 ? true : false;
61+
if (lhsType.getDimSize(lhsType.getRank() - 1) !=
62+
rhsType.getDimSize(rhsType.getRank() - 1)) {
63+
return failure(); // dimK mismatch
64+
}
7565
// Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
7666
// tiling.
77-
if (dimM % 2 != 0 || dimN % 2 != 0 || dimK % 8 != 0) {
67+
if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
68+
return failure();
69+
}
70+
71+
// Check iterator types for contract.
72+
auto iteratorTypes = op.getIteratorTypesArray();
73+
if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
74+
vector::IteratorType::reduction) {
7875
return failure();
7976
}
8077

@@ -120,11 +117,14 @@ class LowerContractionToSMMLAPattern
120117
loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
121118

122119
SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
123-
SmallVector<int64_t> smmlaShape{2, 2, 8};
120+
SmallVector<int64_t> smmlaShape{isVecmat ? 1 : 2, 2, 8};
124121
SmallVector<int64_t> loopOrder{0, 1, 2};
122+
if (unrolledSize.size() == 2) {
123+
smmlaShape = {2, 8};
124+
loopOrder = {0, 1};
125+
}
125126
for (SmallVector<int64_t> offsets :
126127
StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
127-
128128
// Helper to compute the new shape of each operand and extract the slice.
129129
auto extractOperand = [&](Value operand, AffineMap permutationMap,
130130
ArrayRef<int64_t> operandOffsets) {
@@ -150,16 +150,30 @@ class LowerContractionToSMMLAPattern
150150
Value tiledAcc =
151151
extractOperand(op.getAcc(), accPermutationMap, accOffsets);
152152

153+
// With vecmat, tiled LHS and ACC will contain only one of 2 necessary
154+
// rows along dimM. Broadcast both to the full width
155+
if (isVecmat) {
156+
auto lhsBroadcastType = VectorType::get(
157+
{2, 8}, tiledLhs.getType().cast<ShapedType>().getElementType());
158+
tiledLhs = rewriter.create<vector::BroadcastOp>(loc, lhsBroadcastType,
159+
tiledLhs);
160+
auto accBroadcastType = VectorType::get(
161+
{2, 2}, tiledAcc.getType().cast<ShapedType>().getElementType());
162+
tiledAcc = rewriter.create<vector::BroadcastOp>(loc, accBroadcastType,
163+
tiledAcc);
164+
}
165+
153166
// Collapse tiled operands to 1D vectors required by smmla intrinsic
154167
auto collapsedInputType = VectorType::get(
155168
tiledLhs.getType().cast<ShapedType>().getNumElements(),
156169
tiledLhs.getType().cast<ShapedType>().getElementType());
157-
auto collapsedOutputType = VectorType::get(
158-
{4}, tiledAcc.getType().cast<ShapedType>().getElementType());
159170
auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
160171
tiledLhs.getLoc(), collapsedInputType, tiledLhs);
161172
auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
162173
tiledRhs.getLoc(), collapsedInputType, tiledRhs);
174+
auto collapsedOutputType = VectorType::get(
175+
tiledAcc.getType().cast<ShapedType>().getNumElements(),
176+
tiledAcc.getType().cast<ShapedType>().getElementType());
163177
auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
164178
tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
165179

@@ -172,6 +186,11 @@ class LowerContractionToSMMLAPattern
172186
Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
173187
smmlaOp.getLoc(), tiledAcc.getType(), smmlaOp);
174188

189+
// With vecmat, only one row of tiled ACC can be inserted inot file result
190+
if (isVecmat) {
191+
tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
192+
}
193+
175194
// Insert the tiled result back into the non tiled result of the
176195
// contract op.
177196
SmallVector<int64_t> strides(

mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,125 @@ func.func @test_lower_vector_arm_neon_unroll_incompatible_shape(%lhs: vector<4x1
134134
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x12xi32>, vector<4x12xi32> into vector<4x4xi32>
135135
return %res : vector<4x4xi32>
136136
}
137+
138+
// -----
139+
140+
// CHECK-LABEL: func.func @test_lower_vector_arm_neon_vecmat_unroll(
141+
// CHECK-SAME: %[[VAL_0:.*]]: vector<8xi8>,
142+
// CHECK-SAME: %[[VAL_1:.*]]: vector<8x8xi8>,
143+
// CHECK-SAME: %[[VAL_2:.*]]: vector<8xi32>) -> vector<8xi32> {
144+
// CHECK: %[[VAL_3:.*]] = arith.constant dense<0> : vector<8xi32>
145+
// CHECK: %[[VAL_4:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
146+
// CHECK: %[[VAL_5:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32>
147+
// CHECK: %[[VAL_6:.*]] = vector.broadcast %[[VAL_0]] : vector<8xi8> to vector<2x8xi8>
148+
// CHECK: %[[VAL_7:.*]] = vector.broadcast %[[VAL_5]] : vector<2xi32> to vector<2x2xi32>
149+
// CHECK: %[[VAL_8:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x8xi8> to vector<16xi8>
150+
// CHECK: %[[VAL_9:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
151+
// CHECK: %[[VAL_10:.*]] = vector.shape_cast %[[VAL_7]] : vector<2x2xi32> to vector<4xi32>
152+
// CHECK: %[[VAL_11:.*]] = arm_neon.intr.smmla %[[VAL_10]], %[[VAL_8]], %[[VAL_9]] : vector<16xi8> to vector<4xi32>
153+
// CHECK: %[[VAL_12:.*]] = vector.shape_cast %[[VAL_11]] : vector<4xi32> to vector<2x2xi32>
154+
// CHECK: %[[VAL_13:.*]] = vector.extract %[[VAL_12]][0] : vector<2xi32> from vector<2x2xi32>
155+
// CHECK: %[[VAL_14:.*]] = vector.insert_strided_slice %[[VAL_13]], %[[VAL_3]] {offsets = [0], strides = [1]} : vector<2xi32> into vector<8xi32>
156+
// CHECK: %[[VAL_15:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
157+
// CHECK: %[[VAL_16:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32>
158+
// CHECK: %[[VAL_17:.*]] = vector.broadcast %[[VAL_0]] : vector<8xi8> to vector<2x8xi8>
159+
// CHECK: %[[VAL_18:.*]] = vector.broadcast %[[VAL_16]] : vector<2xi32> to vector<2x2xi32>
160+
// CHECK: %[[VAL_19:.*]] = vector.shape_cast %[[VAL_17]] : vector<2x8xi8> to vector<16xi8>
161+
// CHECK: %[[VAL_20:.*]] = vector.shape_cast %[[VAL_15]] : vector<2x8xi8> to vector<16xi8>
162+
// CHECK: %[[VAL_21:.*]] = vector.shape_cast %[[VAL_18]] : vector<2x2xi32> to vector<4xi32>
163+
// CHECK: %[[VAL_22:.*]] = arm_neon.intr.smmla %[[VAL_21]], %[[VAL_19]], %[[VAL_20]] : vector<16xi8> to vector<4xi32>
164+
// CHECK: %[[VAL_23:.*]] = vector.shape_cast %[[VAL_22]] : vector<4xi32> to vector<2x2xi32>
165+
// CHECK: %[[VAL_24:.*]] = vector.extract %[[VAL_23]][0] : vector<2xi32> from vector<2x2xi32>
166+
// CHECK: %[[VAL_25:.*]] = vector.insert_strided_slice %[[VAL_24]], %[[VAL_14]] {offsets = [2], strides = [1]} : vector<2xi32> into vector<8xi32>
167+
// CHECK: %[[VAL_26:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [4, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
168+
// CHECK: %[[VAL_27:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [4], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32>
169+
// CHECK: %[[VAL_28:.*]] = vector.broadcast %[[VAL_0]] : vector<8xi8> to vector<2x8xi8>
170+
// CHECK: %[[VAL_29:.*]] = vector.broadcast %[[VAL_27]] : vector<2xi32> to vector<2x2xi32>
171+
// CHECK: %[[VAL_30:.*]] = vector.shape_cast %[[VAL_28]] : vector<2x8xi8> to vector<16xi8>
172+
// CHECK: %[[VAL_31:.*]] = vector.shape_cast %[[VAL_26]] : vector<2x8xi8> to vector<16xi8>
173+
// CHECK: %[[VAL_32:.*]] = vector.shape_cast %[[VAL_29]] : vector<2x2xi32> to vector<4xi32>
174+
// CHECK: %[[VAL_33:.*]] = arm_neon.intr.smmla %[[VAL_32]], %[[VAL_30]], %[[VAL_31]] : vector<16xi8> to vector<4xi32>
175+
// CHECK: %[[VAL_34:.*]] = vector.shape_cast %[[VAL_33]] : vector<4xi32> to vector<2x2xi32>
176+
// CHECK: %[[VAL_35:.*]] = vector.extract %[[VAL_34]][0] : vector<2xi32> from vector<2x2xi32>
177+
// CHECK: %[[VAL_36:.*]] = vector.insert_strided_slice %[[VAL_35]], %[[VAL_25]] {offsets = [4], strides = [1]} : vector<2xi32> into vector<8xi32>
178+
// CHECK: %[[VAL_37:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [6, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
179+
// CHECK: %[[VAL_38:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [6], sizes = [2], strides = [1]} : vector<8xi32> to vector<2xi32>
180+
// CHECK: %[[VAL_39:.*]] = vector.broadcast %[[VAL_0]] : vector<8xi8> to vector<2x8xi8>
181+
// CHECK: %[[VAL_40:.*]] = vector.broadcast %[[VAL_38]] : vector<2xi32> to vector<2x2xi32>
182+
// CHECK: %[[VAL_41:.*]] = vector.shape_cast %[[VAL_39]] : vector<2x8xi8> to vector<16xi8>
183+
// CHECK: %[[VAL_42:.*]] = vector.shape_cast %[[VAL_37]] : vector<2x8xi8> to vector<16xi8>
184+
// CHECK: %[[VAL_43:.*]] = vector.shape_cast %[[VAL_40]] : vector<2x2xi32> to vector<4xi32>
185+
// CHECK: %[[VAL_44:.*]] = arm_neon.intr.smmla %[[VAL_43]], %[[VAL_41]], %[[VAL_42]] : vector<16xi8> to vector<4xi32>
186+
// CHECK: %[[VAL_45:.*]] = vector.shape_cast %[[VAL_44]] : vector<4xi32> to vector<2x2xi32>
187+
// CHECK: %[[VAL_46:.*]] = vector.extract %[[VAL_45]][0] : vector<2xi32> from vector<2x2xi32>
188+
// CHECK: %[[VAL_47:.*]] = vector.insert_strided_slice %[[VAL_46]], %[[VAL_36]] {offsets = [6], strides = [1]} : vector<2xi32> into vector<8xi32>
189+
// CHECK: return %[[VAL_47]] : vector<8xi32>
190+
// CHECK: }
191+
func.func @test_lower_vector_arm_neon_vecmat_unroll(%lhs: vector<8xi8>, %rhs: vector<8x8xi8>, %acc : vector<8xi32>) -> vector<8xi32> {
192+
%lhs_extsi= arith.extsi %lhs : vector<8xi8> to vector<8xi32>
193+
%rhs_extsi = arith.extsi %rhs : vector<8x8xi8> to vector<8x8xi32>
194+
%res = vector.contract {indexing_maps = [affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<8xi32>, vector<8x8xi32> into vector<8xi32>
195+
return %res : vector<8xi32>
196+
}
197+
198+
// -----
199+
200+
201+
// CHECK-LABEL: func.func @test_lower_vector_arm_neon_vecmat_unroll_leading_dim(
202+
// CHECK-SAME: %[[VAL_0:.*]]: vector<1x8xi8>,
203+
// CHECK-SAME: %[[VAL_1:.*]]: vector<8x8xi8>,
204+
// CHECK-SAME: %[[VAL_2:.*]]: vector<1x8xi32>) -> vector<1x8xi32> {
205+
// CHECK: %[[VAL_3:.*]] = arith.constant dense<0> : vector<1x8xi32>
206+
// CHECK: %[[VAL_4:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
207+
// CHECK: %[[VAL_5:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 0], sizes = [1, 2], strides = [1, 1]} : vector<1x8xi32> to vector<1x2xi32>
208+
// CHECK: %[[VAL_6:.*]] = vector.broadcast %[[VAL_0]] : vector<1x8xi8> to vector<2x8xi8>
209+
// CHECK: %[[VAL_7:.*]] = vector.broadcast %[[VAL_5]] : vector<1x2xi32> to vector<2x2xi32>
210+
// CHECK: %[[VAL_8:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x8xi8> to vector<16xi8>
211+
// CHECK: %[[VAL_9:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
212+
// CHECK: %[[VAL_10:.*]] = vector.shape_cast %[[VAL_7]] : vector<2x2xi32> to vector<4xi32>
213+
// CHECK: %[[VAL_11:.*]] = arm_neon.intr.smmla %[[VAL_10]], %[[VAL_8]], %[[VAL_9]] : vector<16xi8> to vector<4xi32>
214+
// CHECK: %[[VAL_12:.*]] = vector.shape_cast %[[VAL_11]] : vector<4xi32> to vector<2x2xi32>
215+
// CHECK: %[[VAL_13:.*]] = vector.extract %[[VAL_12]][0] : vector<2xi32> from vector<2x2xi32>
216+
// CHECK: %[[VAL_14:.*]] = vector.insert_strided_slice %[[VAL_13]], %[[VAL_3]] {offsets = [0, 0], strides = [1]} : vector<2xi32> into vector<1x8xi32>
217+
// CHECK: %[[VAL_15:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
218+
// CHECK: %[[VAL_16:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 2], sizes = [1, 2], strides = [1, 1]} : vector<1x8xi32> to vector<1x2xi32>
219+
// CHECK: %[[VAL_17:.*]] = vector.broadcast %[[VAL_0]] : vector<1x8xi8> to vector<2x8xi8>
220+
// CHECK: %[[VAL_18:.*]] = vector.broadcast %[[VAL_16]] : vector<1x2xi32> to vector<2x2xi32>
221+
// CHECK: %[[VAL_19:.*]] = vector.shape_cast %[[VAL_17]] : vector<2x8xi8> to vector<16xi8>
222+
// CHECK: %[[VAL_20:.*]] = vector.shape_cast %[[VAL_15]] : vector<2x8xi8> to vector<16xi8>
223+
// CHECK: %[[VAL_21:.*]] = vector.shape_cast %[[VAL_18]] : vector<2x2xi32> to vector<4xi32>
224+
// CHECK: %[[VAL_22:.*]] = arm_neon.intr.smmla %[[VAL_21]], %[[VAL_19]], %[[VAL_20]] : vector<16xi8> to vector<4xi32>
225+
// CHECK: %[[VAL_23:.*]] = vector.shape_cast %[[VAL_22]] : vector<4xi32> to vector<2x2xi32>
226+
// CHECK: %[[VAL_24:.*]] = vector.extract %[[VAL_23]][0] : vector<2xi32> from vector<2x2xi32>
227+
// CHECK: %[[VAL_25:.*]] = vector.insert_strided_slice %[[VAL_24]], %[[VAL_14]] {offsets = [0, 2], strides = [1]} : vector<2xi32> into vector<1x8xi32>
228+
// CHECK: %[[VAL_26:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [4, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
229+
// CHECK: %[[VAL_27:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 4], sizes = [1, 2], strides = [1, 1]} : vector<1x8xi32> to vector<1x2xi32>
230+
// CHECK: %[[VAL_28:.*]] = vector.broadcast %[[VAL_0]] : vector<1x8xi8> to vector<2x8xi8>
231+
// CHECK: %[[VAL_29:.*]] = vector.broadcast %[[VAL_27]] : vector<1x2xi32> to vector<2x2xi32>
232+
// CHECK: %[[VAL_30:.*]] = vector.shape_cast %[[VAL_28]] : vector<2x8xi8> to vector<16xi8>
233+
// CHECK: %[[VAL_31:.*]] = vector.shape_cast %[[VAL_26]] : vector<2x8xi8> to vector<16xi8>
234+
// CHECK: %[[VAL_32:.*]] = vector.shape_cast %[[VAL_29]] : vector<2x2xi32> to vector<4xi32>
235+
// CHECK: %[[VAL_33:.*]] = arm_neon.intr.smmla %[[VAL_32]], %[[VAL_30]], %[[VAL_31]] : vector<16xi8> to vector<4xi32>
236+
// CHECK: %[[VAL_34:.*]] = vector.shape_cast %[[VAL_33]] : vector<4xi32> to vector<2x2xi32>
237+
// CHECK: %[[VAL_35:.*]] = vector.extract %[[VAL_34]][0] : vector<2xi32> from vector<2x2xi32>
238+
// CHECK: %[[VAL_36:.*]] = vector.insert_strided_slice %[[VAL_35]], %[[VAL_25]] {offsets = [0, 4], strides = [1]} : vector<2xi32> into vector<1x8xi32>
239+
// CHECK: %[[VAL_37:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [6, 0], sizes = [2, 8], strides = [1, 1]} : vector<8x8xi8> to vector<2x8xi8>
240+
// CHECK: %[[VAL_38:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 6], sizes = [1, 2], strides = [1, 1]} : vector<1x8xi32> to vector<1x2xi32>
241+
// CHECK: %[[VAL_39:.*]] = vector.broadcast %[[VAL_0]] : vector<1x8xi8> to vector<2x8xi8>
242+
// CHECK: %[[VAL_40:.*]] = vector.broadcast %[[VAL_38]] : vector<1x2xi32> to vector<2x2xi32>
243+
// CHECK: %[[VAL_41:.*]] = vector.shape_cast %[[VAL_39]] : vector<2x8xi8> to vector<16xi8>
244+
// CHECK: %[[VAL_42:.*]] = vector.shape_cast %[[VAL_37]] : vector<2x8xi8> to vector<16xi8>
245+
// CHECK: %[[VAL_43:.*]] = vector.shape_cast %[[VAL_40]] : vector<2x2xi32> to vector<4xi32>
246+
// CHECK: %[[VAL_44:.*]] = arm_neon.intr.smmla %[[VAL_43]], %[[VAL_41]], %[[VAL_42]] : vector<16xi8> to vector<4xi32>
247+
// CHECK: %[[VAL_45:.*]] = vector.shape_cast %[[VAL_44]] : vector<4xi32> to vector<2x2xi32>
248+
// CHECK: %[[VAL_46:.*]] = vector.extract %[[VAL_45]][0] : vector<2xi32> from vector<2x2xi32>
249+
// CHECK: %[[VAL_47:.*]] = vector.insert_strided_slice %[[VAL_46]], %[[VAL_36]] {offsets = [0, 6], strides = [1]} : vector<2xi32> into vector<1x8xi32>
250+
// CHECK: return %[[VAL_47]] : vector<1x8xi32>
251+
// CHECK: }
252+
253+
func.func @test_lower_vector_arm_neon_vecmat_unroll_leading_dim(%lhs: vector<1x8xi8>, %rhs: vector<8x8xi8>, %acc : vector<1x8xi32>) -> vector<1x8xi32> {
254+
%lhs_extsi= arith.extsi %lhs : vector<1x8xi8> to vector<1x8xi32>
255+
%rhs_extsi = arith.extsi %rhs : vector<8x8xi8> to vector<8x8xi32>
256+
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<1x8xi32>, vector<8x8xi32> into vector<1x8xi32>
257+
return %res : vector<1x8xi32>
258+
}

0 commit comments

Comments
 (0)