@@ -56,10 +56,21 @@ namespace {
56
56
// `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
57
57
//
58
58
struct TileCheck : public AffineExprVisitor <TileCheck> {
59
- TileCheck (ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {}
59
+ TileCheck (ArrayRef<OpFoldResult> tileSizes, ArrayRef<int64_t > domainSizes)
60
+ : tileSizes(tileSizes), domainSizes(domainSizes) {}
60
61
61
62
void visitDimExpr (AffineDimExpr expr) {
62
- isTiled |= !isZeroIndex (tileSizes[expr.getPosition ()]);
63
+ unsigned pos = expr.getPosition ();
64
+
65
+ // There is no tile if all tile sizes correspond to the domain size
66
+ std::optional<int64_t > tileSize = getConstantIntValue (tileSizes[pos]);
67
+ if (tileSize && !domainSizes.empty ()) {
68
+ if (domainSizes[pos] == *tileSize) {
69
+ return ;
70
+ }
71
+ }
72
+
73
+ isTiled |= !isZeroIndex (tileSizes[pos]);
63
74
}
64
75
void visitAffineBinaryOpExpr (AffineBinaryOpExpr expr) {
65
76
visit (expr.getLHS ());
@@ -70,24 +81,28 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
70
81
}
71
82
bool isTiled = false ;
72
83
ArrayRef<OpFoldResult> tileSizes;
84
+ ArrayRef<int64_t > domainSizes;
73
85
};
74
86
75
87
} // namespace
76
88
77
- static bool isTiled (AffineExpr expr, ArrayRef<OpFoldResult> tileSizes) {
89
+ static bool isTiled (AffineExpr expr, ArrayRef<OpFoldResult> tileSizes,
90
+ ArrayRef<int64_t > domainSizes) {
78
91
if (!expr)
79
92
return false ;
80
- TileCheck t (tileSizes);
93
+
94
+ TileCheck t (tileSizes, domainSizes);
81
95
t.visit (expr);
82
96
return t.isTiled ;
83
97
}
84
98
85
99
// Checks whether the `map varies with respect to a non-zero `tileSize`.
86
- static bool isTiled (AffineMap map, ArrayRef<OpFoldResult> tileSizes) {
100
+ static bool isTiled (AffineMap map, ArrayRef<OpFoldResult> tileSizes,
101
+ ArrayRef<int64_t > domainSizes) {
87
102
if (!map)
88
103
return false ;
89
104
for (unsigned r = 0 ; r < map.getNumResults (); ++r)
90
- if (isTiled (map.getResult (r), tileSizes))
105
+ if (isTiled (map.getResult (r), tileSizes, domainSizes ))
91
106
return true ;
92
107
return false ;
93
108
}
@@ -556,19 +571,19 @@ Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
556
571
ArrayRef<OpFoldResult> lbs,
557
572
ArrayRef<OpFoldResult> ubs,
558
573
ArrayRef<OpFoldResult> subShapeSizes,
559
- bool omitPartialTileCheck) {
560
- SliceParameters sliceParams =
561
- computeSliceParameters (builder, loc, valueToTile, tileSizes, map, lbs,
562
- ubs, subShapeSizes, omitPartialTileCheck);
574
+ bool omitPartialTileCheck,
575
+ ArrayRef<int64_t > domainSizes) {
576
+ SliceParameters sliceParams = computeSliceParameters (
577
+ builder, loc, valueToTile, tileSizes, map, lbs, ubs, subShapeSizes,
578
+ omitPartialTileCheck, domainSizes);
563
579
return materializeTiledShape (builder, loc, valueToTile, sliceParams);
564
580
}
565
581
566
- SliceParameters
567
- computeSliceParameters (OpBuilder &builder, Location loc, Value valueToTile,
568
- ArrayRef<OpFoldResult> tileSizes, AffineMap map,
569
- ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
570
- ArrayRef<OpFoldResult> subShapeSizes,
571
- bool omitPartialTileCheck) {
582
+ SliceParameters computeSliceParameters (
583
+ OpBuilder &builder, Location loc, Value valueToTile,
584
+ ArrayRef<OpFoldResult> tileSizes, AffineMap map, ArrayRef<OpFoldResult> lbs,
585
+ ArrayRef<OpFoldResult> ubs, ArrayRef<OpFoldResult> subShapeSizes,
586
+ bool omitPartialTileCheck, ArrayRef<int64_t > domainSizes) {
572
587
auto shapedType = dyn_cast<ShapedType>(valueToTile.getType ());
573
588
assert (shapedType && " only shaped types can be tiled" );
574
589
ArrayRef<int64_t > shape = shapedType.getShape ();
@@ -585,7 +600,7 @@ computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
585
600
// The offset & size computation below only handles the case when
586
601
// the map is monotonically increasing, i.e. the min and max values are
587
602
// attained at the lower and upper bounds of the iteration domain.
588
- if (!isTiled (m, tileSizes) || !m. isComponentWiseMonotonicallyIncreasing ( )) {
603
+ if (!isTiled (m, tileSizes, domainSizes )) {
589
604
sliceParams.offsets .push_back (builder.getIndexAttr (0 ));
590
605
OpFoldResult dim = createFoldedDimOp (builder, loc, valueToTile, r);
591
606
sliceParams.sizes .push_back (dim);
@@ -786,8 +801,9 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
786
801
// subdomains explicit.
787
802
788
803
Type operandType = opOperand.get ().getType ();
789
- if (!isTiled (map, tileSizes) && !(isa<RankedTensorType>(operandType) &&
790
- linalgOp.isDpsInit (&opOperand))) {
804
+ if (!isTiled (map, tileSizes, linalgOp.getStaticLoopRanges ()) &&
805
+ !(isa<RankedTensorType>(operandType) &&
806
+ linalgOp.isDpsInit (&opOperand))) {
791
807
allSliceParams.push_back (std::nullopt);
792
808
LLVM_DEBUG (llvm::dbgs ()
793
809
<< " : not tiled: use shape: " << operandType << " \n " );
@@ -797,7 +813,7 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
797
813
798
814
allSliceParams.push_back (computeSliceParameters (
799
815
builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes,
800
- omitPartialTileCheck));
816
+ omitPartialTileCheck, linalgOp. getStaticLoopRanges () ));
801
817
}
802
818
803
819
return allSliceParams;
0 commit comments