Skip to content

Commit 9466c4e

Browse files
authored
[MLIR][tensor] Improve tensor.pack verifier to catch more cases with unconditional runtime errors (#77217)
Previously, the `tensor.pack` verifier detects unconditional runtime errors only when tile sizes are static. Now, dynamic tiles are considered and we only require that the input and either corresponding tile or output size are static to determine if it will unconditionally produce errors at runtime.
1 parent 5bd374d commit 9466c4e

File tree

4 files changed

+49
-13
lines changed

4 files changed

+49
-13
lines changed

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1943,11 +1943,12 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
19431943

19441944
// Returns true if we have enough static information to catch undefined
19451945
// behavior when the tile size does not divide perfectly the dimension of
1946-
// the input tensor. If a given dimension or a tile associated with it is
1947-
// dynamic, the dimension is not considered as we don't have enough static
1948-
// information to understand if the tile perfectly divides that dimension.
1946+
// the input tensor. Detecting UB requires that the input size and either
1947+
// corresponding tile or output size are static.
19491948
static bool requirePaddingValue(ArrayRef<int64_t> inputShape,
19501949
ArrayRef<int64_t> innerDimsPos,
1950+
ArrayRef<int64_t> outputShape,
1951+
ArrayRef<int64_t> outerDimsPerm,
19511952
ArrayRef<OpFoldResult> innerTiles);
19521953

19531954
static Value createDestinationTensor(OpBuilder &b, Location loc,

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -563,8 +563,10 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
563563
return getConstantIntValue(tile).has_value();
564564
});
565565
if (areConstantTiles && operandType.hasStaticShape() &&
566-
!tensor::PackOp::requirePaddingValue(operandType.getShape(), innerPos,
567-
innerPackSizes)) {
566+
!tensor::PackOp::requirePaddingValue(
567+
operandType.getShape(), innerPos,
568+
dest.getType().cast<ShapedType>().getShape(), {},
569+
innerPackSizes)) {
568570
packOps.push_back(rewriter.create<tensor::PackOp>(
569571
loc, operand, dest, innerPos, innerPackSizes));
570572
} else {

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3746,15 +3746,29 @@ SmallVector<int64_t> PackOp::getStaticTiles() {
37463746

37473747
bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
37483748
ArrayRef<int64_t> innerDimsPos,
3749+
ArrayRef<int64_t> outputShape,
3750+
ArrayRef<int64_t> outerDimsPerm,
37493751
ArrayRef<OpFoldResult> innerTiles) {
3752+
SmallVector<int64_t> outputTileSizes(
3753+
outputShape.take_front(inputShape.size()));
3754+
if (!outerDimsPerm.empty()) {
3755+
assert(outerDimsPerm.size() == outputTileSizes.size() &&
3756+
"expected output and outer_dims_perm to have same size");
3757+
applyPermutationToVector(outputTileSizes,
3758+
invertPermutationVector(outerDimsPerm));
3759+
}
37503760
for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
37513761
if (ShapedType::isDynamic(inputShape[pos]))
37523762
continue;
37533763
std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
3754-
if (!constantTile)
3755-
continue;
3756-
if (inputShape[pos] % (*constantTile) != 0)
3764+
3765+
if (!constantTile) {
3766+
if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
3767+
(inputShape[pos] % outputTileSizes[pos] != 0))
3768+
return true;
3769+
} else if (inputShape[pos] % (*constantTile) != 0) {
37573770
return true;
3771+
}
37583772
}
37593773
return false;
37603774
}
@@ -3776,9 +3790,11 @@ LogicalResult PackOp::verify() {
37763790

37773791
if (!paddingValue &&
37783792
requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
3793+
getDestType().getShape(), getOuterDimsPerm(),
37793794
getMixedTiles())) {
3780-
return emitOpError("invalid tile factor provided. Only full tiles are "
3781-
"supported when padding_value is not set");
3795+
return emitOpError(
3796+
"invalid tile factor or output size provided. Only full tiles are "
3797+
"supported when padding_value is not set");
37823798
}
37833799
return success();
37843800
}
@@ -3979,8 +3995,9 @@ static bool paddingIsNotNeeded(PackOp op) {
39793995
return false;
39803996
if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
39813997
return false;
3982-
return !PackOp::requirePaddingValue(srcType.getShape(), op.getInnerDimsPos(),
3983-
op.getMixedTiles());
3998+
return !PackOp::requirePaddingValue(
3999+
srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
4000+
op.getOuterDimsPerm(), op.getMixedTiles());
39844001
}
39854002

39864003
/// Returns true if the `srcShape` or `destShape` is different from the one in

mlir/test/Dialect/Tensor/invalid.mlir

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,13 +597,29 @@ func.func @empty_wrong_number_of_operands(%sz : index) {
597597
// -----
598598

599599
func.func @pack_invalid_no_padding_no_full_tiles(%input: tensor<256x128xf32>, %output: tensor<8x8x16x33xf32>) -> tensor<8x8x16x33xf32> {
600-
// expected-error@+1 {{invalid tile factor provided. Only full tiles are supported when padding_value is not set}}
600+
// expected-error@+1 {{invalid tile factor or output size provided. Only full tiles are supported when padding_value is not set}}
601601
%0 = tensor.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 33] into %output : tensor<256x128xf32> -> tensor<8x8x16x33xf32>
602602
return %0 : tensor<8x8x16x33xf32>
603603
}
604604

605605
// -----
606606

607+
func.func @pack_invalid_no_padding_no_full_tiles_dyn_tiles(%input: tensor<256x128xf32>, %output: tensor<10x8x?x?xf32>, %tile_size_0: index, %tile_size_1: index) -> tensor<10x8x?x?xf32> {
608+
// expected-error@+1 {{invalid tile factor or output size provided. Only full tiles are supported when padding_value is not set}}
609+
%0 = tensor.pack %input inner_dims_pos = [1, 0] inner_tiles = [%tile_size_0, %tile_size_1] into %output : tensor<256x128xf32> -> tensor<10x8x?x?xf32>
610+
return %0 : tensor<10x8x?x?xf32>
611+
}
612+
613+
// -----
614+
615+
func.func @pack_invalid_no_padding_no_full_tiles_dyn_tiles_outperm(%input: tensor<256x128xf32>, %output: tensor<8x10x?x?xf32>, %tile_size_0: index, %tile_size_1: index) -> tensor<8x10x?x?xf32> {
616+
// expected-error@+1 {{invalid tile factor or output size provided. Only full tiles are supported when padding_value is not set}}
617+
%0 = tensor.pack %input outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [%tile_size_0, %tile_size_1] into %output : tensor<256x128xf32> -> tensor<8x10x?x?xf32>
618+
return %0 : tensor<8x10x?x?xf32>
619+
}
620+
621+
// -----
622+
607623
func.func @pad_and_pack_invalid_type(%input: tensor<13x15xf32>, %output: tensor<2x8x8x2xf32>, %pad: i32) -> tensor<2x8x8x2xf32> {
608624
// expected-error@+1 {{expected padding_value has 'f32' but got: 'i32'}}
609625
%0 = tensor.pack %input padding_value(%pad: i32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<13x15xf32> -> tensor<2x8x8x2xf32>

0 commit comments

Comments
 (0)