Skip to content

Commit affd836

Browse files
committed
fixup! [mlir][tensor] Extend the logic to generalise tensor.pack
Address PR comments from Han-Chung. Some clean-up and also relaxing the requirement that the dynamic dim has to be a constant multiple of vector.vscale.
1 parent 78247f5 commit affd836

File tree

4 files changed

+62
-61
lines changed

4 files changed

+62
-61
lines changed

mlir/include/mlir/Dialect/Tensor/Utils/Utils.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414
namespace mlir {
1515
namespace tensor {
1616

17-
// Return a PadOp that pads `source` to `type` size where the static
18-
// sizes are assumed to be greater than the dynamic sizes. If `type` has dynamic
19-
// dimensions the padding width is set to zero. The op performs "high" padding
20-
// (i.e. it adds trailing padding values until the desired size is met).
17+
// Return a PadOp that pads `source` to `type` size. Output sizes (from `type`)
18+
// are assumed to be static and greater than the potentially dynamic input sizes
19+
// (from `source`). The op performs "high" padding (i.e. it adds trailing
20+
// padding values until the desired size is met).
2121
PadOp createPadHighOp(RankedTensorType type, Value source, Value pad,
2222
bool nofold, Location loc, OpBuilder &builder,
2323
std::optional<Value> dynOutDim = {});

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,9 +1028,8 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
10281028
/// This method assumes that all outer dims for this pack Op are 1.
10291029
///
10301030
/// At most _one_ inner tile size can be _dynamic_, all other inner tiles are
1031-
/// required to have static sizes. The inner tile that's dynamic must be a
1032-
/// multiple of vector.vscale (to support scalable tile sizes). This condition
1033-
/// can be relaxed in the future.
1031+
/// required to have static sizes. This restriction can be relaxed in the
1032+
/// future.
10341033
static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
10351034
tensor::PackOp packOp) {
10361035
Value input = packOp.getSource();
@@ -1049,8 +1048,8 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
10491048
DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
10501049
packOp.getDimAndTileMapping();
10511050

1052-
// The size of a scalable tile (if present).
1053-
Value scalableSize;
1051+
// The size of a dynamic tile (if present).
1052+
Value dynamicTileSize;
10541053

10551054
// Collect dims for the padded shape.
10561055
SmallVector<int64_t> paddedShape;
@@ -1080,16 +1079,15 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
10801079
// 2.2 Dynamic tile sizes
10811080
paddedShape.push_back(ShapedType::kDynamic);
10821081

1083-
// Get the value that holds the scalable size.
1084-
assert(!scalableSize && "Only one scalable size is supported ATM.");
1085-
scalableSize = llvm::dyn_cast_if_present<Value>(tileSizeForDim);
1086-
assert(vector::getConstantVscaleMultiplier(scalableSize) &&
1087-
"This dynamic shape is not a multiple of vscale, this !");
1082+
// Get the value that holds the dynamic size.
1083+
assert(!dynamicTileSize && "Only one dynamic tile is supported ATM.");
1084+
dynamicTileSize = llvm::dyn_cast_if_present<Value>(tileSizeForDim);
10881085
}
10891086
auto resultType =
10901087
RankedTensorType::get(paddedShape, inputType.getElementType());
10911088
return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(),
1092-
/*nofold=*/false, loc, builder, scalableSize);
1089+
/*nofold=*/false, loc, builder,
1090+
dynamicTileSize);
10931091
}
10941092

10951093
// Normalizes a permutation on a higher rank space to its actual size, e.g.
@@ -1152,14 +1150,6 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
11521150

11531151
LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
11541152
tensor::PackOp packOp, PatternRewriter &rewriter) const {
1155-
if (llvm::any_of(packOp.getMixedTiles(), [](OpFoldResult tile) {
1156-
return tile.is<Value>() && !vector::getConstantVscaleMultiplier(
1157-
llvm::dyn_cast<Value>(tile));
1158-
})) {
1159-
return rewriter.notifyMatchFailure(
1160-
packOp, "require inner tile sizes to be either static or a constant "
1161-
"multiple of vector.vscale");
1162-
}
11631153
if (llvm::count_if(packOp.getMixedTiles(),
11641154
[](OpFoldResult tile) { return tile.is<Value>(); }) > 1) {
11651155
return rewriter.notifyMatchFailure(
@@ -1221,22 +1211,20 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
12211211
SmallVector<int64_t> transpShape = readShape;
12221212
applyPermutationToVector<int64_t>(transpShape, perm);
12231213

1224-
// If there's a tile with a scalable size, retrieve its size. ATM only 1
1225-
// scalable tile is allowed.
1226-
Value scalableSize;
1214+
// If there's a tile with a dynamic size, retrieve its size. ATM only 1
1215+
// dynamic tile is allowed.
1216+
Value dynDimSize;
12271217
for (auto tile : packOp.getMixedTiles()) {
12281218
if (tile.is<Value>()) {
1229-
assert(!scalableSize && "Only one scalable size is supported ATM.");
1230-
scalableSize = cast<Value>(tile);
1231-
assert(vector::getConstantVscaleMultiplier(scalableSize) &&
1232-
"This dynamic shape is not a multiple of vscale!");
1219+
assert(!dynDimSize && "Only one scalable size is supported ATM.");
1220+
dynDimSize = cast<Value>(tile);
12331221
}
12341222
}
12351223

12361224
Value empty =
12371225
ShapedType::isDynamicShape(cast<ShapedType>(input.getType()).getShape())
12381226
? rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType,
1239-
scalableSize)
1227+
dynDimSize)
12401228
: rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
12411229
auto transposedOp =
12421230
rewriter.create<linalg::TransposeOp>(loc, tile, empty, perm);

mlir/lib/Dialect/Tensor/Utils/Utils.cpp

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,37 +26,30 @@ PadOp mlir::tensor::createPadHighOp(RankedTensorType type, Value source,
2626
Value pad, bool nofold, Location loc,
2727
OpBuilder &b,
2828
std::optional<Value> dynOutDim) {
29-
assert(llvm::count_if(
30-
type.getShape(),
31-
[](int64_t dim) { return ShapedType::isDynamic(dim); }) <= 1 &&
29+
30+
assert(type.getNumDynamicDims() <= 1 &&
3231
"At most one output dim can be dynamic!");
3332

3433
// Init "low" and "high" padding values ("low" is kept as is, "high" is
3534
// computed below).
3635
SmallVector<OpFoldResult> low(type.getRank(), b.getIndexAttr(0));
3736
SmallVector<OpFoldResult> high(type.getRank(), b.getIndexAttr(0));
38-
for (const auto &en : enumerate(type.getShape())) {
39-
if (!ShapedType::isDynamic(en.value())) {
40-
// Static sizes - the "high" value is computed based on the input and
41-
// output dims. Compute the padding width.
42-
AffineExpr d0;
43-
bindDims(b.getContext(), d0);
44-
OpFoldResult sz = tensor::getMixedSize(b, loc, source, en.index());
45-
high[en.index()] =
46-
affine::makeComposedFoldedAffineApply(b, loc, en.value() - d0, {sz});
47-
} else {
48-
// Dynamic sizes - the "high" value is computed based on the input dim
49-
// and `dynOutDim`.
50-
assert(dynOutDim.has_value() &&
51-
"dynamic output dim requires dynOutDim to be set");
52-
53-
// Compute the padding width.
54-
AffineExpr d0, d1;
55-
auto dimVal = b.create<tensor::DimOp>(loc, source, en.index());
56-
bindDims(b.getContext(), d0, d1);
57-
high[en.index()] = affine::makeComposedFoldedAffineApply(
58-
b, loc, d0 - d1, {dynOutDim.value(), dimVal.getResult()});
59-
}
37+
38+
for (const auto [idx, val] : enumerate(type.getShape())) {
39+
bool isOutDimDynamic = ShapedType::isDynamic(val);
40+
assert((!isOutDimDynamic || dynOutDim.has_value()) &&
41+
"dynamic output dim requires dynOutDim to be set");
42+
43+
// Compute the padding width: outDim - srcDim.
44+
AffineExpr d0, d1;
45+
bindDims(b.getContext(), d0, d1);
46+
OpFoldResult srcDim = tensor::getMixedSize(b, loc, source, idx);
47+
Value outDim = isOutDimDynamic
48+
? dynOutDim.value()
49+
: b.create<arith::ConstantIndexOp>(loc, val).getResult();
50+
51+
high[idx] = affine::makeComposedFoldedAffineApply(b, loc, d0 - d1,
52+
{outDim, srcDim});
6053
}
6154
return b.create<PadOp>(loc, type, source, low, high, pad, nofold);
6255
}

mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
2323
%0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<5x1xf32> -> tensor<1x1x8x2xf32>
2424
return %0 : tensor<1x1x8x2xf32>
2525
}
26-
// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 - s1)>
26+
// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0] -> (s0 - 5)>
2727

2828
// CHECK-LABEL: func.func @simple_pad_and_pack
2929
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
@@ -36,7 +36,29 @@ func.func @simple_pad_and_pack(%input: tensor<5x1xf32>, %output: tensor<1x1x8x2x
3636
// CHECK-SAME: [0, 0, 0, 0] [1, 1, 8, 2] [1, 1, 1, 1]
3737
// CHECK: return %[[INSERT]]
3838

39-
/// Same as example above, but with scalable sizes.
39+
/// Same as example above, but with dynamic tile size.
40+
41+
func.func @simple_pad_and_pack_dynamic(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32, %high: index) -> tensor<1x1x?x2xf32> {
42+
%0 = tensor.pack %input padding_value(%pad : f32) inner_dims_pos = [0, 1] inner_tiles = [%high, 2] into %output : tensor<5x1xf32> -> tensor<1x1x?x2xf32>
43+
return %0 : tensor<1x1x?x2xf32>
44+
}
45+
46+
// CHECK-LABEL: func.func @simple_pad_and_pack_dynamic(
47+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
48+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
49+
// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]
50+
// CHECK-SAME: %[[HIGH_VAL:.*]]: index) -> tensor<1x1x?x2xf32> {
51+
// CHECK: %[[C2:.*]] = arith.constant 2 : index
52+
// CHECK: %[[PAD_HIGH:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[HIGH_VAL]]]
53+
// CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
54+
// CHECK: tensor.yield %[[PAD_VAL]] : f32
55+
// CHECK-NOT: linalg.transpose
56+
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[VAL_10:.*]][0, 0] {{\[}}%[[HIGH_VAL]], 2] [1, 1] : tensor<?x2xf32> to tensor<?x2xf32>
57+
// CHECK: %[[DIM:.*]] = tensor.dim %[[DEST]], %[[C2]] : tensor<1x1x?x2xf32>
58+
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[DIM]], 2] [1, 1, 1, 1] : tensor<?x2xf32> into tensor<1x1x?x2xf32>
59+
// CHECK: return %[[RES]] : tensor<1x1x?x2xf32>
60+
61+
/// Same as example above, but with scalable tile size.
4062

4163
/// NOTE: For this example to make sense in practice, the "?" in the output shape
4264
/// should effectively be 8 * vector.vscale (and that's what tensor.dim
@@ -50,17 +72,15 @@ func.func @simple_pad_and_pack_scalable(%input: tensor<5x1xf32>, %output: tensor
5072
return %0 : tensor<1x1x?x2xf32>
5173
}
5274

53-
5475
// CHECK-LABEL: func.func @simple_pad_and_pack_scalable(
5576
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: tensor<5x1xf32>,
5677
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]: tensor<1x1x?x2xf32>,
5778
// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]: f32) -> tensor<1x1x?x2xf32> {
5879
// CHECK: %[[C2:.+]] = arith.constant 2 : index
59-
// CHECK: %[[C5:.+]] = arith.constant 5 : index
6080
// CHECK: %[[C8:.+]] = arith.constant 8 : index
6181
// CHECK: %[[VS:.+]] = vector.vscale
6282
// CHECK: %[[C8_VS:.+]] = arith.muli %[[VS]], %[[C8]] : index
63-
// CHECK: %[[PAD_HIGH:.+]] = affine.apply #[[$ATTR_0]](){{\[}}%[[C8_VS]], %[[C5]]]
83+
// CHECK: %[[PAD_HIGH:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[C8_VS]]]
6484
// CHECK: %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {
6585
// CHECK: tensor.yield %[[PAD_VAL]] : f32
6686
// CHECK-NOT: linalg.transpose

0 commit comments

Comments
 (0)