@@ -186,6 +186,52 @@ func.func @matvec_mk_k_m_max(%A: vector<2x2xf32>,
186
186
return %0 : vector <2 xf32 >
187
187
}
188
188
189
+ // CHECK-LABEL: func.func @masked_matvec_mk_k_m_max(
190
+ // CHECK-SAME: %{{.*}}: vector<2x3xf32>,
191
+ // CHECK-SAME: %{{.*}}: vector<3xf32>,
192
+ // CHECK-SAME: %{{.*}}: vector<2xf32>,
193
+ // CHECK-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
194
+ // CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
195
+ // CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<2xi1> from vector<3x2xi1>
196
+ // CHECK: vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxf>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
197
+
198
+ // CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<2xi1> from vector<3x2xi1>
199
+ // CHECK: vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxf>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
200
+
201
+ // CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1>
202
+ // CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxf>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
203
+ func.func @masked_matvec_mk_k_m_max (%A: vector <2 x3 xf32 >,
204
+ %x: vector <3 xf32 >,
205
+ %b: vector <2 xf32 >,
206
+ %m: vector <2 x3 xi1 >) -> vector <2 xf32 > {
207
+ %0 = vector.mask %m { vector.contract #matvecmax_trait %A , %x , %b
208
+ : vector <2 x3 xf32 >, vector <3 xf32 > into vector <2 xf32 > } : vector <2 x3 xi1 > -> vector <2 xf32 >
209
+ return %0 : vector <2 xf32 >
210
+ }
211
+
212
+ // CHECK-LABEL: func.func @masked_matvec_mk_k_m_max_scalable_parallel_dim(
213
+ // CHECK-SAME: %{{.*}}: vector<[2]x3xf32>,
214
+ // CHECK-SAME: %{{.*}}: vector<3xf32>,
215
+ // CHECK-SAME: %{{.*}}: vector<[2]xf32>,
216
+ // CHECK-SAME: %[[IN_MASK:.*]]: vector<[2]x3xi1>) -> vector<[2]xf32>
217
+ // CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<[2]x3xi1> to vector<3x[2]xi1>
218
+ // CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<[2]xi1> from vector<3x[2]xi1>
219
+ // CHECK: vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxf>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
220
+
221
+ // CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<[2]xi1> from vector<3x[2]xi1>
222
+ // CHECK: vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxf>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
223
+
224
+ // CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<3x[2]xi1>
225
+ // CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<maxf>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
226
+ func.func @masked_matvec_mk_k_m_max_scalable_parallel_dim (%A: vector <[2 ]x3 xf32 >,
227
+ %x: vector <3 xf32 >,
228
+ %b: vector <[2 ]xf32 >,
229
+ %m: vector <[2 ]x3 xi1 >) -> vector <[2 ]xf32 > {
230
+ %0 = vector.mask %m { vector.contract #matvecmax_trait %A , %x , %b
231
+ : vector <[2 ]x3 xf32 >, vector <3 xf32 > into vector <[2 ]xf32 > } : vector <[2 ]x3 xi1 > -> vector <[2 ]xf32 >
232
+ return %0 : vector <[2 ]xf32 >
233
+ }
234
+
189
235
// ============================================================================
190
236
// Matvec 2 (plain + masked + scalable)
191
237
// ============================================================================
0 commit comments