Skip to content

Commit b3da00e

Browse files
committed
fixup! fixup! [mlir][tensor] Extend the logic to generalise tensor.pack
Allow multiple dynamic dims and remove the dependency on #109667
1 parent 894974f commit b3da00e

File tree

3 files changed

+42
-27
lines changed

3 files changed

+42
-27
lines changed

mlir/include/mlir/Dialect/Tensor/Utils/Utils.h

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,22 @@
1414
namespace mlir {
1515
namespace tensor {
1616

17-
// Return a PadOp that pads `source` to `type` size. Output sizes (from `type`)
18-
// are assumed to be static and greater than the potentially dynamic input sizes
19-
// (from `source`). The op performs "high" padding (i.e. it adds trailing
20-
// padding values until the desired size is met).
21-
PadOp createPadHighOp(RankedTensorType type, Value source, Value pad,
17+
// Return a PadOp that pads `source` to `resType` size. The op performs "high"
18+
// padding, i.e. it adds trailing padding values until the desired size is met.
19+
// Output sizes are assumed to be greater than the input sizes. The padding
20+
// width is calculated as: resDim - sourceDim.
21+
//
22+
// Handling static sizes is trivial. Dynamic dimensions are trickier (*):
23+
// 1. dynamic input sizes are extracted from `source`
24+
// 2. for dynamic output dims, there are two options:
25+
// 2.1 all output dynamic dim sizes are specified in `dynOutDim`,
26+
// 2.2 `dynOutDim` is empty and the corresponding padding width is set to 0.
27+
//
28+
// (*) Note that `resType` is just a shape and it only encodes the actual sizes
29+
// for _static_ dimensions.
30+
PadOp createPadHighOp(RankedTensorType resType, Value source, Value pad,
2231
bool nofold, Location loc, OpBuilder &builder,
23-
std::optional<Value> dynOutDim = {});
32+
SmallVector<Value> dynOutDim = {});
2433

2534
// Creates dim ops for each dynamic dimension of the ranked tensor argument and
2635
// returns these as values.

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,8 +1048,8 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
10481048
DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
10491049
packOp.getDimAndTileMapping();
10501050

1051-
// The size of a dynamic tile (if present).
1052-
Value dynamicTileSize;
1051+
// The sizes of dynamic tiles
1052+
SmallVector<Value> dynamicTileSizes;
10531053

10541054
// Collect dims for the padded shape.
10551055
SmallVector<int64_t> paddedShape;
@@ -1080,14 +1080,13 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
10801080
paddedShape.push_back(ShapedType::kDynamic);
10811081

10821082
// Get the value that holds the dynamic size.
1083-
assert(!dynamicTileSize && "Only one dynamic tile is supported ATM.");
1084-
dynamicTileSize = llvm::dyn_cast_if_present<Value>(tileSizeForDim);
1083+
dynamicTileSizes.push_back(llvm::dyn_cast<Value>(tileSizeForDim));
10851084
}
10861085
auto resultType =
10871086
RankedTensorType::get(paddedShape, inputType.getElementType());
10881087
return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(),
10891088
/*nofold=*/false, loc, builder,
1090-
dynamicTileSize);
1089+
dynamicTileSizes);
10911090
}
10921091

10931092
// Normalizes a permutation on a higher rank space to its actual size, e.g.

mlir/lib/Dialect/Tensor/Utils/Utils.cpp

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,36 +22,43 @@
2222
using namespace mlir;
2323
using namespace mlir::tensor;
2424

25-
PadOp mlir::tensor::createPadHighOp(RankedTensorType type, Value source,
25+
PadOp mlir::tensor::createPadHighOp(RankedTensorType resType, Value source,
2626
Value pad, bool nofold, Location loc,
2727
OpBuilder &b,
28-
std::optional<Value> dynOutDim) {
28+
SmallVector<Value> dynOutDims) {
2929

30-
assert(type.getNumDynamicDims() <= 1 &&
31-
"At most one output dim can be dynamic!");
30+
assert((resType.getNumDynamicDims() == dynOutDims.size()) ||
31+
dynOutDims.empty() &&
32+
"Either none or all output dynamic dims must be specified!");
3233

3334
// Init "low" and "high" padding values ("low" is kept as is, "high" is
3435
// computed below).
35-
SmallVector<OpFoldResult> low(type.getRank(), b.getIndexAttr(0));
36-
SmallVector<OpFoldResult> high(type.getRank(), b.getIndexAttr(0));
36+
SmallVector<OpFoldResult> low(resType.getRank(), b.getIndexAttr(0));
37+
SmallVector<OpFoldResult> high(resType.getRank(), b.getIndexAttr(0));
3738

38-
for (const auto [idx, val] : enumerate(type.getShape())) {
39-
bool isOutDimDynamic = ShapedType::isDynamic(val);
40-
assert((!isOutDimDynamic || dynOutDim.has_value()) &&
41-
"dynamic output dim requires dynOutDim to be set");
39+
size_t outDimIdx = 0;
4240

43-
// Compute the padding width: outDim - srcDim.
41+
for (const auto [idx, val] : enumerate(resType.getShape())) {
42+
bool isDimDynamic = ShapedType::isDynamic(val);
43+
bool updatePadHigh = !isDimDynamic || !dynOutDims.empty();
44+
45+
// Keep the default padding width (i.e. "0") when the output dim is dynamic
46+
// and no actual output sizes have been provided.
47+
if (!updatePadHigh)
48+
continue;
49+
50+
// Compute the padding width: resDim - sourceDim.
4451
AffineExpr d0, d1;
4552
bindDims(b.getContext(), d0, d1);
46-
OpFoldResult srcDim = tensor::getMixedSize(b, loc, source, idx);
47-
Value outDim = isOutDimDynamic
48-
? dynOutDim.value()
53+
OpFoldResult sourceDim = tensor::getMixedSize(b, loc, source, idx);
54+
Value outDim = isDimDynamic
55+
? dynOutDims[outDimIdx++]
4956
: b.create<arith::ConstantIndexOp>(loc, val).getResult();
5057

5158
high[idx] = affine::makeComposedFoldedAffineApply(b, loc, d0 - d1,
52-
{outDim, srcDim});
59+
{outDim, sourceDim});
5360
}
54-
return b.create<PadOp>(loc, type, source, low, high, pad, nofold);
61+
return b.create<PadOp>(loc, resType, source, low, high, pad, nofold);
5562
}
5663

5764
SmallVector<Value> mlir::tensor::createDynamicDimValues(OpBuilder &b,

0 commit comments

Comments
 (0)