@@ -1324,15 +1324,14 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
1324
1324
VectorType lhsType = op.getLhsType ();
1325
1325
Value lhs = op.lhs (), rhs = op.rhs (), res = op.acc ();
1326
1326
1327
- // Set up the parallel/reduction structure in right form.
1328
- AffineExpr m, n, k;
1329
- bindDims (rewriter.getContext (), m, n, k);
1330
-
1331
1327
//
1332
1328
// Two outer parallel, one inner reduction (matmat flavor).
1333
1329
//
1334
1330
UnrolledOuterProductEmitter e (rewriter, op);
1335
1331
if (e.iters ({Par (), Par (), Red ()})) {
1332
+ // Set up the parallel/reduction structure in right form.
1333
+ AffineExpr m, n, k;
1334
+ bindDims (rewriter.getContext (), m, n, k);
1336
1335
// Classical row-major matmul: Just permute the lhs.
1337
1336
if (e.layout ({{m, k}, {k, n}, {m, n}}))
1338
1337
return e.outer_prod (e.t (lhs), rhs, res, lhsType.getDimSize (1 ));
@@ -1367,17 +1366,42 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
1367
1366
// One outer parallel, one inner reduction (matvec flavor)
1368
1367
//
1369
1368
if (e.iters ({Par (), Red ()})) {
1369
+ AffineExpr m, k;
1370
+ bindDims (rewriter.getContext (), m, k);
1371
+
1372
+ // Case mat-vec: transpose.
1373
+ if (e.layout ({{m, k}, {k}, {m}}))
1374
+ return e.outer_prod (e.t (lhs), rhs, res, lhsType.getDimSize (1 ));
1375
+ // Case mat-trans-vec: ready to go.
1376
+ if (e.layout ({{k, m}, {k}, {m}}))
1377
+ return e.outer_prod (lhs, rhs, res, lhsType.getDimSize (0 ));
1378
+ // Case vec-mat: swap and transpose.
1379
+ if (e.layout ({{k}, {m, k}, {m}}))
1380
+ return e.outer_prod (e.t (rhs), lhs, res, lhsType.getDimSize (0 ));
1381
+ // Case vec-mat-trans: swap and ready to go.
1382
+ if (e.layout ({{k}, {k, m}, {m}}))
1383
+ return e.outer_prod (rhs, lhs, res, lhsType.getDimSize (0 ));
1384
+ return failure ();
1385
+ }
1386
+
1387
+ //
1388
+ // One outer reduction, one inner parallel (tmatvec flavor)
1389
+ //
1390
+ if (e.iters ({Red (), Par ()})) {
1391
+ AffineExpr k, m;
1392
+ bindDims (rewriter.getContext (), k, m);
1393
+
1370
1394
// Case mat-vec: transpose.
1371
- if (e.layout ({{m, n }, {n }, {m}}))
1395
+ if (e.layout ({{m, k }, {k }, {m}}))
1372
1396
return e.outer_prod (e.t (lhs), rhs, res, lhsType.getDimSize (1 ));
1373
1397
// Case mat-trans-vec: ready to go.
1374
- if (e.layout ({{n , m}, {n }, {m}}))
1398
+ if (e.layout ({{k , m}, {k }, {m}}))
1375
1399
return e.outer_prod (lhs, rhs, res, lhsType.getDimSize (0 ));
1376
1400
// Case vec-mat: swap and transpose.
1377
- if (e.layout ({{n }, {m, n }, {m}}))
1401
+ if (e.layout ({{k }, {m, k }, {m}}))
1378
1402
return e.outer_prod (e.t (rhs), lhs, res, lhsType.getDimSize (0 ));
1379
1403
// Case vec-mat-trans: swap and ready to go.
1380
- if (e.layout ({{n }, {n , m}, {m}}))
1404
+ if (e.layout ({{k }, {k , m}, {m}}))
1381
1405
return e.outer_prod (rhs, lhs, res, lhsType.getDimSize (0 ));
1382
1406
return failure ();
1383
1407
}
0 commit comments