Skip to content

Commit 210fd91

Browse files
committed
fixup! fixup! fixup! [mlir][tensor] Extend the logic to generalise tensor.pack
Incorporating suggestions from @hanhanW
1 parent bfa6a2d commit 210fd91

File tree

3 files changed

+6
-17
lines changed

3 files changed

+6
-17
lines changed

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,10 +1054,10 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
10541054
// Collect dims for the padded shape.
10551055
SmallVector<int64_t> paddedShape;
10561056
for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) {
1057-
int64_t inputDimSize = inputType.getDimSize(dimIdx);
10581057
// 1. Non-tiled outer dims.
10591058
// These dims should be 1 and we simply preserve them.
10601059
if (!tileAndPosMapping.count(dimIdx)) {
1060+
int64_t inputDimSize = inputType.getDimSize(dimIdx);
10611061
assert(inputDimSize == 1 &&
10621062
"with all outer dims == 1, this non-tiled input dim should be 1!");
10631063
paddedShape.push_back(inputDimSize);
@@ -1214,16 +1214,6 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
12141214

12151215
applyPermutationToVector<OpFoldResult>(transShapeForEmpty, perm);
12161216

1217-
// If there's a tile with a dynamic size, retrieve its size. ATM only 1
1218-
// dynamic tile is allowed.
1219-
Value dynDimSize;
1220-
for (auto tile : packOp.getMixedTiles()) {
1221-
if (tile.is<Value>()) {
1222-
assert(!dynDimSize && "Only one scalable size is supported ATM.");
1223-
dynDimSize = cast<Value>(tile);
1224-
}
1225-
}
1226-
12271217
Value empty =
12281218
rewriter.create<tensor::EmptyOp>(loc, transShapeForEmpty, elemType);
12291219
auto transposedOp =

mlir/lib/Dialect/Tensor/Utils/Utils.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,8 @@ PadOp mlir::tensor::createPadHighOp(RankedTensorType resType, Value source,
5151
AffineExpr d0, d1;
5252
bindDims(b.getContext(), d0, d1);
5353
OpFoldResult sourceDim = tensor::getMixedSize(b, loc, source, idx);
54-
Value outDim = isDimDynamic
55-
? dynOutDims[outDimIdx++]
56-
: b.create<arith::ConstantIndexOp>(loc, val).getResult();
54+
OpFoldResult outDim = isDimDynamic ? OpFoldResult(dynOutDims[outDimIdx++])
55+
: OpFoldResult(b.getIndexAttr(val));
5756

5857
high[idx] = affine::makeComposedFoldedAffineApply(b, loc, d0 - d1,
5958
{outDim, sourceDim});

mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ func.func @simple_pad_and_pack_scalable(%input: tensor<5x1xf32>, %output: tensor
7676
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]: tensor<5x1xf32>,
7777
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]: tensor<1x1x?x2xf32>,
7878
// CHECK-SAME: %[[PAD_VAL:[a-zA-Z0-9]+]]: f32) -> tensor<1x1x?x2xf32> {
79-
// CHECK: %[[C2:.+]] = arith.constant 2 : index
80-
// CHECK: %[[C8:.+]] = arith.constant 8 : index
81-
// CHECK: %[[VS:.+]] = vector.vscale
79+
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
80+
// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index
81+
// CHECK-DAG: %[[VS:.+]] = vector.vscale
8282
// CHECK: %[[C8_VS:.+]] = arith.muli %[[VS]], %[[C8]] : index
8383
// CHECK: %[[PAD_HIGH:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[C8_VS]]]
8484
// CHECK: %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] {

0 commit comments

Comments
 (0)