Skip to content

Commit f6cb2ee

Browse files
committed
diego comments
1 parent ae74220 commit f6cb2ee

File tree

2 files changed

+92
-48
lines changed

2 files changed

+92
-48
lines changed

mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ static Type matchContainerType(Type element, Type container) {
3838
return element;
3939
}
4040

41-
/// Lowering from a vector::contractOp arm neon smmla intrinsic. This up to an
42-
/// 8x8x8 vector contract that is tiled (up to 16) smmla instructions with
43-
/// unrolling. If no unrolling is necessary, a single smmla instruction is
44-
/// emitted.
41+
/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
42+
/// any vector.contract into multiple smmla instructions with unrolling so long
43+
/// as [2,2,8] is a divisor of its shape. If no unrolling is necessary, a single
44+
/// smmla instruction is emitted.
4545
class LowerContractionToSMMLAPattern
4646
: public OpRewritePattern<vector::ContractionOp> {
4747
public:
@@ -72,9 +72,9 @@ class LowerContractionToSMMLAPattern
7272
auto dimN = rhsType.getDimSize(0);
7373
auto dimK = lhsType.getDimSize(1);
7474

75-
// Unrolling patterns can handle [(2|4|8), (2|4|8), 8] shaped inputs for
75+
// Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
7676
// tiling.
77-
if (dimM % 2 != 0 || dimM > 8 || dimN % 2 != 0 || dimN > 8 || dimK != 8) {
77+
if (dimM % 2 != 0 || dimN % 2 != 0 || dimK % 8 != 0) {
7878
return failure();
7979
}
8080

@@ -139,15 +139,15 @@ class LowerContractionToSMMLAPattern
139139
AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0];
140140
SmallVector<int64_t> lhsOffsets =
141141
applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
142-
auto tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets);
142+
Value tiledLhs = extractOperand(extsiLhs, lhsPermutationMap, lhsOffsets);
143143
AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1];
144144
SmallVector<int64_t> rhsOffsets =
145145
applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
146-
auto tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets);
146+
Value tiledRhs = extractOperand(extsiRhs, rhsPermutationMap, rhsOffsets);
147147
AffineMap accPermutationMap = op.getIndexingMapsArray()[2];
148148
SmallVector<int64_t> accOffsets =
149149
applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
150-
auto tiledAcc =
150+
Value tiledAcc =
151151
extractOperand(op.getAcc(), accPermutationMap, accOffsets);
152152

153153
// Collapse tiled operands to 1D vectors required by smmla intrinsic

mlir/test/Dialect/ArmNeon/lower-to-arm-neon.mlir

Lines changed: 83 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -45,48 +45,92 @@ func.func @test_lower_vector_arm_neon_without_extsi(%lhs: vector<2x8xi32>, %rhs:
4545

4646
// CHECK-LABEL: test_lower_vector_arm_neon_unroll
4747
// CHECK-SAME: %[[VAL_0:.*]]: vector<4x8xi8>, %[[VAL_1:.*]]: vector<4x8xi8>, %[[VAL_2:.*]]: vector<4x4xi32>
48-
// CHECK: %[[VAL_3:.*]] = arith.constant dense<0> : vector<4x4xi32>
49-
// CHECK: %[[VAL_4:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
50-
// CHECK: %[[VAL_5:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
51-
// CHECK: %[[VAL_6:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
52-
// CHECK: %[[VAL_7:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
53-
// CHECK: %[[VAL_8:.*]] = vector.shape_cast %[[VAL_5]] : vector<2x8xi8> to vector<16xi8>
54-
// CHECK: %[[VAL_9:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x2xi32> to vector<4xi32>
55-
// CHECK: %[[VAL_10:.*]] = arm_neon.intr.smmla %[[VAL_9]], %[[VAL_7]], %[[VAL_8]] : vector<16xi8> to vector<4xi32>
56-
// CHECK: %[[VAL_11:.*]] = vector.shape_cast %[[VAL_10]] : vector<4xi32> to vector<2x2xi32>
57-
// CHECK: %[[VAL_12:.*]] = vector.insert_strided_slice %[[VAL_11]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
58-
// CHECK: %[[VAL_13:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
59-
// CHECK: %[[VAL_14:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
60-
// CHECK: %[[VAL_15:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
61-
// CHECK: %[[VAL_16:.*]] = vector.shape_cast %[[VAL_13]] : vector<2x8xi8> to vector<16xi8>
62-
// CHECK: %[[VAL_17:.*]] = vector.shape_cast %[[VAL_14]] : vector<2x8xi8> to vector<16xi8>
63-
// CHECK: %[[VAL_18:.*]] = vector.shape_cast %[[VAL_15]] : vector<2x2xi32> to vector<4xi32>
64-
// CHECK: %[[VAL_19:.*]] = arm_neon.intr.smmla %[[VAL_18]], %[[VAL_16]], %[[VAL_17]] : vector<16xi8> to vector<4xi32>
65-
// CHECK: %[[VAL_20:.*]] = vector.shape_cast %[[VAL_19]] : vector<4xi32> to vector<2x2xi32>
66-
// CHECK: %[[VAL_21:.*]] = vector.insert_strided_slice %[[VAL_20]], %[[VAL_12]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
67-
// CHECK: %[[VAL_22:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
68-
// CHECK: %[[VAL_23:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
69-
// CHECK: %[[VAL_24:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
70-
// CHECK: %[[VAL_25:.*]] = vector.shape_cast %[[VAL_22]] : vector<2x8xi8> to vector<16xi8>
71-
// CHECK: %[[VAL_26:.*]] = vector.shape_cast %[[VAL_23]] : vector<2x8xi8> to vector<16xi8>
72-
// CHECK: %[[VAL_27:.*]] = vector.shape_cast %[[VAL_24]] : vector<2x2xi32> to vector<4xi32>
73-
// CHECK: %[[VAL_28:.*]] = arm_neon.intr.smmla %[[VAL_27]], %[[VAL_25]], %[[VAL_26]] : vector<16xi8> to vector<4xi32>
74-
// CHECK: %[[VAL_29:.*]] = vector.shape_cast %[[VAL_28]] : vector<4xi32> to vector<2x2xi32>
75-
// CHECK: %[[VAL_30:.*]] = vector.insert_strided_slice %[[VAL_29]], %[[VAL_21]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
76-
// CHECK: %[[VAL_31:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
77-
// CHECK: %[[VAL_32:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
78-
// CHECK: %[[VAL_33:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
79-
// CHECK: %[[VAL_34:.*]] = vector.shape_cast %[[VAL_31]] : vector<2x8xi8> to vector<16xi8>
80-
// CHECK: %[[VAL_35:.*]] = vector.shape_cast %[[VAL_32]] : vector<2x8xi8> to vector<16xi8>
81-
// CHECK: %[[VAL_36:.*]] = vector.shape_cast %[[VAL_33]] : vector<2x2xi32> to vector<4xi32>
82-
// CHECK: %[[VAL_37:.*]] = arm_neon.intr.smmla %[[VAL_36]], %[[VAL_34]], %[[VAL_35]] : vector<16xi8> to vector<4xi32>
83-
// CHECK: %[[VAL_38:.*]] = vector.shape_cast %[[VAL_37]] : vector<4xi32> to vector<2x2xi32>
84-
// CHECK: %[[VAL_39:.*]] = vector.insert_strided_slice %[[VAL_38]], %[[VAL_30]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
85-
// CHECK: return %[[VAL_39]] : vector<4x4xi32>
86-
// CHECK: }
48+
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<0> : vector<4x4xi32>
49+
// CHECK-DAG: %[[VAL_4:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
50+
// CHECK-DAG: %[[VAL_5:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
51+
// CHECK-DAG: %[[VAL_6:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
52+
// CHECK-DAG: %[[VAL_7:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
53+
// CHECK-DAG: %[[VAL_8:.*]] = vector.shape_cast %[[VAL_5]] : vector<2x8xi8> to vector<16xi8>
54+
// CHECK-DAG: %[[VAL_9:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x2xi32> to vector<4xi32>
55+
// CHECK-DAG: %[[VAL_10:.*]] = arm_neon.intr.smmla %[[VAL_9]], %[[VAL_7]], %[[VAL_8]] : vector<16xi8> to vector<4xi32>
56+
// CHECK-DAG: %[[VAL_11:.*]] = vector.shape_cast %[[VAL_10]] : vector<4xi32> to vector<2x2xi32>
57+
// CHECK-DAG: %[[VAL_12:.*]] = vector.insert_strided_slice %[[VAL_11]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
58+
// CHECK-DAG: %[[VAL_13:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
59+
// CHECK-DAG: %[[VAL_14:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
60+
// CHECK-DAG: %[[VAL_15:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
61+
// CHECK-DAG: %[[VAL_16:.*]] = vector.shape_cast %[[VAL_13]] : vector<2x8xi8> to vector<16xi8>
62+
// CHECK-DAG: %[[VAL_17:.*]] = vector.shape_cast %[[VAL_14]] : vector<2x8xi8> to vector<16xi8>
63+
// CHECK-DAG: %[[VAL_18:.*]] = vector.shape_cast %[[VAL_15]] : vector<2x2xi32> to vector<4xi32>
64+
// CHECK-DAG: %[[VAL_19:.*]] = arm_neon.intr.smmla %[[VAL_18]], %[[VAL_16]], %[[VAL_17]] : vector<16xi8> to vector<4xi32>
65+
// CHECK-DAG: %[[VAL_20:.*]] = vector.shape_cast %[[VAL_19]] : vector<4xi32> to vector<2x2xi32>
66+
// CHECK-DAG: %[[VAL_21:.*]] = vector.insert_strided_slice %[[VAL_20]], %[[VAL_12]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
67+
// CHECK-DAG: %[[VAL_22:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
68+
// CHECK-DAG: %[[VAL_23:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
69+
// CHECK-DAG: %[[VAL_24:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
70+
// CHECK-DAG: %[[VAL_25:.*]] = vector.shape_cast %[[VAL_22]] : vector<2x8xi8> to vector<16xi8>
71+
// CHECK-DAG: %[[VAL_26:.*]] = vector.shape_cast %[[VAL_23]] : vector<2x8xi8> to vector<16xi8>
72+
// CHECK-DAG: %[[VAL_27:.*]] = vector.shape_cast %[[VAL_24]] : vector<2x2xi32> to vector<4xi32>
73+
// CHECK-DAG: %[[VAL_28:.*]] = arm_neon.intr.smmla %[[VAL_27]], %[[VAL_25]], %[[VAL_26]] : vector<16xi8> to vector<4xi32>
74+
// CHECK-DAG: %[[VAL_29:.*]] = vector.shape_cast %[[VAL_28]] : vector<4xi32> to vector<2x2xi32>
75+
// CHECK-DAG: %[[VAL_30:.*]] = vector.insert_strided_slice %[[VAL_29]], %[[VAL_21]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
76+
// CHECK-DAG: %[[VAL_31:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
77+
// CHECK-DAG: %[[VAL_32:.*]] = vector.extract_strided_slice %[[VAL_1]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
78+
// CHECK-DAG: %[[VAL_33:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xi32> to vector<2x2xi32>
79+
// CHECK-DAG: %[[VAL_34:.*]] = vector.shape_cast %[[VAL_31]] : vector<2x8xi8> to vector<16xi8>
80+
// CHECK-DAG: %[[VAL_35:.*]] = vector.shape_cast %[[VAL_32]] : vector<2x8xi8> to vector<16xi8>
81+
// CHECK-DAG: %[[VAL_36:.*]] = vector.shape_cast %[[VAL_33]] : vector<2x2xi32> to vector<4xi32>
82+
// CHECK-DAG: %[[VAL_37:.*]] = arm_neon.intr.smmla %[[VAL_36]], %[[VAL_34]], %[[VAL_35]] : vector<16xi8> to vector<4xi32>
83+
// CHECK-DAG: %[[VAL_38:.*]] = vector.shape_cast %[[VAL_37]] : vector<4xi32> to vector<2x2xi32>
84+
// CHECK-DAG: %[[VAL_39:.*]] = vector.insert_strided_slice %[[VAL_38]], %[[VAL_30]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xi32> into vector<4x4xi32>
85+
// CHECK-DAG: return %[[VAL_39]] : vector<4x4xi32>
86+
// CHECK-DAG: }
8787
func.func @test_lower_vector_arm_neon_unroll(%lhs: vector<4x8xi8>, %rhs: vector<4x8xi8>, %acc : vector<4x4xi32>) -> vector<4x4xi32> {
8888
%lhs_extsi = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
8989
%rhs_extsi = arith.extsi %rhs : vector<4x8xi8> to vector<4x8xi32>
9090
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x8xi32>, vector<4x8xi32> into vector<4x4xi32>
9191
return %res : vector<4x4xi32>
9292
}
93+
94+
// -----
95+
96+
// CHECK-LABEL: func.func @test_lower_vector_arm_neon_mixed_unroll(
97+
// CHECK-SAME: %[[VAL_0:.*]]: vector<4x8xi8>,
98+
// CHECK-SAME: %[[VAL_1:.*]]: vector<2x8xi4>,
99+
// CHECK-SAME: %[[VAL_2:.*]]: vector<4x2xi32>) -> vector<4x2xi32> {
100+
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant dense<0> : vector<4x2xi32>
101+
// CHECK-DAG: %[[VAL_4:.*]] = arith.extsi %[[VAL_1]] : vector<2x8xi4> to vector<2x8xi8>
102+
// CHECK-DAG: %[[VAL_5:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [0, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
103+
// CHECK-DAG: %[[VAL_6:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xi32> to vector<2x2xi32>
104+
// CHECK-DAG: %[[VAL_7:.*]] = vector.shape_cast %[[VAL_5]] : vector<2x8xi8> to vector<16xi8>
105+
// CHECK-DAG: %[[VAL_8:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
106+
// CHECK-DAG: %[[VAL_9:.*]] = vector.shape_cast %[[VAL_6]] : vector<2x2xi32> to vector<4xi32>
107+
// CHECK-DAG: %[[VAL_10:.*]] = arm_neon.intr.smmla %[[VAL_9]], %[[VAL_7]], %[[VAL_8]] : vector<16xi8> to vector<4xi32>
108+
// CHECK-DAG: %[[VAL_11:.*]] = vector.shape_cast %[[VAL_10]] : vector<4xi32> to vector<2x2xi32>
109+
// CHECK-DAG: %[[VAL_12:.*]] = vector.insert_strided_slice %[[VAL_11]], %[[VAL_3]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x2xi32>
110+
// CHECK-DAG: %[[VAL_13:.*]] = vector.extract_strided_slice %[[VAL_0]] {offsets = [2, 0], sizes = [2, 8], strides = [1, 1]} : vector<4x8xi8> to vector<2x8xi8>
111+
// CHECK-DAG: %[[VAL_14:.*]] = vector.extract_strided_slice %[[VAL_2]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xi32> to vector<2x2xi32>
112+
// CHECK-DAG: %[[VAL_15:.*]] = vector.shape_cast %[[VAL_13]] : vector<2x8xi8> to vector<16xi8>
113+
// CHECK-DAG: %[[VAL_16:.*]] = vector.shape_cast %[[VAL_4]] : vector<2x8xi8> to vector<16xi8>
114+
// CHECK-DAG: %[[VAL_17:.*]] = vector.shape_cast %[[VAL_14]] : vector<2x2xi32> to vector<4xi32>
115+
// CHECK-DAG: %[[VAL_18:.*]] = arm_neon.intr.smmla %[[VAL_17]], %[[VAL_15]], %[[VAL_16]] : vector<16xi8> to vector<4xi32>
116+
// CHECK-DAG: %[[VAL_19:.*]] = vector.shape_cast %[[VAL_18]] : vector<4xi32> to vector<2x2xi32>
117+
// CHECK-DAG: %[[VAL_20:.*]] = vector.insert_strided_slice %[[VAL_19]], %[[VAL_12]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xi32> into vector<4x2xi32>
118+
// CHECK-DAG: return %[[VAL_20]] : vector<4x2xi32>
119+
// CHECK-DAG: }
120+
func.func @test_lower_vector_arm_neon_mixed_unroll(%lhs: vector<4x8xi8>, %rhs: vector<2x8xi4>, %acc : vector<4x2xi32>) -> vector<4x2xi32> {
121+
%lhs_extsi = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
122+
%rhs_extsi = arith.extsi %rhs : vector<2x8xi4> to vector<2x8xi32>
123+
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x8xi32>, vector<2x8xi32> into vector<4x2xi32>
124+
return %res : vector<4x2xi32>
125+
}
126+
127+
// -----
128+
129+
// CHECK-LABEL: func.func @test_lower_vector_arm_neon_unroll_incompatible_shape(
130+
// CHECK-DAG: %[[result:.*]] = vector.contract
131+
func.func @test_lower_vector_arm_neon_unroll_incompatible_shape(%lhs: vector<4x12xi8>, %rhs: vector<4x12xi8>, %acc : vector<4x4xi32>) -> vector<4x4xi32> {
132+
%lhs_extsi = arith.extsi %lhs : vector<4x12xi8> to vector<4x12xi32>
133+
%rhs_extsi = arith.extsi %rhs : vector<4x12xi8> to vector<4x12xi32>
134+
%res = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs_extsi, %rhs_extsi, %acc : vector<4x12xi32>, vector<4x12xi32> into vector<4x4xi32>
135+
return %res : vector<4x4xi32>
136+
}

0 commit comments

Comments
 (0)