Skip to content

Commit 66f84c8

Browse files
authored
[mlir][tensor] Extend the logic to generalise tensor.pack (#109815)
Extends the logic to generalise tensor.pack (into e.g. tensor.pad + tensor.transpose) so that it also works when one of the inner tile sizes is scalable (i.e. a multiple of `vector.vscale`). For example: ```mlir %c8 = arith.constant 8 : index %vscale = vector.vscale %c8_vscale = arith.muli %vscale, %c8 : index %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%c8_vscale, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32> } ``` is generalised as: ```mlir %c8 = arith.constant 8 : index %vscale = vector.vscale %c8_vscale = arith.muli %vscale, %c8 : index %0 = affine.apply #map()[%c8_vscale, %c5] %padded = tensor.pad %arg0 low[0, 0] high[%0, 1] { ^bb0(%arg3: index, %arg4: index): tensor.yield %arg2 : f32 } : tensor<5x1xf32> to tensor<?x2xf32> ``` At the Tensor level, we model scalability using dynamic shapes and this change basically extends the relevant logic so that it also works for dynamic shapes.
1 parent 37e717e commit 66f84c8

File tree

4 files changed

+161
-46
lines changed

4 files changed

+161
-46
lines changed

mlir/include/mlir/Dialect/Tensor/Utils/Utils.h

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,22 @@
1414
namespace mlir {
1515
namespace tensor {
1616

17-
// Return a PadOp that pads `source` to `type` size where the static
18-
// sizes are assumed to be greater than the dynamic sizes. If `type` has dynamic
19-
// dimensions the padding width is set to zero. The op performs "high" padding
20-
// (i.e. it adds trailing padding values until the desired size is met).
21-
PadOp createPadHighOp(RankedTensorType type, Value source, Value pad,
22-
bool nofold, Location loc, OpBuilder &builder);
17+
// Return a PadOp that pads `source` to `resType` size. The op performs "high"
18+
// padding, i.e. it adds trailing padding values until the desired size is met.
19+
// Output sizes are assumed to be greater than the input sizes. The padding
20+
// width is calculated as: resDim - sourceDim.
21+
//
22+
// Handling static sizes is trivial. Dynamic dimensions are trickier (*):
23+
// 1. dynamic input sizes are extracted from `source`
24+
// 2. for dynamic output dims, there are two options:
25+
// 2.1 all output dynamic dim sizes are specified in `dynOutDim`,
26+
// 2.2 `dynOutDim` is empty and the corresponding padding width is set to 0.
27+
//
28+
// (*) Note that `resType` is just a shape and it only encodes the actual sizes
29+
// for _static_ dimensions.
30+
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad,
31+
bool nofold, Location loc, OpBuilder &builder,
32+
SmallVector<Value> dynOutDim = {});
2333

2434
// Creates dim ops for each dynamic dimension of the ranked tensor argument and
2535
// returns these as values.

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 56 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,8 +1021,11 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
10211021
return success();
10221022
}
10231023

1024-
/// Returns a tensor.pad op if padding value is set. Otherwise, returns the
1025-
/// source directly. The method assumes that the `packOp` has static shapes.
1024+
/// If padding value is set, returns a tensor.pad Op for the source tensor,
1025+
/// with the output shape matching the output of `packOp`. Otherwise, returns
1026+
/// the source directly.
1027+
///
1028+
/// This method assumes that all outer dims for this pack Op are 1.
10261029
static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
10271030
tensor::PackOp packOp) {
10281031
Value input = packOp.getSource();
@@ -1038,26 +1041,48 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
10381041
ShapedType inputType = packOp.getSourceType();
10391042
int64_t inputRank = inputType.getRank();
10401043

1041-
SmallVector<int64_t> paddedShape;
10421044
DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
10431045
packOp.getDimAndTileMapping();
1044-
for (int64_t dim = 0; dim < inputRank; ++dim) {
1045-
int64_t size = inputType.getDimSize(dim);
1046-
if (!tileAndPosMapping.count(dim)) {
1047-
paddedShape.push_back(size);
1046+
1047+
// The sizes of dynamic tiles
1048+
SmallVector<Value> dynamicTileSizes;
1049+
1050+
// Collect dims for the padded shape.
1051+
SmallVector<int64_t> paddedShape;
1052+
for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
1053+
// 1. Non-tiled outer dims.
1054+
// These dims should be 1 and we simply preserve them.
1055+
if (!tileAndPosMapping.count(dimIdx)) {
1056+
int64_t inputDimSize = inputType.getDimSize(dimIdx);
1057+
assert(inputDimSize == 1 &&
1058+
"with all outer dims == 1, this non-tiled input dim should be 1!");
1059+
paddedShape.push_back(inputDimSize);
1060+
continue;
1061+
}
1062+
1063+
// 2. Tiled outer dims
1064+
// As all outer dims == 1, it is safe to use the tile size for the padded
1065+
// shape.
1066+
OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
1067+
1068+
// 2.1 Static tile sizes
1069+
std::optional<int64_t> cstTileSize = getConstantIntValue(tileSizeForDim);
1070+
if (cstTileSize.has_value()) {
1071+
paddedShape.push_back(cstTileSize.value());
10481072
continue;
10491073
}
10501074

1051-
// The size is less than or equal to tileSize because outer dims are all 1s.
1052-
std::optional<int64_t> tileSize =
1053-
getConstantIntValue(tileAndPosMapping.lookup(dim));
1054-
assert(tileSize.has_value() && "dynamic inner tile size is not supported");
1055-
paddedShape.push_back(tileSize.value());
1075+
// 2.2 Dynamic tile sizes
1076+
paddedShape.push_back(ShapedType::kDynamic);
1077+
1078+
// Get the value that holds the dynamic size.
1079+
dynamicTileSizes.push_back(llvm::dyn_cast<Value>(tileSizeForDim));
10561080
}
10571081
auto resultType =
10581082
RankedTensorType::get(paddedShape, inputType.getElementType());
10591083
return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(),
1060-
/*nofold=*/false, loc, builder);
1084+
/*nofold=*/false, loc, builder,
1085+
dynamicTileSizes);
10611086
}
10621087

10631088
// Normalizes a permutation on a higher rank space to its actual size, e.g.
@@ -1120,10 +1145,10 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
11201145

11211146
LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
11221147
tensor::PackOp packOp, PatternRewriter &rewriter) const {
1123-
if (llvm::any_of(packOp.getMixedTiles(),
1124-
[](OpFoldResult tile) { return tile.is<Value>(); })) {
1125-
return rewriter.notifyMatchFailure(packOp,
1126-
"require inner tile sizes being static");
1148+
if (llvm::count_if(packOp.getMixedTiles(),
1149+
[](OpFoldResult tile) { return tile.is<Value>(); }) > 1) {
1150+
return rewriter.notifyMatchFailure(
1151+
packOp, "at most one dynamic tile size is supported");
11271152
}
11281153

11291154
// TODO: support the case that outer dimensions are not all 1s. A
@@ -1147,12 +1172,15 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
11471172
SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
11481173
SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
11491174
SmallVector<OpFoldResult> readSizes;
1150-
SmallVector<int64_t> readShape;
1175+
SmallVector<OpFoldResult> transShapeForEmpty;
1176+
SmallVector<int64_t> readShapeForExtractSlice;
11511177
for (auto i : llvm::seq<unsigned>(0, srcRank)) {
11521178
if (dimAndTileMapping.count(i)) {
1153-
readShape.push_back(getConstantIntValue(dimAndTileMapping[i])
1154-
.value_or(ShapedType::kDynamic));
1179+
readShapeForExtractSlice.push_back(
1180+
getConstantIntValue(dimAndTileMapping[i])
1181+
.value_or(ShapedType::kDynamic));
11551182
readSizes.push_back(dimAndTileMapping[i]);
1183+
transShapeForEmpty.push_back(dimAndTileMapping[i]);
11561184
continue;
11571185
}
11581186
if (ShapedType::isDynamic(inputShape[i])) {
@@ -1161,12 +1189,14 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
11611189
} else {
11621190
readSizes.push_back(rewriter.getIndexAttr(inputShape[i]));
11631191
}
1164-
if (inputShape[i] != 1)
1165-
readShape.push_back(inputShape[i]);
1192+
if (inputShape[i] != 1) {
1193+
readShapeForExtractSlice.push_back(inputShape[i]);
1194+
transShapeForEmpty.push_back(rewriter.getIndexAttr(inputShape[i]));
1195+
}
11661196
}
11671197

11681198
Type elemType = packOp.getSourceType().getElementType();
1169-
auto readType = RankedTensorType::get(readShape, elemType);
1199+
auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
11701200

11711201
Value tile = rewriter.create<tensor::ExtractSliceOp>(
11721202
loc, readType, input, readOffsets, readSizes, readStrides);
@@ -1178,10 +1208,10 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
11781208
LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
11791209
llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
11801210

1181-
SmallVector<int64_t> transpShape = readShape;
1182-
applyPermutationToVector<int64_t>(transpShape, perm);
1211+
applyPermutationToVector<OpFoldResult>(transShapeForEmpty, perm);
11831212

1184-
Value empty = rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
1213+
Value empty =
1214+
rewriter.create<tensor::EmptyOp>(loc, transShapeForEmpty, elemType);
11851215
auto transposedOp =
11861216
rewriter.create<linalg::TransposeOp>(loc, tile, empty, perm);
11871217

mlir/lib/Dialect/Tensor/Utils/Utils.cpp

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,28 +16,48 @@
1616
#include "mlir/Dialect/Arith/IR/Arith.h"
1717
#include "mlir/Dialect/Arith/Utils/Utils.h"
1818
#include "mlir/Dialect/Utils/IndexingUtils.h"
19+
#include "mlir/Dialect/Vector/IR//VectorOps.h"
1920
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
2021

2122
using namespace mlir;
2223
using namespace mlir::tensor;
2324

24-
PadOp mlir::tensor::createPadHighOp(RankedTensorType type, Value source,
25+
PadOp mlir::tensor::createPadHighOp(RankedTensorType resType, Value source,
2526
Value pad, bool nofold, Location loc,
26-
OpBuilder &b) {
27-
SmallVector<OpFoldResult> low(type.getRank(), b.getIndexAttr(0));
28-
SmallVector<OpFoldResult> high(type.getRank(), b.getIndexAttr(0));
29-
for (const auto &en : enumerate(type.getShape())) {
30-
// Pad only the static dimensions of the result tensor type.
31-
if (ShapedType::isDynamic(en.value()))
27+
OpBuilder &b,
28+
SmallVector<Value> dynOutDims) {
29+
30+
assert((resType.getNumDynamicDims() == dynOutDims.size()) ||
31+
dynOutDims.empty() &&
32+
"Either none or all output dynamic dims must be specified!");
33+
34+
// Init "low" and "high" padding values ("low" is kept as is, "high" is
35+
// computed below).
36+
SmallVector<OpFoldResult> low(resType.getRank(), b.getIndexAttr(0));
37+
SmallVector<OpFoldResult> high(resType.getRank(), b.getIndexAttr(0));
38+
39+
size_t outDimIdx = 0;
40+
41+
for (const auto [idx, val] : enumerate(resType.getShape())) {
42+
bool isDimDynamic = ShapedType::isDynamic(val);
43+
bool updatePadHigh = !isDimDynamic || !dynOutDims.empty();
44+
45+
// Keep the default padding width (i.e. "0") when the output dim is dynamic
46+
// and no actual output sizes have been provided.
47+
if (!updatePadHigh)
3248
continue;
33-
// Compute the padding width.
34-
AffineExpr d0;
35-
bindDims(b.getContext(), d0);
36-
OpFoldResult sz = tensor::getMixedSize(b, loc, source, en.index());
37-
high[en.index()] =
38-
affine::makeComposedFoldedAffineApply(b, loc, en.value() - d0, {sz});
49+
50+
// Compute the padding width: resDim - sourceDim.
51+
AffineExpr d0, d1;
52+
bindDims(b.getContext(), d0, d1);
53+
OpFoldResult sourceDim = tensor::getMixedSize(b, loc, source, idx);
54+
OpFoldResult outDim = isDimDynamic ? OpFoldResult(dynOutDims[outDimIdx++])
55+
: OpFoldResult(b.getIndexAttr(val));
56+
57+
high[idx] = affine::makeComposedFoldedAffineApply(b, loc, d0 - d1,
58+
{outDim, sourceDim});
3959
}
40-
return b.create<PadOp>(loc, type, source, low, high, pad, nofold);
60+
return b.create<PadOp>(loc, resType, source, low, high, pad, nofold);
4161
}
4262

4363
SmallVector<Value> mlir::tensor::createDynamicDimValues(OpBuilder &b,

mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
2323
%0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<5x1xf32> -> tensor<1x1x8x2xf32>
2424
return %0 : tensor<1x1x8x2xf32>
2525
}
26+
// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0] -> (s0 - 5)>
27+
2628
// CHECK-LABEL: func.func @simple_pad_and_pack
2729
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
2830
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
@@ -34,6 +36,59 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
3436
// CHECK-SAME: [0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
3537
// CHECK: return %[[INSERT]]
3638

39+
/// Same as example above, but with dynamic tile size.
40+
41+
func.func @simple_pad_and_pack_dynamic(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32, %high: index) -> tensor<1x1x?x2xf32> {
42+
%0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%high, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
43+
return %0 : tensor<1x1x?x2xf32>
44+
}
45+
46+
// CHECK-LABEL: func.func @simple_pad_and_pack_dynamic(
47+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
48+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
49+
// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]
50+
// CHECK-SAME: %[[HIGH_VAL:.*]]: index) -> tensor<1x1x?x2xf32> {
51+
// CHECK: %[[C2:.*]] = arith.constant 2 : index
52+
// CHECK: %[[PAD_HIGH:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[HIGH_VAL]]]
53+
// CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
54+
// CHECK: tensor.yield %[[PAD_VAL]] : f32
55+
// CHECK-NOT: linalg.transpose
56+
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[VAL_10:.*]][0, 0] {{\[}}%[[HIGH_VAL]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
57+
// CHECK: %[[DIM:.*]] = tensor.dim %[[DEST]], %[[C2]] : tensor<1x1x?x2xf32>
58+
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
59+
// CHECK: return %[[RES]] : tensor<1x1x?x2xf32>
60+
61+
/// Same as example above, but with scalable tile size.
62+
63+
/// NOTE: For this example to make sense in practice, the "?" in the output shape
64+
/// should effectively be 8 * vector.vscale (and that's what tensor.dim
65+
/// below should return).
66+
67+
func.func @simple_pad_and_pack_scalable(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32) -> tensor<1x1x?x2xf32> {
68+
%c8 = arith.constant 8 : index
69+
%vscale = vector.vscale
70+
%c8_vscale = arith.muli %vscale, %c8 : index
71+
%0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%c8_vscale, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
72+
return %0 : tensor<1x1x?x2xf32>
73+
}
74+
75+
// CHECK-LABEL: func.func @simple_pad_and_pack_scalable(
76+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: tensor<5x1xf32>,
77+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]: tensor<1x1x?x2xf32>,
78+
// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]: f32) -> tensor<1x1x?x2xf32> {
79+
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
80+
// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
81+
// CHECK-DAG: %[[VS:.+]] = vector.vscale
82+
// CHECK: %[[C8_VS:.+]] = arith.muli %[[VS]], %[[C8]] : index
83+
// CHECK: %[[PAD_HIGH:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[C8_VS]]]
84+
// CHECK: %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
85+
// CHECK: tensor.yield %[[PAD_VAL]] : f32
86+
// CHECK-NOT: linalg.transpose
87+
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[PAD:.+]][0, 0] {{\[}}%[[C8_VS]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
88+
// CHECK: %[[DIM:.+]] = tensor.dim %[[DEST]], %[[C2]] : tensor<1x1x?x2xf32>
89+
// CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
90+
// CHECK: return %[[RES]] : tensor<1x1x?x2xf32>
91+
3792
// -----
3893

3994
func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32>{

0 commit comments

Comments
 (0)