Skip to content

Commit f06563a

Browse files
authored
[mlir][tensor] Add consumer fusion for tensor.pack op. (#103715)
Add missing `getIterationDomainTileFromOperandTile` and `getTiledImplementationFromOperandTile` to `tensor.pack` and enable fusing it as a consumer. NOTE that, it only expects perfect tiling scenario without padding semantic currently.
1 parent be8ee09 commit f06563a

File tree

2 files changed

+173
-0
lines changed

2 files changed

+173
-0
lines changed

mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,120 @@ struct PackOpTiling
246246
return failure();
247247
return tilingResult.value();
248248
}
249+
250+
/// Method to return the position of iteration domain tile computed by the
251+
/// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
252+
/// `resultSizes` only cover outer dimensions.
253+
LogicalResult getIterationDomainTileFromOperandTile(
254+
Operation *op, OpBuilder &b, unsigned operandNumber,
255+
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
256+
SmallVectorImpl<OpFoldResult> &resultOffsets,
257+
SmallVectorImpl<OpFoldResult> &resultSizes) const {
258+
if (operandNumber != 0)
259+
return failure();
260+
261+
auto packOp = cast<PackOp>(op);
262+
// It is not trivial to infer dest tile from source tile if `packOp` has
263+
// padding semantic.
264+
if (packOp.getPaddingValue())
265+
return failure();
266+
267+
Location loc = packOp.getLoc();
268+
269+
SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
270+
DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
271+
packOp.getDimAndTileMapping();
272+
for (auto dim : packOp.getOuterDimsPerm()) {
273+
if (dimAndTileMapping.count(dim)) {
274+
FailureOr<int64_t> cstSize =
275+
ValueBoundsConstraintSet::computeConstantBound(
276+
presburger::BoundType::UB, sizes[dim],
277+
/*stopCondition=*/nullptr, /*closedUB=*/true);
278+
std::optional<int64_t> cstInnerSize =
279+
getConstantIntValue(dimAndTileMapping[dim]);
280+
// Currently fusing `packOp` as consumer only expects perfect tiling
281+
// scenario because even if without padding semantic, the `packOp` may
282+
// also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
283+
// where the `tileSize` from operand of `packOp` is 5, which is not
284+
// exactly divided by `innerTile`(=6) of `packOp`. As the result:
285+
// 1. the first slice is extracted from (0) to (4) and inserted into
286+
// (0,0)~(0,4) at first row.
287+
// 2. the second slice is extracted from (5) to (9) and SHOULD BE
288+
// respectively inserted into two rows with different length, including
289+
// first row: (0,5) and second row (1,0)~(1,3). It is hard to coordinate
290+
// them, thus adding below constraint to bypass them temporarily. In
291+
// another word, we can only support tiling with consumer if the tile
292+
// size for the producer is a multiple of the inner tile size for the
293+
// packed dimensions at this moment.
294+
if (failed(cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) {
295+
return failure();
296+
}
297+
298+
using AV = affine::AffineValueExpr;
299+
affine::AffineBuilder ab(b, loc);
300+
AffineExpr dim0, sym;
301+
bindDims(b.getContext(), dim0);
302+
bindSymbols(b.getContext(), sym);
303+
auto avOffset = AV(dim0).bind(offsets[dim]);
304+
auto avSize = AV(dim0).bind(sizes[dim]);
305+
auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
306+
outerDimOffsets.push_back(ab.floor(avOffset, avTileSize));
307+
outerDimSizes.push_back(ab.ceil(avSize, avTileSize));
308+
} else {
309+
outerDimOffsets.push_back(offsets[dim]);
310+
outerDimSizes.push_back(sizes[dim]);
311+
}
312+
}
313+
314+
resultOffsets = outerDimOffsets;
315+
resultSizes = outerDimSizes;
316+
return success();
317+
}
318+
319+
/// Method to return the tiled implementation of tensor.pack as a consumer.
320+
FailureOr<TilingResult> getTiledImplementationFromOperandTile(
321+
Operation *op, OpBuilder &b, unsigned operandNumber,
322+
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
323+
if (operandNumber != 0)
324+
return failure();
325+
326+
auto packOp = cast<PackOp>(op);
327+
Location loc = packOp.getLoc();
328+
329+
int64_t inputRank = packOp.getSourceRank();
330+
auto oneAttr = b.getI64IntegerAttr(1);
331+
SmallVector<OpFoldResult> strides(inputRank, oneAttr);
332+
333+
SmallVector<Value> tiledOperands;
334+
tiledOperands.push_back(b.create<ExtractSliceOp>(loc, packOp.getSource(),
335+
offsets, sizes, strides));
336+
337+
SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
338+
if (failed(getIterationDomainTileFromOperandTile(
339+
op, b, /*operandNumber=*/0, offsets, sizes, outerDimOffsets,
340+
outerDimSizes)))
341+
return failure();
342+
343+
SmallVector<OpFoldResult> outputOffsets, outputSizes;
344+
if (failed(getResultTilePosition(op, b, 0, outerDimOffsets, outerDimSizes,
345+
outputOffsets, outputSizes)))
346+
return failure();
347+
348+
strides.append(packOp.getDestRank() - inputRank, oneAttr);
349+
auto extractSlice = b.create<ExtractSliceOp>(
350+
loc, packOp.getDest(), outputOffsets, outputSizes, strides);
351+
tiledOperands.push_back(extractSlice);
352+
353+
assert(!packOp.getPaddingValue() && "Expect no padding semantic");
354+
for (auto tile : packOp.getInnerTiles())
355+
tiledOperands.push_back(tile);
356+
357+
Operation *tiledPackOp = b.create<PackOp>(
358+
loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs());
359+
360+
return TilingResult{{tiledPackOp},
361+
SmallVector<Value>(tiledPackOp->getResults())};
362+
}
249363
};
250364

251365
struct UnpackTileDimInfo {

mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,3 +315,62 @@ module attributes {transform.with_named_sequence} {
315315
// CHECK: }
316316
// CHECK: }
317317
// CHECK: return %[[FINAL_RESULT]]#1 :
318+
319+
// -----
320+
321+
#map = affine_map<(d0, d1) -> (d0, d1)>
322+
module {
323+
func.func @fuse_pack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> {
324+
%c4 = arith.constant 4 : index
325+
%c64 = arith.constant 64 : index
326+
%c0 = arith.constant 0 : index
327+
%1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
328+
%extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
329+
%3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
330+
^bb0(%in: f32, %in_16: f32, %out: f32):
331+
%13 = arith.mulf %in, %in_16 : f32
332+
%14 = arith.addf %out, %13 : f32
333+
linalg.yield %14 : f32
334+
} -> tensor<32x32xf32>
335+
scf.forall.in_parallel {
336+
tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
337+
}
338+
}
339+
%output = tensor.empty() : tensor<4x32x16xf32>
340+
%pack = tensor.pack %1 outer_dims_perm = [0, 1] inner_dims_pos = [0] inner_tiles = [16] into %output : tensor<64x32xf32> -> tensor<4x32x16xf32>
341+
return %pack : tensor<4x32x16xf32>
342+
}
343+
}
344+
345+
module attributes {transform.with_named_sequence} {
346+
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
347+
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
348+
: (!transform.any_op) -> !transform.any_op
349+
%a, %b = transform.test.fuse_consumer %slice_op
350+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
351+
transform.yield
352+
}
353+
}
354+
// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)>
355+
// CHECK: func.func @fuse_pack_consumer_into_scf_forall(
356+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
357+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
358+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
359+
// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<4x32x16xf32>
360+
// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
361+
// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]])
362+
// CHECK-SAME: {
363+
// CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
364+
// CHECK: %[[GENERIC_OUT:.*]] = linalg.generic
365+
// CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] :
366+
// CHECK: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV1]])
367+
// CHECK: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1]
368+
// CHECK: %[[TILED_PACK_OUT:.*]] = tensor.pack %[[GENERIC_OUT]]
369+
// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0] inner_tiles = [16]
370+
// CHECK-SAME: into %[[TILED_PACK_DEST]]
371+
// CHECK: scf.forall.in_parallel {
372+
// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1]
373+
// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
374+
// CHECK: }
375+
// CHECK: }
376+
// CHECK: return %[[FINAL_RESULT]]#1 :

0 commit comments

Comments
 (0)