Skip to content

Commit 1726b65

Browse files
authored
[MLIR][Vector] Refactor tests for contract -> OP transforms (4/N) (#73807)
This patch refactors tests for: vector.contract -> vector.outerproduct for matvec operations (b += Ax). Summary of changes: * add 2 missing cases (masked + scalable) when the operation kind is `maxf`. This is a part of a larger effort to add cases with scalable vectors to tests for the Vector dialect. Implements #72834.
1 parent 9557fcc commit 1726b65

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

mlir/test/Dialect/Vector/vector-contract-to-outerproduct-matvec-transforms.mlir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,52 @@ func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>,
186186
return %0 : vector<2xf32>
187187
}
188188

189+
// CHECK-LABEL: func.func @masked_matvec_mk_k_m_max(
190+
// CHECK-SAME: %{{.*}}: vector<2x3xf32>,
191+
// CHECK-SAME: %{{.*}}: vector<3xf32>,
192+
// CHECK-SAME: %{{.*}}: vector<2xf32>,
193+
// CHECK-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
194+
// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
195+
// CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<2xi1> from vector<3x2xi1>
196+
// CHECK: vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxf>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
197+
198+
// CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<2xi1> from vector<3x2xi1>
199+
// CHECK: vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxf>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
200+
201+
// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1>
202+
// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxf>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
203+
func.func @masked_matvec_mk_k_m_max(%A: vector<2x3xf32>,
204+
%x: vector<3xf32>,
205+
%b: vector<2xf32>,
206+
%m: vector<2x3xi1>) -> vector<2xf32> {
207+
%0 = vector.mask %m { vector.contract #matvecmax_trait %A, %x, %b
208+
: vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32>
209+
return %0 : vector<2xf32>
210+
}
211+
212+
// CHECK-LABEL: func.func @masked_matvec_mk_k_m_max_scalable_parallel_dim(
213+
// CHECK-SAME: %{{.*}}: vector<[2]x3xf32>,
214+
// CHECK-SAME: %{{.*}}: vector<3xf32>,
215+
// CHECK-SAME: %{{.*}}: vector<[2]xf32>,
216+
// CHECK-SAME: %[[IN_MASK:.*]]: vector<[2]x3xi1>) -> vector<[2]xf32>
217+
// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<[2]x3xi1> to vector<3x[2]xi1>
218+
// CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<[2]xi1> from vector<3x[2]xi1>
219+
// CHECK: vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxf>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
220+
221+
// CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<[2]xi1> from vector<3x[2]xi1>
222+
// CHECK: vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxf>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
223+
224+
// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<3x[2]xi1>
225+
// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxf>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
226+
func.func @masked_matvec_mk_k_m_max_scalable_parallel_dim(%A: vector<[2]x3xf32>,
227+
%x: vector<3xf32>,
228+
%b: vector<[2]xf32>,
229+
%m: vector<[2]x3xi1>) -> vector<[2]xf32> {
230+
%0 = vector.mask %m { vector.contract #matvecmax_trait %A, %x, %b
231+
: vector<[2]x3xf32>, vector<3xf32> into vector<[2]xf32> } : vector<[2]x3xi1> -> vector<[2]xf32>
232+
return %0 : vector<[2]xf32>
233+
}
234+
189235
// ============================================================================
190236
// Matvec 2 (plain + masked + scalable)
191237
// ============================================================================

0 commit comments

Comments
 (0)