@@ -1390,16 +1390,25 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
1390
1390
// Convolution vectorization patterns
1391
1391
// ===----------------------------------------------------------------------===//
1392
1392
namespace {
1393
- // / Generate a vector implementation for:
1393
+ // / Generate a vector implementation for either :
1394
1394
// / ```
1395
1395
// / Op def: ( n, w, c, kw, f )
1396
1396
// / Iters: ({Par(), Par(), Par(), Red(), Red()})
1397
1397
// / Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
1398
1398
// / ```
1399
1399
// / kw is unrolled, w is unrolled iff dilationW > 1.
1400
- struct Conv1D_NWC_WCF_Generator : public StructuredGenerator <LinalgOp> {
1401
- Conv1D_NWC_WCF_Generator (OpBuilder &builder, LinalgOp linalgOp, int strideW,
1402
- int dilationW)
1400
+ // /
1401
+ // / or
1402
+ // /
1403
+ // / ```
1404
+ // / Op def: ( n, w, c, kw )
1405
+ // / Iters: ({Par(), Par(), Par(), Red()})
1406
+ // / Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
1407
+ // / ```
1408
+ // / kw is unrolled, w is unrolled iff dilationW > 1.
1409
+ struct Conv1D_NWC_Generator : public StructuredGenerator <LinalgOp> {
1410
+ Conv1D_NWC_Generator (OpBuilder &builder, LinalgOp linalgOp, int strideW,
1411
+ int dilationW)
1403
1412
: StructuredGenerator<LinalgOp>(builder, linalgOp), valid(false ),
1404
1413
strideW (strideW), dilationW(dilationW) {
1405
1414
// Determine whether `linalgOp` can be generated with this generator
@@ -1413,7 +1422,8 @@ struct Conv1D_NWC_WCF_Generator : public StructuredGenerator<LinalgOp> {
1413
1422
resShapedType = resShaped.getType ().dyn_cast <ShapedType>();
1414
1423
if (!lhsShapedType || !rhsShapedType || !resShapedType)
1415
1424
return ;
1416
- if (lhsShapedType.getRank () != 3 || rhsShapedType.getRank () != 3 ||
1425
+ if (lhsShapedType.getRank () != 3 ||
1426
+ (rhsShapedType.getRank () != 2 && rhsShapedType.getRank () != 3 ) ||
1417
1427
resShapedType.getRank () != 3 )
1418
1428
return ;
1419
1429
@@ -1553,12 +1563,130 @@ struct Conv1D_NWC_WCF_Generator : public StructuredGenerator<LinalgOp> {
1553
1563
/* iteratorTypes=*/ ArrayRef<StringRef>{par, par, par, red});
1554
1564
}
1555
1565
1566
+ // / Generate a vector implementation for:
1567
+ // / ```
1568
+ // / Op def: ( n, w, c, kw)
1569
+ // / Iters: ({Par(), Par(), Par(), Red()})
1570
+ // / Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
1571
+ // / ```
1572
+ // / kw is always unrolled.
1573
+ // / TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is > 1.
1574
+ FailureOr<Operation *> dilated_conv () {
1575
+ if (!valid)
1576
+ return failure ();
1577
+
1578
+ int nSize = lhsShapedType.getShape ()[0 ];
1579
+ int wSize = resShapedType.getShape ()[1 ];
1580
+ int cSize = lhsShapedType.getShape ()[2 ];
1581
+ int kwSize = rhsShapedType.getShape ()[0 ];
1582
+
1583
+ vector::TransferWriteOp write;
1584
+ Value zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
1585
+
1586
+ // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
1587
+ // When strideW == 1, we can batch the contiguous loads and avoid unrolling
1588
+ int64_t wSizeStep = strideW == 1 ? wSize : 1 ;
1589
+
1590
+ Type lhsEltType = lhsShapedType.getElementType ();
1591
+ Type rhsEltType = rhsShapedType.getElementType ();
1592
+ Type resEltType = resShapedType.getElementType ();
1593
+ VectorType lhsType = VectorType::get (
1594
+ {nSize, (wSize - 1 ) * strideW + 1 + (kwSize - 1 ) * dilationW + 1 ,
1595
+ cSize},
1596
+ lhsEltType);
1597
+ VectorType rhsType = VectorType::get ({kwSize, cSize}, rhsEltType);
1598
+ VectorType resType = VectorType::get ({nSize, wSize, cSize}, resEltType);
1599
+
1600
+ // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0, 0].
1601
+ Value lhs = builder.create <vector::TransferReadOp>(
1602
+ loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
1603
+ // Read rhs slice of size {kw, c} @ [0, 0].
1604
+ Value rhs = builder.create <vector::TransferReadOp>(loc, rhsType, rhsShaped,
1605
+ ValueRange{zero, zero});
1606
+ // Read res slice of size {n, w, c} @ [0, 0, 0].
1607
+ Value res = builder.create <vector::TransferReadOp>(
1608
+ loc, resType, resShaped, ValueRange{zero, zero, zero});
1609
+
1610
+ // ===------------------------------------------------------------------===//
1611
+ // Begin vector-only rewrite part
1612
+ // ===------------------------------------------------------------------===//
1613
+ // Unroll along kw and read slices of lhs and rhs.
1614
+ SmallVector<Value> lhsVals, rhsVals, resVals;
1615
+ for (int64_t kw = 0 ; kw < kwSize; ++kw) {
1616
+ // Extract rhs slice of size {c} @ [kw].
1617
+ rhsVals.push_back (builder.create <vector::ExtractOp>(
1618
+ loc, rhs, /* offsets=*/ ArrayRef<int64_t >{kw}));
1619
+
1620
+ for (int64_t w = 0 ; w < wSize; w += wSizeStep) {
1621
+ // Extract lhs slice of size {n, wSizeStep, c}
1622
+ // @ [0, sw * w + dw * kw, 0].
1623
+ lhsVals.push_back (builder.create <vector::ExtractStridedSliceOp>(
1624
+ loc, lhs,
1625
+ /* offsets=*/ ArrayRef<int64_t >{0 , w * strideW + kw * dilationW, 0 },
1626
+ /* sizes=*/ ArrayRef<int64_t >{nSize, wSizeStep, cSize},
1627
+ /* strides=*/ ArrayRef<int64_t >{1 , 1 , 1 }));
1628
+
1629
+ // This does not depend on kw.
1630
+ if (kw == 0 ) {
1631
+ // Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
1632
+ resVals.push_back (builder.create <vector::ExtractStridedSliceOp>(
1633
+ loc, res,
1634
+ /* offsets=*/ ArrayRef<int64_t >{0 , w, 0 },
1635
+ /* sizes=*/ ArrayRef<int64_t >{nSize, wSizeStep, cSize},
1636
+ /* strides=*/ ArrayRef<int64_t >{1 , 1 , 1 }));
1637
+ }
1638
+ }
1639
+ }
1640
+
1641
+ auto linearIndex = [&](int64_t kw, int64_t w) {
1642
+ return kw * (wSize / wSizeStep) + w;
1643
+ };
1644
+
1645
+ // Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
1646
+ for (int64_t kw = 0 ; kw < kwSize; ++kw) {
1647
+ for (int64_t w = 0 ; w < wSize; w += wSizeStep) {
1648
+ resVals[w] = dilatedConv1dSliceAsContraction (
1649
+ builder, loc, lhsVals[linearIndex (kw, w)], rhsVals[kw], resVals[w]);
1650
+ }
1651
+ }
1652
+
1653
+ // Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
1654
+ // This does not depend on kw.
1655
+ for (int64_t w = 0 ; w < wSize; w += wSizeStep) {
1656
+ res = builder.create <vector::InsertStridedSliceOp>(
1657
+ loc, resVals[w], res,
1658
+ /* offsets=*/ ArrayRef<int64_t >{0 , w, 0 },
1659
+ /* strides=*/ ArrayRef<int64_t >{1 , 1 , 1 });
1660
+ }
1661
+ // ===------------------------------------------------------------------===//
1662
+ // End vector-only rewrite part
1663
+ // ===------------------------------------------------------------------===//
1664
+
1665
+ // Write back res slice of size {n, w, c} @ [0, 0, 0].
1666
+ return builder
1667
+ .create <vector::TransferWriteOp>(loc, res, resShaped,
1668
+ ValueRange{zero, zero, zero})
1669
+ .getOperation ();
1670
+ }
1671
+
1672
+ // Create a contraction: lhs{n, w, c} * rhs{c} -> res{n, w, c}
1673
+ vector::ContractionOp dilatedConv1dSliceAsContraction (OpBuilder &b,
1674
+ Location loc, Value lhs,
1675
+ Value rhs, Value res) {
1676
+ StringRef par = Par ().strRef , red = Red ().strRef ;
1677
+ AffineExpr n, w, c;
1678
+ bindDims (ctx, n, w, c);
1679
+ return builder.create <vector::ContractionOp>(
1680
+ loc, lhs, rhs, res,
1681
+ /* indexingMaps=*/ MapList{{n, w, c}, {c}, {n, w, c}},
1682
+ /* iteratorTypes=*/ ArrayRef<StringRef>{par, par, red});
1683
+ }
1684
+
1556
1685
// / Entry point that transposes into the common form:
1557
1686
// / {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
1558
1687
FailureOr<Operation *> generateConv () {
1559
1688
AffineExpr n, w, f, kw, c;
1560
1689
bindDims (ctx, n, w, f, kw, c);
1561
-
1562
1690
if (!iters ({Par (), Par (), Par (), Red (), Red ()}))
1563
1691
return failure ();
1564
1692
@@ -1570,6 +1698,22 @@ struct Conv1D_NWC_WCF_Generator : public StructuredGenerator<LinalgOp> {
1570
1698
return failure ();
1571
1699
}
1572
1700
1701
+ // / Entry point that transposes into the common form:
1702
+ // / {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
1703
+ FailureOr<Operation *> generateDilatedConv () {
1704
+ AffineExpr n, w, c, kw;
1705
+ bindDims (ctx, n, w, c, kw);
1706
+ if (!iters ({Par (), Par (), Par (), Red ()}))
1707
+ return failure ();
1708
+
1709
+ // No transposition needed.
1710
+ if (layout ({/* lhsIndex*/ {n, strideW * w + dilationW * kw, c},
1711
+ /* rhsIndex*/ {kw, c},
1712
+ /* resIndex*/ {n, w, c}}))
1713
+ return dilated_conv ();
1714
+ return failure ();
1715
+ }
1716
+
1573
1717
private:
1574
1718
bool valid;
1575
1719
int strideW, dilationW;
@@ -1588,8 +1732,11 @@ vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp) {
1588
1732
auto stride = strides ? *strides.getValues <uint64_t >().begin () : 1 ;
1589
1733
auto dilation = dilations ? *dilations.getValues <uint64_t >().begin () : 1 ;
1590
1734
LinalgOp linalgOp = cast<LinalgOp>(convOp.getOperation ());
1591
- Conv1D_NWC_WCF_Generator e (b, linalgOp, stride, dilation);
1592
- return e.generateConv ();
1735
+ Conv1D_NWC_Generator e (b, linalgOp, stride, dilation);
1736
+ auto res = e.generateConv ();
1737
+ if (succeeded (res))
1738
+ return res;
1739
+ return e.generateDilatedConv ();
1593
1740
}
1594
1741
1595
1742
struct VectorizeConvolution
0 commit comments