@@ -255,7 +255,15 @@ struct PackOpTiling
255
255
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
256
256
SmallVectorImpl<OpFoldResult> &resultOffsets,
257
257
SmallVectorImpl<OpFoldResult> &resultSizes) const {
258
+ if (operandNumber != 0 )
259
+ return failure ();
260
+
258
261
auto packOp = cast<PackOp>(op);
262
+ // It is not trivial to infer dest tile from source tile if `packOp` has
263
+ // padding semantic.
264
+ if (packOp.getPaddingValue ())
265
+ return failure ();
266
+
259
267
Location loc = packOp.getLoc ();
260
268
261
269
SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
@@ -269,7 +277,20 @@ struct PackOpTiling
269
277
/* stopCondition=*/ nullptr , /* closedUB=*/ true );
270
278
std::optional<int64_t > cstInnerSize =
271
279
getConstantIntValue (dimAndTileMapping[dim]);
272
- // Currently only expect perfect tiling cases.
280
+ // Currently fusing `packOp` as consumer only expects perfect tiling
281
+ // scenario because even if without padding semantic, the `packOp` may
282
+ // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
283
+ // where the `tileSize` from operand of `packOp` is 5, which is not
284
+ // exactly divided by `innerTile`(=6) of `packOp`. As the result:
285
+ // 1. the first slice is extracted from (0) to (4) and inserted into
286
+ // (0,0)~(0,4) at first row.
287
+ // 2. the second slice is extracted from (5) to (9) and SHOULD BE
288
+ // respectively inserted into two rows with different length, including
289
+ // first row: (0,5) and second row (1,0)~(1,3). It is hard to coordinate
290
+ // them, thus adding below constraint to bypass them temporarily. In
291
+ // another word, we can only support tiling with consumer if the tile
292
+ // size for the producer is a multiple of the inner tile size for the
293
+ // packed dimensions at this moment.
273
294
if (failed (cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0 ) {
274
295
return failure ();
275
296
}
@@ -299,6 +320,9 @@ struct PackOpTiling
299
320
FailureOr<TilingResult> getTiledImplementationFromOperandTile (
300
321
Operation *op, OpBuilder &b, unsigned operandNumber,
301
322
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
323
+ if (operandNumber != 0 )
324
+ return failure ();
325
+
302
326
auto packOp = cast<PackOp>(op);
303
327
Location loc = packOp.getLoc ();
304
328
@@ -326,8 +350,7 @@ struct PackOpTiling
326
350
loc, packOp.getDest (), outputOffsets, outputSizes, strides);
327
351
tiledOperands.push_back (extractSlice);
328
352
329
- if (auto val = packOp.getPaddingValue ())
330
- tiledOperands.push_back (val);
353
+ assert (!packOp.getPaddingValue () && " Expect no padding semantic" );
331
354
for (auto tile : packOp.getInnerTiles ())
332
355
tiledOperands.push_back (tile);
333
356
0 commit comments