Skip to content

Commit 8cc616b

Browse files
authored
[mlir] Clamp UnPackOp tiling sizes from operand tile (#112429)
The `getIterationDomainTileFromOperandTile` implementation for tensor.unpack did not clamp sizes when the unpack op had extract_slice semantics. This PR fixes the bug. The PR also makes a minor change to `tileAndFuseConsumerOfSlice`. When replacing DPS inits, the iteration domain is needed, and it is computed from the tiled version of the operation after the initial tiling transformation. This can result in some extra indexing computation, so the PR changes it to use the original full sized cloned consumer op. --------- Signed-off-by: Max Dawkins <[email protected]>
1 parent 1884ffc commit 8cc616b

File tree

3 files changed

+95
-14
lines changed

3 files changed

+95
-14
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1996,13 +1996,17 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
19961996
candidateSliceOp, "containingOp's result yield with stride");
19971997
}
19981998

1999-
// 10. Try to get iter domain position from input position.
1999+
// 10. Try to get iter domain position from input position. Use
2000+
// clonedConsumerOp instead of tiledConsumerOp, because the iteration domain
2001+
// may require index computation based on the result size. The sizes and
2002+
// offsets should be the same either way, but using tiledConsumerOp could
2003+
// lead to some chained unnecessary extra index computation.
20002004
SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
2001-
if (failed(tiledConsumerOp.getIterationDomainTileFromOperandTile(
2005+
if (failed(clonedConsumerOp.getIterationDomainTileFromOperandTile(
20022006
rewriter, operandNumber, offsets, sizes, iterDomainOffsets,
20032007
iterDomainSizes))) {
20042008
return rewriter.notifyMatchFailure(
2005-
tiledConsumerOp,
2009+
clonedConsumerOp,
20062010
"can't get iter domain position from input position");
20072011
}
20082012

mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1717
#include "mlir/Dialect/Tensor/Utils/Utils.h"
1818
#include "mlir/Dialect/Utils/IndexingUtils.h"
19+
#include "mlir/Interfaces/InferTypeOpInterface.h"
1920
#include "mlir/Interfaces/TilingInterface.h"
2021
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
2122

@@ -621,6 +622,12 @@ struct UnPackOpTiling
621622
SmallVectorImpl<OpFoldResult> &resultOffsets,
622623
SmallVectorImpl<OpFoldResult> &resultSizes) const {
623624
auto unPackOp = cast<UnPackOp>(op);
625+
// If the operand tile is the dest, then no adjustment is needed.
626+
if (operandNumber == unPackOp.getDestMutable().getOperandNumber()) {
627+
resultOffsets = llvm::to_vector(offsets);
628+
resultSizes = llvm::to_vector(sizes);
629+
return success();
630+
}
624631
Location loc = unPackOp.getLoc();
625632

626633
int64_t numTiles = unPackOp.getInnerDimsPos().size();
@@ -629,6 +636,10 @@ struct UnPackOpTiling
629636
// The tiling is applied on interchanged dimensions. We have to undo the
630637
// interchange to map sizes and offsets to the original input.
631638
int64_t outputRank = unPackOp.getDestRank();
639+
ReifiedRankedShapedTypeDims reifiedReturnShapes;
640+
if (failed(reifyResultShapes(b, unPackOp, reifiedReturnShapes)))
641+
return failure();
642+
SmallVector<OpFoldResult> outputMixedSizes = reifiedReturnShapes.front();
632643
SmallVector<OpFoldResult> origOffsets(destOffsets);
633644
SmallVector<OpFoldResult> origSizes(destSizes);
634645
applyPermToRange(origOffsets, origSizes,
@@ -640,18 +651,21 @@ struct UnPackOpTiling
640651
for (auto dim : llvm::seq<int64_t>(0, outputRank)) {
641652
using AV = affine::AffineValueExpr;
642653
affine::AffineBuilder ab(b, loc);
643-
AffineExpr dim0, dim1, sym;
654+
AffineExpr dim0, dim1, sym0;
644655
bindDims(b.getContext(), dim0, dim1);
645-
bindSymbols(b.getContext(), sym);
656+
bindSymbols(b.getContext(), sym0);
646657
if (dimAndTileMapping.count(dim)) {
647658
// If the data dimension is tiled, the i-th index is the product of
648659
// offset_i and tile_i, and the i-th size is the product of sizes_i and
649-
// tile_i.
660+
// tile_i. The sizes must be clamped to the sizes of the unpack result.
650661
auto avOffset = AV(dim0).bind(origOffsets[dim]);
651662
auto avSize = AV(dim0).bind(origSizes[dim]);
652-
auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]);
663+
auto avTileSize = AV(sym0).bind(dimAndTileMapping[dim]);
664+
auto avResultSize = AV(dim0).bind(outputMixedSizes[dim]);
653665
resultOffsets.push_back(ab.mul(avOffset, avTileSize));
654-
resultSizes.push_back(ab.mul(avSize, avTileSize));
666+
auto avResultOffset = AV(dim1).bind(resultOffsets.back());
667+
resultSizes.push_back(ab.min({ab.mul(avSize, avTileSize),
668+
ab.sub(avResultSize, avResultOffset)}));
655669
} else {
656670
resultOffsets.push_back(origOffsets[dim]);
657671
resultSizes.push_back(origSizes[dim]);

mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir

Lines changed: 69 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ module {
265265
%c4 = arith.constant 4 : index
266266
%c64 = arith.constant 64 : index
267267
%c0 = arith.constant 0 : index
268-
%1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
268+
%1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
269269
%extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
270270
%3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
271271
^bb0(%in: f32, %in_16: f32, %out: f32):
@@ -292,26 +292,89 @@ module attributes {transform.with_named_sequence} {
292292
transform.yield
293293
}
294294
}
295-
// CHECK: #[[UNPACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
295+
// CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
296+
// CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2048)>
296297
// CHECK: func.func @fuse_unpack_consumer_into_scf_forall(
297298
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
298299
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
299300
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
300301
// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<2048xf32>
301-
// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
302+
// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) = (0, 0) to (64, 32) step (32, 32)
303+
// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[UNPACK_OUT_ARG:.*]] = %[[OUT_INIT]])
304+
// CHECK-SAME: {
305+
// CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
306+
// CHECK: %[[GENERIC_OUT:.*]] = linalg.generic
307+
// CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] :
308+
// CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]])
309+
// CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]])
310+
// CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
311+
// CHECK: %[[TILED_UNPACK_OUT:.*]] = tensor.unpack %[[GENERIC_OUT]]
312+
// CHECK-SAME: outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32]
313+
// CHECK-SAME: into %[[TILED_UNPACK_DEST]]
314+
// CHECK: scf.forall.in_parallel {
315+
// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
316+
// CHECK: tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
317+
// CHECK: }
318+
// CHECK: }
319+
// CHECK: return %[[FINAL_RESULT]]#1 :
320+
321+
// -----
322+
323+
#map = affine_map<(d0, d1) -> (d0, d1)>
324+
module {
325+
func.func @fuse_unaligned_unpack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<2047xf32> {
326+
%c4 = arith.constant 4 : index
327+
%c64 = arith.constant 64 : index
328+
%c0 = arith.constant 0 : index
329+
%1 = scf.forall (%arg3, %arg4) = (0, 0) to (64, 32) step (32, 32) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) {
330+
%extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32>
331+
%3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) {
332+
^bb0(%in: f32, %in_16: f32, %out: f32):
333+
%13 = arith.mulf %in, %in_16 : f32
334+
%14 = arith.addf %out, %13 : f32
335+
linalg.yield %14 : f32
336+
} -> tensor<32x32xf32>
337+
scf.forall.in_parallel {
338+
tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32>
339+
}
340+
}
341+
%output = tensor.empty() : tensor<2047xf32>
342+
%unpack = tensor.unpack %1 outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32] into %output : tensor<64x32xf32> -> tensor<2047xf32>
343+
return %unpack : tensor<2047xf32>
344+
}
345+
}
346+
347+
module attributes {transform.with_named_sequence} {
348+
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
349+
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
350+
: (!transform.any_op) -> !transform.any_op
351+
%a, %b = transform.test.fuse_consumer %slice_op
352+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
353+
transform.yield
354+
}
355+
}
356+
// CHECK-DAG: #[[UNPACK_RESULT_OFFSET_MAP:.*]] = affine_map<(d0) -> (d0 * 32)>
357+
// CHECK-DAG: #[[UNPACK_RESULT_SIZE_MAP:.*]] = affine_map<(d0) -> (1024, d0 * -32 + 2047)>
358+
// CHECK: func.func @fuse_unaligned_unpack_consumer_into_scf_forall(
359+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
360+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32>
361+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>)
362+
// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<2047xf32>
363+
// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) = (0, 0) to (64, 32) step (32, 32)
302364
// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[UNPACK_OUT_ARG:.*]] = %[[OUT_INIT]])
303365
// CHECK-SAME: {
304366
// CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
305367
// CHECK: %[[GENERIC_OUT:.*]] = linalg.generic
306368
// CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] :
307-
// CHECK: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_MAP]](%[[IV1]])
308-
// CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [1024] [1]
369+
// CHECK-DAG: %[[UNPACK_RESULT_OFFSET:.*]] = affine.apply #[[UNPACK_RESULT_OFFSET_MAP]](%[[IV1]])
370+
// CHECK-DAG: %[[UNPACK_RESULT_SIZE:.*]] = affine.min #[[UNPACK_RESULT_SIZE_MAP]](%[[IV1]])
371+
// CHECK: %[[TILED_UNPACK_DEST:.*]] = tensor.extract_slice %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
309372
// CHECK: %[[TILED_UNPACK_OUT:.*]] = tensor.unpack %[[GENERIC_OUT]]
310373
// CHECK-SAME: outer_dims_perm = [0] inner_dims_pos = [0] inner_tiles = [32]
311374
// CHECK-SAME: into %[[TILED_UNPACK_DEST]]
312375
// CHECK: scf.forall.in_parallel {
313376
// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1]
314-
// CHECK: tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [1024] [1]
377+
// CHECK: tensor.parallel_insert_slice %[[TILED_UNPACK_OUT]] into %[[UNPACK_OUT_ARG]][%[[UNPACK_RESULT_OFFSET]]] [%[[UNPACK_RESULT_SIZE]]] [1]
315378
// CHECK: }
316379
// CHECK: }
317380
// CHECK: return %[[FINAL_RESULT]]#1 :

0 commit comments

Comments
 (0)