@@ -240,4 +240,134 @@ TEST(isRowMajorBatchMatmul, FirstInputSwapped) {
240
240
EXPECT_THAT (maps, Not (Truly (isRowMajorBatchMatmul)));
241
241
}
242
242
243
+ TEST (isVecmat, Simple) {
244
+ MLIRContext context;
245
+
246
+ AffineExpr k, n;
247
+ bindDims (&context, k, n);
248
+ auto mapA = AffineMapAttr::get (AffineMap::get (2 , 0 , {k}, &context));
249
+ auto mapB = AffineMapAttr::get (AffineMap::get (2 , 0 , {k, n}, &context));
250
+ auto mapC = AffineMapAttr::get (AffineMap::get (2 , 0 , {n}, &context));
251
+ auto maps = ArrayAttr::get (&context, {mapA, mapB, mapC});
252
+
253
+ EXPECT_THAT (maps, Truly (isVecmat));
254
+ }
255
+
256
+ TEST (isVecmat, BindingSwapped) {
257
+ MLIRContext context;
258
+
259
+ AffineExpr k, n;
260
+ bindDims (&context, n, k); // bind in different order
261
+ auto mapA = AffineMapAttr::get (AffineMap::get (2 , 0 , {k}, &context));
262
+ auto mapB = AffineMapAttr::get (AffineMap::get (2 , 0 , {k, n}, &context));
263
+ auto mapC = AffineMapAttr::get (AffineMap::get (2 , 0 , {n}, &context));
264
+ auto maps = ArrayAttr::get (&context, {mapA, mapB, mapC});
265
+
266
+ EXPECT_THAT (maps, Truly (isVecmat));
267
+ }
268
+
269
+ TEST (isVecmat, WrongDimOrderMatrix) {
270
+ MLIRContext context;
271
+
272
+ AffineExpr k, n;
273
+ bindDims (&context, k, n);
274
+ auto mapA = AffineMapAttr::get (AffineMap::get (2 , 0 , {k}, &context));
275
+ auto mapB = AffineMapAttr::get (AffineMap::get (2 , 0 , {n, k}, &context));
276
+ auto mapC = AffineMapAttr::get (AffineMap::get (2 , 0 , {n}, &context));
277
+ auto maps = ArrayAttr::get (&context, {mapA, mapB, mapC});
278
+
279
+ EXPECT_THAT (maps, Not (Truly (isVecmat)));
280
+ }
281
+
282
+ TEST (isMatvec, Simple) {
283
+ MLIRContext context;
284
+
285
+ AffineExpr k, n;
286
+ bindDims (&context, k, n);
287
+ auto mapA = AffineMapAttr::get (AffineMap::get (2 , 0 , {n, k}, &context));
288
+ auto mapB = AffineMapAttr::get (AffineMap::get (2 , 0 , {k}, &context));
289
+ auto mapC = AffineMapAttr::get (AffineMap::get (2 , 0 , {n}, &context));
290
+ auto maps = ArrayAttr::get (&context, {mapA, mapB, mapC});
291
+
292
+ EXPECT_THAT (maps, Truly (isMatvec));
293
+ }
294
+
295
+ TEST (isMatvec, BindingSwapped) {
296
+ MLIRContext context;
297
+
298
+ AffineExpr k, n;
299
+ bindDims (&context, n, k); // bind in different order
300
+ auto mapA = AffineMapAttr::get (AffineMap::get (2 , 0 , {n, k}, &context));
301
+ auto mapB = AffineMapAttr::get (AffineMap::get (2 , 0 , {k}, &context));
302
+ auto mapC = AffineMapAttr::get (AffineMap::get (2 , 0 , {n}, &context));
303
+ auto maps = ArrayAttr::get (&context, {mapA, mapB, mapC});
304
+
305
+ EXPECT_THAT (maps, Truly (isMatvec));
306
+ }
307
+
308
+ TEST (isMatvec, WrongDimOrderMatrix) {
309
+ MLIRContext context;
310
+
311
+ AffineExpr k, n;
312
+ bindDims (&context, k, n);
313
+ auto mapA = AffineMapAttr::get (AffineMap::get (2 , 0 , {k, n}, &context));
314
+ auto mapB = AffineMapAttr::get (AffineMap::get (2 , 0 , {k}, &context));
315
+ auto mapC = AffineMapAttr::get (AffineMap::get (2 , 0 , {n}, &context));
316
+ auto maps = ArrayAttr::get (&context, {mapA, mapB, mapC});
317
+
318
+ EXPECT_THAT (maps, Not (Truly (isMatvec)));
319
+ }
320
+
321
+ TEST (isBatchMatvec, Simple) {
322
+ MLIRContext context;
323
+
324
+ AffineExpr batch, k, n;
325
+ bindDims (&context, batch, k, n);
326
+ auto mapA = AffineMapAttr::get (AffineMap::get (3 , 0 , {batch, n, k}, &context));
327
+ auto mapB = AffineMapAttr::get (AffineMap::get (3 , 0 , {batch, k}, &context));
328
+ auto mapC = AffineMapAttr::get (AffineMap::get (3 , 0 , {batch, n}, &context));
329
+ auto maps = ArrayAttr::get (&context, {mapA, mapB, mapC});
330
+
331
+ EXPECT_THAT (maps, Truly (isBatchMatvec));
332
+ }
333
+
334
+ TEST (isBatchMatvec, BindingSwapped) {
335
+ MLIRContext context;
336
+
337
+ AffineExpr batch, k, n;
338
+ bindDims (&context, batch, n, k); // bind in different order
339
+ auto mapA = AffineMapAttr::get (AffineMap::get (3 , 0 , {batch, n, k}, &context));
340
+ auto mapB = AffineMapAttr::get (AffineMap::get (3 , 0 , {batch, k}, &context));
341
+ auto mapC = AffineMapAttr::get (AffineMap::get (3 , 0 , {batch, n}, &context));
342
+ auto maps = ArrayAttr::get (&context, {mapA, mapB, mapC});
343
+
344
+ EXPECT_THAT (maps, Truly (isBatchMatvec));
345
+ }
346
+
347
+ TEST (isBatchMatvec, Matmul) {
348
+ MLIRContext context;
349
+
350
+ AffineExpr m, n, k;
351
+ bindDims (&context, m, n, k);
352
+ auto mapA = AffineMapAttr::get (AffineMap::get (3 , 0 , {m, k}, &context));
353
+ auto mapB = AffineMapAttr::get (AffineMap::get (3 , 0 , {k, n}, &context));
354
+ auto mapC = AffineMapAttr::get (AffineMap::get (3 , 0 , {m, n}, &context));
355
+ auto maps = ArrayAttr::get (&context, {mapA, mapB, mapC});
356
+
357
+ EXPECT_THAT (maps, Not (Truly (isBatchMatvec)));
358
+ }
359
+
360
+ TEST (isBatchMatvec, WrongDimOrderMatrix) {
361
+ MLIRContext context;
362
+
363
+ AffineExpr batch, k, n;
364
+ bindDims (&context, batch, k, n);
365
+ auto mapA = AffineMapAttr::get (AffineMap::get (3 , 0 , {batch, k, n}, &context));
366
+ auto mapB = AffineMapAttr::get (AffineMap::get (3 , 0 , {batch, k}, &context));
367
+ auto mapC = AffineMapAttr::get (AffineMap::get (3 , 0 , {batch, n}, &context));
368
+ auto maps = ArrayAttr::get (&context, {mapA, mapB, mapC});
369
+
370
+ EXPECT_THAT (maps, Not (Truly (isBatchMatvec)));
371
+ }
372
+
243
373
} // namespace
0 commit comments