Skip to content

Commit 78247f5

Browse files
committed
[mlir][tensor] Extend the logic to generalise tensor.pack
Extends the logic to generalise tensor.pack (into e.g. tensor.pad + tensor.transpose) so that it also works when one of the inner tile sizes is scalable (i.e. a multiple of `vector.vscale`). For example: ```mlir %c8 = arith.constant 8 : index %vscale = vector.vscale %c8_vscale = arith.muli %vscale, %c8 : index %0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%c8_vscale, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32> } ``` is generalised as: ```mlir %c8 = arith.constant 8 : index %vscale = vector.vscale %c8_vscale = arith.muli %vscale, %c8 : index %0 = affine.apply #map()[%c8_vscale, %c5] %padded = tensor.pad %arg0 low[0, 0] high[%0, 1] { ^bb0(%arg3: index, %arg4: index): tensor.yield %arg2 : f32 } : tensor<5x1xf32> to tensor<?x2xf32> ``` At the Tensor level, we model scalability using dynamic shapes and this change basically extends the relevant logic so that it also works for dynamic shapes. However, rather than allowing arbitrary values and number of tile sizes to be dynamic, only _one_ tile size is allowed to be dynamic. In addition, it is required to be a constant multiple of `vector.vscale`. While the requirements above can be relaxed, I wanted to avoid full generality for now. Primarily to avoid complexity that's not yet needed and to make reviewing a bit easier.
1 parent 33c134e commit 78247f5

File tree

4 files changed

+142
-29
lines changed

4 files changed

+142
-29
lines changed

mlir/include/mlir/Dialect/Tensor/Utils/Utils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ namespace tensor {
1919
// dimensions the padding width is set to zero. The op performs "high" padding
2020
// (i.e. it adds trailing padding values until the desired size is met).
2121
PadOp createPadHighOp(RankedTensorType type, Value source, Value pad,
22-
bool nofold, Location loc, OpBuilder &builder);
22+
bool nofold, Location loc, OpBuilder &builder,
23+
std::optional<Value> dynOutDim = {});
2324

2425
// Creates dim ops for each dynamic dimension of the ranked tensor argument and
2526
// returns these as values.

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 74 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,8 +1021,16 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
10211021
return success();
10221022
}
10231023

1024-
/// Returns a tensor.pad op if padding value is set. Otherwise, returns the
1025-
/// source directly. The method assumes that the `packOp` has static shapes.
1024+
/// If padding value is set, returns a tensor.pad Op for the source tensor,
1025+
/// with the output shape matching the output of `packOp`. Otherwise, returns
1026+
/// the source directly.
1027+
///
1028+
/// This method assumes that all outer dims for this pack Op are 1.
1029+
///
1030+
/// At most _one_ inner tile size can be _dynamic_, all other inner tiles are
1031+
/// required to have static sizes. The inner tile that's dynamic must be a
1032+
/// multiple of vector.vscale (to support scalable tile sizes). This condition
1033+
/// can be relaxed in the future.
10261034
static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
10271035
tensor::PackOp packOp) {
10281036
Value input = packOp.getSource();
@@ -1038,26 +1046,50 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
10381046
ShapedType inputType = packOp.getSourceType();
10391047
int64_t inputRank = inputType.getRank();
10401048

1041-
SmallVector<int64_t> paddedShape;
10421049
DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
10431050
packOp.getDimAndTileMapping();
1044-
for (int64_t dim = 0; dim < inputRank; ++dim) {
1045-
int64_t size = inputType.getDimSize(dim);
1046-
if (!tileAndPosMapping.count(dim)) {
1047-
paddedShape.push_back(size);
1051+
1052+
// The size of a scalable tile (if present).
1053+
Value scalableSize;
1054+
1055+
// Collect dims for the padded shape.
1056+
SmallVector<int64_t> paddedShape;
1057+
for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
1058+
int64_t inputDimSize = inputType.getDimSize(dimIdx);
1059+
// 1. Non-tiled outer dims.
1060+
// These dims should be 1 and we simply preserve them.
1061+
if (!tileAndPosMapping.count(dimIdx)) {
1062+
assert(inputDimSize == 1 &&
1063+
"with all outer dims == 1, this non-tiled input dim should be 1!");
1064+
paddedShape.push_back(inputDimSize);
1065+
continue;
1066+
}
1067+
1068+
// 2. Tiled outer dims
1069+
// As all outer dims == 1, it is safe to use the tile size for the padded
1070+
// shape.
1071+
OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx);
1072+
1073+
// 2.1 Static tile sizes
1074+
std::optional<int64_t> cstTileSize = getConstantIntValue(tileSizeForDim);
1075+
if (cstTileSize.has_value()) {
1076+
paddedShape.push_back(cstTileSize.value());
10481077
continue;
10491078
}
10501079

1051-
// The size is less than or equal to tileSize because outer dims are all 1s.
1052-
std::optional<int64_t> tileSize =
1053-
getConstantIntValue(tileAndPosMapping.lookup(dim));
1054-
assert(tileSize.has_value() && "dynamic inner tile size is not supported");
1055-
paddedShape.push_back(tileSize.value());
1080+
// 2.2 Dynamic tile sizes
1081+
paddedShape.push_back(ShapedType::kDynamic);
1082+
1083+
// Get the value that holds the scalable size.
1084+
assert(!scalableSize && "Only one scalable size is supported ATM.");
1085+
scalableSize = llvm::dyn_cast_if_present<Value>(tileSizeForDim);
1086+
assert(vector::getConstantVscaleMultiplier(scalableSize) &&
1087+
"This dynamic shape is not a multiple of vscale, this !");
10561088
}
10571089
auto resultType =
10581090
RankedTensorType::get(paddedShape, inputType.getElementType());
10591091
return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(),
1060-
/*nofold=*/false, loc, builder);
1092+
/*nofold=*/false, loc, builder, scalableSize);
10611093
}
10621094

10631095
// Normalizes a permutation on a higher rank space to its actual size, e.g.
@@ -1120,10 +1152,18 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
11201152

11211153
LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
11221154
tensor::PackOp packOp, PatternRewriter &rewriter) const {
1123-
if (llvm::any_of(packOp.getMixedTiles(),
1124-
[](OpFoldResult tile) { return tile.is<Value>(); })) {
1125-
return rewriter.notifyMatchFailure(packOp,
1126-
"require inner tile sizes being static");
1155+
if (llvm::any_of(packOp.getMixedTiles(), [](OpFoldResult tile) {
1156+
return tile.is<Value>() && !vector::getConstantVscaleMultiplier(
1157+
llvm::dyn_cast<Value>(tile));
1158+
})) {
1159+
return rewriter.notifyMatchFailure(
1160+
packOp, "require inner tile sizes to be either static or a constant "
1161+
"multiple of vector.vscale");
1162+
}
1163+
if (llvm::count_if(packOp.getMixedTiles(),
1164+
[](OpFoldResult tile) { return tile.is<Value>(); }) > 1) {
1165+
return rewriter.notifyMatchFailure(
1166+
packOp, "at most one dynamic tile size is supported");
11271167
}
11281168

11291169
// TODO: support the case that outer dimensions are not all 1s. A
@@ -1181,7 +1221,23 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
11811221
SmallVector<int64_t> transpShape = readShape;
11821222
applyPermutationToVector<int64_t>(transpShape, perm);
11831223

1184-
Value empty = rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
1224+
// If there's a tile with a scalable size, retrieve its size. ATM only 1
1225+
// scalable tile is allowed.
1226+
Value scalableSize;
1227+
for (auto tile : packOp.getMixedTiles()) {
1228+
if (tile.is<Value>()) {
1229+
assert(!scalableSize && "Only one scalable size is supported ATM.");
1230+
scalableSize = cast<Value>(tile);
1231+
assert(vector::getConstantVscaleMultiplier(scalableSize) &&
1232+
"This dynamic shape is not a multiple of vscale!");
1233+
}
1234+
}
1235+
1236+
Value empty =
1237+
ShapedType::isDynamicShape(cast<ShapedType>(input.getType()).getShape())
1238+
? rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType,
1239+
scalableSize)
1240+
: rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
11851241
auto transposedOp =
11861242
rewriter.create<linalg::TransposeOp>(loc, tile, empty, perm);
11871243

mlir/lib/Dialect/Tensor/Utils/Utils.cpp

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,47 @@
1616
#include "mlir/Dialect/Arith/IR/Arith.h"
1717
#include "mlir/Dialect/Arith/Utils/Utils.h"
1818
#include "mlir/Dialect/Utils/IndexingUtils.h"
19+
#include "mlir/Dialect/Vector/IR//VectorOps.h"
1920
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
2021

2122
using namespace mlir;
2223
using namespace mlir::tensor;
2324

2425
PadOp mlir::tensor::createPadHighOp(RankedTensorType type, Value source,
2526
Value pad, bool nofold, Location loc,
26-
OpBuilder &b) {
27+
OpBuilder &b,
28+
std::optional<Value> dynOutDim) {
29+
assert(llvm::count_if(
30+
type.getShape(),
31+
[](int64_t dim) { return ShapedType::isDynamic(dim); }) <= 1 &&
32+
"At most one output dim can be dynamic!");
33+
34+
// Init "low" and "high" padding values ("low" is kept as is, "high" is
35+
// computed below).
2736
SmallVector<OpFoldResult> low(type.getRank(), b.getIndexAttr(0));
2837
SmallVector<OpFoldResult> high(type.getRank(), b.getIndexAttr(0));
2938
for (const auto &en : enumerate(type.getShape())) {
30-
// Pad only the static dimensions of the result tensor type.
31-
if (ShapedType::isDynamic(en.value()))
32-
continue;
33-
// Compute the padding width.
34-
AffineExpr d0;
35-
bindDims(b.getContext(), d0);
36-
OpFoldResult sz = tensor::getMixedSize(b, loc, source, en.index());
37-
high[en.index()] =
38-
affine::makeComposedFoldedAffineApply(b, loc, en.value() - d0, {sz});
39+
if (!ShapedType::isDynamic(en.value())) {
40+
// Static sizes - the "high" value is computed based on the input and
41+
// output dims. Compute the padding width.
42+
AffineExpr d0;
43+
bindDims(b.getContext(), d0);
44+
OpFoldResult sz = tensor::getMixedSize(b, loc, source, en.index());
45+
high[en.index()] =
46+
affine::makeComposedFoldedAffineApply(b, loc, en.value() - d0, {sz});
47+
} else {
48+
// Dynamic sizes - the "high" value is computed based on the input dim
49+
// and `dynOutDim`.
50+
assert(dynOutDim.has_value() &&
51+
"dynamic output dim requires dynOutDim to be set");
52+
53+
// Compute the padding width.
54+
AffineExpr d0, d1;
55+
auto dimVal = b.create<tensor::DimOp>(loc, source, en.index());
56+
bindDims(b.getContext(), d0, d1);
57+
high[en.index()] = affine::makeComposedFoldedAffineApply(
58+
b, loc, d0 - d1, {dynOutDim.value(), dimVal.getResult()});
59+
}
3960
}
4061
return b.create<PadOp>(loc, type, source, low, high, pad, nofold);
4162
}

mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
2323
%0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<5x1xf32> -> tensor<1x1x8x2xf32>
2424
return %0 : tensor<1x1x8x2xf32>
2525
}
26+
// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 - s1)>
27+
2628
// CHECK-LABEL: func.func @simple_pad_and_pack
2729
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
2830
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
@@ -34,6 +36,39 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
3436
// CHECK-SAME: [0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
3537
// CHECK: return %[[INSERT]]
3638

39+
/// Same as example above, but with scalable sizes.
40+
41+
/// NOTE: For this example to make sense in practice, the "?" in the output shape
42+
/// should effectively be 8 * vector.vscale (and that's what tensor.dim
43+
/// below should return).
44+
45+
func.func @simple_pad_and_pack_scalable(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32) -> tensor<1x1x?x2xf32> {
46+
%c8 = arith.constant 8 : index
47+
%vscale = vector.vscale
48+
%c8_vscale = arith.muli %vscale, %c8 : index
49+
%0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%c8_vscale, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
50+
return %0 : tensor<1x1x?x2xf32>
51+
}
52+
53+
54+
// CHECK-LABEL: func.func @simple_pad_and_pack_scalable(
55+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: tensor<5x1xf32>,
56+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]: tensor<1x1x?x2xf32>,
57+
// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]: f32) -> tensor<1x1x?x2xf32> {
58+
// CHECK: %[[C2:.+]] = arith.constant 2 : index
59+
// CHECK: %[[C5:.+]] = arith.constant 5 : index
60+
// CHECK: %[[C8:.+]] = arith.constant 8 : index
61+
// CHECK: %[[VS:.+]] = vector.vscale
62+
// CHECK: %[[C8_VS:.+]] = arith.muli %[[VS]], %[[C8]] : index
63+
// CHECK: %[[PAD_HIGH:.+]] = affine.apply #[[$ATTR_0]](){{\[}}%[[C8_VS]], %[[C5]]]
64+
// CHECK: %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
65+
// CHECK: tensor.yield %[[PAD_VAL]] : f32
66+
// CHECK-NOT: linalg.transpose
67+
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[PAD:.+]][0, 0] {{\[}}%[[C8_VS]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
68+
// CHECK: %[[DIM:.+]] = tensor.dim %[[DEST]], %[[C2]] : tensor<1x1x?x2xf32>
69+
// CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
70+
// CHECK: return %[[RES]] : tensor<1x1x?x2xf32>
71+
3772
// -----
3873

3974
func.func @simple_NC_to_CNnc(%arg0: tensor<32x8xf32>, %arg1: tensor<1x1x32x8xf32>) -> tensor<1x1x32x8xf32>{

0 commit comments

Comments
 (0)