Skip to content

Commit fe84369

Browse files
authored
[mlir][ArmNeon] Implements unrolling patterns for LowerContractionToSMMLAPattern (#84848)
This patch updates `LowerContractionToSMMLAPattern` to unroll larger vector contracts into multiple smmla instructions. Now accepts up to [8,8,8] tiles (previously only [2,2,8]). The N/M dimensions must be powers of 2. `vector.extract_strided_slice`/`vector.insert_strided_slice` divides the contract into tiles to be processed in a row.
1 parent ab76052 commit fe84369

File tree

2 files changed

+174
-32
lines changed

2 files changed

+174
-32
lines changed

mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp

Lines changed: 80 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
#include "mlir/Dialect/ArmNeon/Transforms.h"
1717
#include "mlir/Dialect/Func/IR/FuncOps.h"
1818
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
19+
#include "mlir/Dialect/Utils/IndexingUtils.h"
1920
#include "mlir/Dialect/Vector/IR/VectorOps.h"
21+
#include "mlir/IR/AffineMap.h"
2022
#include "mlir/IR/PatternMatch.h"
2123
#include "mlir/Support/LogicalResult.h"
2224
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -36,19 +38,17 @@ static Type matchContainerType(Type element, Type container) {
3638
return element;
3739
}
3840

39-
/// Lowering from a single vector::contractOp directly to the arm neon smmla
40-
/// intrinsic. The shapes of the contract and intrinsic must match.
41+
/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
42+
/// any vector.contract into multiple smmla instructions with unrolling so long
43+
/// as [2,2,8] is a divisor of its shape. If no unrolling is necessary, a single
44+
/// smmla instruction is emitted.
4145
class LowerContractionToSMMLAPattern
4246
: public OpRewritePattern<vector::ContractionOp> {
4347
public:
4448
using OpRewritePattern::OpRewritePattern;
4549
LogicalResult matchAndRewrite(vector::ContractionOp op,
4650
PatternRewriter &rewriter) const override {
4751
Location loc = op.getLoc();
48-
Value lhs = op.getLhs();
49-
Value rhs = op.getRhs();
50-
Value res = op.getAcc();
51-
5252
// Check index maps that represent M N K in contract.
5353
auto indexingMaps = op.getIndexingMapsArray();
5454
if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
@@ -57,7 +57,6 @@ class LowerContractionToSMMLAPattern
5757
})) {
5858
return failure();
5959
}
60-
6160
// Check iterator types for contract.
6261
auto iteratorTypes = op.getIteratorTypesArray();
6362
if (iteratorTypes.size() != 3 ||
@@ -66,22 +65,24 @@ class LowerContractionToSMMLAPattern
6665
iteratorTypes[2] != vector::IteratorType::reduction) {
6766
return failure();
6867
}
69-
70-
// Check the tile size by mapping the dimensions of the contract.
68+
// Infer tile sizes from operands; Note: RHS is not transposed.
7169
mlir::VectorType lhsType = op.getLhsType();
7270
mlir::VectorType rhsType = op.getRhsType();
7371
auto dimM = lhsType.getDimSize(0);
7472
auto dimN = rhsType.getDimSize(0);
7573
auto dimK = lhsType.getDimSize(1);
76-
if (rhsType.getDimSize(1) != dimK || dimM != 2 || dimN != 2 || dimK != 8) {
74+
75+
// Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
76+
// tiling.
77+
if (dimM % 2 != 0 || dimN % 2 != 0 || dimK % 8 != 0) {
7778
return failure();
7879
}
7980

8081
// Check two extsi inputs Rhs Lhs for contract.
8182
arith::ExtSIOp origLhsExtOp =
82-
dyn_cast_or_null<arith::ExtSIOp>(lhs.getDefiningOp());
83+
dyn_cast_or_null<arith::ExtSIOp>(op.getLhs().getDefiningOp());
8384
arith::ExtSIOp origRhsExtOp =
84-
dyn_cast_or_null<arith::ExtSIOp>(rhs.getDefiningOp());
85+
dyn_cast_or_null<arith::ExtSIOp>(op.getRhs().getDefiningOp());
8586
if (!origLhsExtOp || !origRhsExtOp) {
8687
return failure();
8788
}
@@ -113,26 +114,73 @@ class LowerContractionToSMMLAPattern
113114
return failure();
114115
}
115116

116-
// Collapse to 1D vectors required by smmla intrinsic
117-
auto collapsedInputType = VectorType::get(
118-
{16}, extsiLhs.getType().cast<ShapedType>().getElementType());
119-
auto collapsedOutputType =
120-
VectorType::get({4}, res.getType().cast<ShapedType>().getElementType());
121-
auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
122-
extsiLhs.getLoc(), collapsedInputType, extsiLhs);
123-
auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
124-
extsiRhs.getLoc(), collapsedInputType, extsiRhs);
125-
auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
126-
res.getLoc(), collapsedOutputType, res);
127-
128-
// Replace the contract with a neon op
129-
auto smmlaOp = rewriter.createOrFold<arm_neon::SmmlaOp>(
130-
op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
131-
collapsedRhs);
132-
133-
// Reshape output back to 2D
134-
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getResultType(),
135-
smmlaOp);
117+
// Initial accumulator for the final result. This is the un-tiled result if
118+
// tiling is done.
119+
Value result = rewriter.create<arith::ConstantOp>(
120+
loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
121+
122+
SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
123+
SmallVector<int64_t> smmlaShape{2, 2, 8};
124+
SmallVector<int64_t> loopOrder{0, 1, 2};
125+
for (SmallVector<int64_t> offsets :
126+
StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
127+
128+
// Helper to compute the new shape of each operand and extract the slice.
129+
auto extractOperand = [&](Value operand, AffineMap permutationMap,
130+
ArrayRef<int64_t> operandOffsets) {
131+
SmallVector<int64_t> operandShape =
132+
applyPermutationMap(permutationMap, ArrayRef<int64_t>(smmlaShape));
133+
SmallVector<int64_t> operandStrides(operandOffsets.size(), 1);
134+
return rewriter.createOrFold<vector::ExtractStridedSliceOp>(
135+
loc, operand, operandOffsets, operandShape, operandStrides);
136+
};
137+
138+
// Extract tiled lhs, rhs, and acc
139+
AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
140+
SmallVector<int64_t> lhsOffsets =
141+
applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
142+
Value tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets);
143+
AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
144+
SmallVector<int64_t> rhsOffsets =
145+
applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
146+
Value tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets);
147+
AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
148+
SmallVector<int64_t> accOffsets =
149+
applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
150+
Value tiledAcc =
151+
extractOperand(op.getAcc(), accPermutationMap, accOffsets);
152+
153+
// Collapse tiled operands to 1D vectors required by smmla intrinsic
154+
auto collapsedInputType = VectorType::get(
155+
tiledLhs.getType().cast<ShapedType>().getNumElements(),
156+
tiledLhs.getType().cast<ShapedType>().getElementType());
157+
auto collapsedOutputType = VectorType::get(
158+
{4}, tiledAcc.getType().cast<ShapedType>().getElementType());
159+
auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
160+
tiledLhs.getLoc(), collapsedInputType, tiledLhs);
161+
auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
162+
tiledRhs.getLoc(), collapsedInputType, tiledRhs);
163+
auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
164+
tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
165+
166+
// Insert contract op
167+
auto smmlaOp = rewriter.createOrFold<arm_neon::SmmlaOp>(
168+
op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
169+
collapsedRhs);
170+
171+
// Reshape output back to 2D
172+
Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
173+
smmlaOp.getLoc(), tiledAcc.getType(), smmlaOp);
174+
175+
// Insert the tiled result back into the non tiled result of the
176+
// contract op.
177+
SmallVector<int64_t> strides(
178+
tiledRes.getType().cast<ShapedType>().getRank(), 1);
179+
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
180+
loc, tiledRes, result, accOffsets, strides);
181+
}
182+
183+
rewriter.replaceOp(op, result);
136184
return success();
137185
}
138186
};

mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,97 @@ func.func @test_lower_vector_arm_neon_without_extsi(%lhs: vector<2x8xi32>, %rhs:
4040
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %acc : vector<2x8xi32>, vector<2x8xi32> into vector<2x2xi32>
4141
return %res : vector<2x2xi32>
4242
}
43+
44+
// -----
45+
46+
// CHECK-LABEL: test_lower_vector_arm_neon_unroll
47+
// CHECK-SAME: %[[VAL_0:.*]]: vector<4x8xi8>, %[[VAL_1:.*]]: vector<4x8xi8>, %[[VAL_2:.*]]: vector<4x4xi32>
48+
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<0> : vector<4x4xi32>
49+
// CHECK-DAG: %[[VAL_4:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
50+
// CHECK-DAG: %[[VAL_5:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
51+
// CHECK-DAG: %[[VAL_6:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
52+
// CHECK-DAG: %[[VAL_7:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
53+
// CHECK-DAG: %[[VAL_8:.*]] = vector.shape_cast %[[VAL_5]] : vector<2x8xi8> to vector<16xi8>
54+
// CHECK-DAG: %[[VAL_9:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x2xi32> to vector<4xi32>
55+
// CHECK-DAG: %[[VAL_10:.*]] = arm_neon.intr.smmla %[[VAL_9]], %[[VAL_7]], %[[VAL_8]] : vector<16xi8> to vector<4xi32>
56+
// CHECK-DAG: %[[VAL_11:.*]] = vector.shape_cast %[[VAL_10]] : vector<4xi32> to vector<2x2xi32>
57+
// CHECK-DAG: %[[VAL_12:.*]] = vector.insert_strided_slice %[[VAL_11]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
58+
// CHECK-DAG: %[[VAL_13:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
59+
// CHECK-DAG: %[[VAL_14:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
60+
// CHECK-DAG: %[[VAL_15:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
61+
// CHECK-DAG: %[[VAL_16:.*]] = vector.shape_cast %[[VAL_13]] : vector<2x8xi8> to vector<16xi8>
62+
// CHECK-DAG: %[[VAL_17:.*]] = vector.shape_cast %[[VAL_14]] : vector<2x8xi8> to vector<16xi8>
63+
// CHECK-DAG: %[[VAL_18:.*]] = vector.shape_cast %[[VAL_15]] : vector<2x2xi32> to vector<4xi32>
64+
// CHECK-DAG: %[[VAL_19:.*]] = arm_neon.intr.smmla %[[VAL_18]], %[[VAL_16]], %[[VAL_17]] : vector<16xi8> to vector<4xi32>
65+
// CHECK-DAG: %[[VAL_20:.*]] = vector.shape_cast %[[VAL_19]] : vector<4xi32> to vector<2x2xi32>
66+
// CHECK-DAG: %[[VAL_21:.*]] = vector.insert_strided_slice %[[VAL_20]], %[[VAL_12]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
67+
// CHECK-DAG: %[[VAL_22:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
68+
// CHECK-DAG: %[[VAL_23:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
69+
// CHECK-DAG: %[[VAL_24:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
70+
// CHECK-DAG: %[[VAL_25:.*]] = vector.shape_cast %[[VAL_22]] : vector<2x8xi8> to vector<16xi8>
71+
// CHECK-DAG: %[[VAL_26:.*]] = vector.shape_cast %[[VAL_23]] : vector<2x8xi8> to vector<16xi8>
72+
// CHECK-DAG: %[[VAL_27:.*]] = vector.shape_cast %[[VAL_24]] : vector<2x2xi32> to vector<4xi32>
73+
// CHECK-DAG: %[[VAL_28:.*]] = arm_neon.intr.smmla %[[VAL_27]], %[[VAL_25]], %[[VAL_26]] : vector<16xi8> to vector<4xi32>
74+
// CHECK-DAG: %[[VAL_29:.*]] = vector.shape_cast %[[VAL_28]] : vector<4xi32> to vector<2x2xi32>
75+
// CHECK-DAG: %[[VAL_30:.*]] = vector.insert_strided_slice %[[VAL_29]], %[[VAL_21]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
76+
// CHECK-DAG: %[[VAL_31:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
77+
// CHECK-DAG: %[[VAL_32:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
78+
// CHECK-DAG: %[[VAL_33:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
79+
// CHECK-DAG: %[[VAL_34:.*]] = vector.shape_cast %[[VAL_31]] : vector<2x8xi8> to vector<16xi8>
80+
// CHECK-DAG: %[[VAL_35:.*]] = vector.shape_cast %[[VAL_32]] : vector<2x8xi8> to vector<16xi8>
81+
// CHECK-DAG: %[[VAL_36:.*]] = vector.shape_cast %[[VAL_33]] : vector<2x2xi32> to vector<4xi32>
82+
// CHECK-DAG: %[[VAL_37:.*]] = arm_neon.intr.smmla %[[VAL_36]], %[[VAL_34]], %[[VAL_35]] : vector<16xi8> to vector<4xi32>
83+
// CHECK-DAG: %[[VAL_38:.*]] = vector.shape_cast %[[VAL_37]] : vector<4xi32> to vector<2x2xi32>
84+
// CHECK-DAG: %[[VAL_39:.*]] = vector.insert_strided_slice %[[VAL_38]], %[[VAL_30]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
85+
// CHECK-DAG: return %[[VAL_39]] : vector<4x4xi32>
86+
// CHECK-DAG: }
87+
func.func @test_lower_vector_arm_neon_unroll(%lhs: vector<4x8xi8>, %rhs: vector<4x8xi8>, %acc : vector<4x4xi32>) -> vector<4x4xi32> {
88+
%lhs_extsi = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
89+
%rhs_extsi = arith.extsi %rhs : vector<4x8xi8> to vector<4x8xi32>
90+
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x8xi32>, vector<4x8xi32> into vector<4x4xi32>
91+
return %res : vector<4x4xi32>
92+
}
93+
94+
// -----
95+
96+
// CHECK-LABEL: func.func @test_lower_vector_arm_neon_mixed_unroll(
97+
// CHECK-SAME: %[[VAL_0:.*]]: vector<4x8xi8>,
98+
// CHECK-SAME: %[[VAL_1:.*]]: vector<2x8xi4>,
99+
// CHECK-SAME: %[[VAL_2:.*]]: vector<4x2xi32>) -> vector<4x2xi32> {
100+
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<0> : vector<4x2xi32>
101+
// CHECK-DAG: %[[VAL_4:.*]] = arith.extsi %[[VAL_1]] : vector<2x8xi4> to vector<2x8xi8>
102+
// CHECK-DAG: %[[VAL_5:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
103+
// CHECK-DAG: %[[VAL_6:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xi32> to vector<2x2xi32>
104+
// CHECK-DAG: %[[VAL_7:.*]] = vector.shape_cast %[[VAL_5]] : vector<2x8xi8> to vector<16xi8>
105+
// CHECK-DAG: %[[VAL_8:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
106+
// CHECK-DAG: %[[VAL_9:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x2xi32> to vector<4xi32>
107+
// CHECK-DAG: %[[VAL_10:.*]] = arm_neon.intr.smmla %[[VAL_9]], %[[VAL_7]], %[[VAL_8]] : vector<16xi8> to vector<4xi32>
108+
// CHECK-DAG: %[[VAL_11:.*]] = vector.shape_cast %[[VAL_10]] : vector<4xi32> to vector<2x2xi32>
109+
// CHECK-DAG: %[[VAL_12:.*]] = vector.insert_strided_slice %[[VAL_11]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x2xi32>
110+
// CHECK-DAG: %[[VAL_13:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
111+
// CHECK-DAG: %[[VAL_14:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xi32> to vector<2x2xi32>
112+
// CHECK-DAG: %[[VAL_15:.*]] = vector.shape_cast %[[VAL_13]] : vector<2x8xi8> to vector<16xi8>
113+
// CHECK-DAG: %[[VAL_16:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
114+
// CHECK-DAG: %[[VAL_17:.*]] = vector.shape_cast %[[VAL_14]] : vector<2x2xi32> to vector<4xi32>
115+
// CHECK-DAG: %[[VAL_18:.*]] = arm_neon.intr.smmla %[[VAL_17]], %[[VAL_15]], %[[VAL_16]] : vector<16xi8> to vector<4xi32>
116+
// CHECK-DAG: %[[VAL_19:.*]] = vector.shape_cast %[[VAL_18]] : vector<4xi32> to vector<2x2xi32>
117+
// CHECK-DAG: %[[VAL_20:.*]] = vector.insert_strided_slice %[[VAL_19]], %[[VAL_12]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x2xi32>
118+
// CHECK-DAG: return %[[VAL_20]] : vector<4x2xi32>
119+
// CHECK-DAG: }
120+
func.func @test_lower_vector_arm_neon_mixed_unroll(%lhs: vector<4x8xi8>, %rhs: vector<2x8xi4>, %acc : vector<4x2xi32>) -> vector<4x2xi32> {
121+
%lhs_extsi = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
122+
%rhs_extsi = arith.extsi %rhs : vector<2x8xi4> to vector<2x8xi32>
123+
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x8xi32>, vector<2x8xi32> into vector<4x2xi32>
124+
return %res : vector<4x2xi32>
125+
}
126+
127+
// -----
128+
129+
// CHECK-LABEL: func.func @test_lower_vector_arm_neon_unroll_incompatible_shape(
130+
// CHECK-DAG: %[[result:.*]] = vector.contract
131+
func.func @test_lower_vector_arm_neon_unroll_incompatible_shape(%lhs: vector<4x12xi8>, %rhs: vector<4x12xi8>, %acc : vector<4x4xi32>) -> vector<4x4xi32> {
132+
%lhs_extsi = arith.extsi %lhs : vector<4x12xi8> to vector<4x12xi32>
133+
%rhs_extsi = arith.extsi %rhs : vector<4x12xi8> to vector<4x12xi32>
134+
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x12xi32>, vector<4x12xi32> into vector<4x4xi32>
135+
return %res : vector<4x4xi32>
136+
}

0 commit comments

Comments
 (0)