@@ -246,6 +246,120 @@ struct PackOpTiling
246
246
return failure ();
247
247
return tilingResult.value ();
248
248
}
249
+
250
+ // / Method to return the position of iteration domain tile computed by the
251
+ // / tiled operation. In current `tensor.pack` context, the `resultOffsets` and
252
+ // / `resultSizes` only cover outer dimensions.
253
+ LogicalResult getIterationDomainTileFromOperandTile (
254
+ Operation *op, OpBuilder &b, unsigned operandNumber,
255
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
256
+ SmallVectorImpl<OpFoldResult> &resultOffsets,
257
+ SmallVectorImpl<OpFoldResult> &resultSizes) const {
258
+ if (operandNumber != 0 )
259
+ return failure ();
260
+
261
+ auto packOp = cast<PackOp>(op);
262
+ // It is not trivial to infer dest tile from source tile if `packOp` has
263
+ // padding semantic.
264
+ if (packOp.getPaddingValue ())
265
+ return failure ();
266
+
267
+ Location loc = packOp.getLoc ();
268
+
269
+ SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
270
+ DenseMap<int64_t , OpFoldResult> dimAndTileMapping =
271
+ packOp.getDimAndTileMapping ();
272
+ for (auto dim : packOp.getOuterDimsPerm ()) {
273
+ if (dimAndTileMapping.count (dim)) {
274
+ FailureOr<int64_t > cstSize =
275
+ ValueBoundsConstraintSet::computeConstantBound (
276
+ presburger::BoundType::UB, sizes[dim],
277
+ /* stopCondition=*/ nullptr , /* closedUB=*/ true );
278
+ std::optional<int64_t > cstInnerSize =
279
+ getConstantIntValue (dimAndTileMapping[dim]);
280
+ // Currently fusing `packOp` as consumer only expects perfect tiling
281
+ // scenario because even if without padding semantic, the `packOp` may
282
+ // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
283
+ // where the `tileSize` from operand of `packOp` is 5, which is not
284
+ // exactly divided by `innerTile`(=6) of `packOp`. As the result:
285
+ // 1. the first slice is extracted from (0) to (4) and inserted into
286
+ // (0,0)~(0,4) at first row.
287
+ // 2. the second slice is extracted from (5) to (9) and SHOULD BE
288
+ // respectively inserted into two rows with different length, including
289
+ // first row: (0,5) and second row (1,0)~(1,3). It is hard to coordinate
290
+ // them, thus adding below constraint to bypass them temporarily. In
291
+ // another word, we can only support tiling with consumer if the tile
292
+ // size for the producer is a multiple of the inner tile size for the
293
+ // packed dimensions at this moment.
294
+ if (failed (cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0 ) {
295
+ return failure ();
296
+ }
297
+
298
+ using AV = affine::AffineValueExpr;
299
+ affine::AffineBuilder ab (b, loc);
300
+ AffineExpr dim0, sym;
301
+ bindDims (b.getContext (), dim0);
302
+ bindSymbols (b.getContext (), sym);
303
+ auto avOffset = AV (dim0).bind (offsets[dim]);
304
+ auto avSize = AV (dim0).bind (sizes[dim]);
305
+ auto avTileSize = AV (sym).bind (dimAndTileMapping[dim]);
306
+ outerDimOffsets.push_back (ab.floor (avOffset, avTileSize));
307
+ outerDimSizes.push_back (ab.ceil (avSize, avTileSize));
308
+ } else {
309
+ outerDimOffsets.push_back (offsets[dim]);
310
+ outerDimSizes.push_back (sizes[dim]);
311
+ }
312
+ }
313
+
314
+ resultOffsets = outerDimOffsets;
315
+ resultSizes = outerDimSizes;
316
+ return success ();
317
+ }
318
+
319
+ // / Method to return the tiled implementation of tensor.pack as a consumer.
320
+ FailureOr<TilingResult> getTiledImplementationFromOperandTile (
321
+ Operation *op, OpBuilder &b, unsigned operandNumber,
322
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
323
+ if (operandNumber != 0 )
324
+ return failure ();
325
+
326
+ auto packOp = cast<PackOp>(op);
327
+ Location loc = packOp.getLoc ();
328
+
329
+ int64_t inputRank = packOp.getSourceRank ();
330
+ auto oneAttr = b.getI64IntegerAttr (1 );
331
+ SmallVector<OpFoldResult> strides (inputRank, oneAttr);
332
+
333
+ SmallVector<Value> tiledOperands;
334
+ tiledOperands.push_back (b.create <ExtractSliceOp>(loc, packOp.getSource (),
335
+ offsets, sizes, strides));
336
+
337
+ SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
338
+ if (failed (getIterationDomainTileFromOperandTile (
339
+ op, b, /* operandNumber=*/ 0 , offsets, sizes, outerDimOffsets,
340
+ outerDimSizes)))
341
+ return failure ();
342
+
343
+ SmallVector<OpFoldResult> outputOffsets, outputSizes;
344
+ if (failed (getResultTilePosition (op, b, 0 , outerDimOffsets, outerDimSizes,
345
+ outputOffsets, outputSizes)))
346
+ return failure ();
347
+
348
+ strides.append (packOp.getDestRank () - inputRank, oneAttr);
349
+ auto extractSlice = b.create <ExtractSliceOp>(
350
+ loc, packOp.getDest (), outputOffsets, outputSizes, strides);
351
+ tiledOperands.push_back (extractSlice);
352
+
353
+ assert (!packOp.getPaddingValue () && " Expect no padding semantic" );
354
+ for (auto tile : packOp.getInnerTiles ())
355
+ tiledOperands.push_back (tile);
356
+
357
+ Operation *tiledPackOp = b.create <PackOp>(
358
+ loc, TypeRange{extractSlice.getType ()}, tiledOperands, op->getAttrs ());
359
+
360
+ return TilingResult{{tiledPackOp},
361
+ SmallVector<Value>(tiledPackOp->getResults ())};
362
+ }
249
363
};
250
364
251
365
struct UnpackTileDimInfo {
0 commit comments