Skip to content

Commit e016ccb

Browse files
committed
Include output size in determining UB for tensor.pack
1 parent a085402 commit e016ccb

File tree

4 files changed

+40
-13
lines changed

4 files changed

+40
-13
lines changed

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1943,11 +1943,12 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
19431943

19441944
// Returns true if we have enough static information to catch undefined
19451945
// behavior when the tile size does not divide perfectly the dimension of
1946-
// the input tensor. If a given dimension or a tile associated with it is
1947-
// dynamic, the dimension is not considered as we don't have enough static
1948-
// information to understand if the tile perfectly divides that dimension.
1946+
// the input tensor. Detecting UB requires that the input size and either
1947+
// corresponding tile or output size are static.
19491948
static bool requirePaddingValue(ArrayRef<int64_t> inputShape,
19501949
ArrayRef<int64_t> innerDimsPos,
1950+
ArrayRef<int64_t> outputShape,
1951+
ArrayRef<int64_t> outerDimsPerm,
19511952
ArrayRef<OpFoldResult> innerTiles);
19521953

19531954
static Value createDestinationTensor(OpBuilder &b, Location loc,

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -582,8 +582,10 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
582582
return getConstantIntValue(tile).has_value();
583583
});
584584
if (areConstantTiles && operandType.hasStaticShape() &&
585-
!tensor::PackOp::requirePaddingValue(operandType.getShape(), innerPos,
586-
innerPackSizes)) {
585+
!tensor::PackOp::requirePaddingValue(
586+
operandType.getShape(), innerPos,
587+
dest.getType().cast<ShapedType>().getShape(), {},
588+
innerPackSizes)) {
587589
packOps.push_back(rewriter.create<tensor::PackOp>(
588590
loc, operand, dest, innerPos, innerPackSizes));
589591
} else {

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3742,14 +3742,27 @@ SmallVector<int64_t> PackOp::getStaticTiles() {
37423742

37433743
bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
37443744
ArrayRef<int64_t> innerDimsPos,
3745+
ArrayRef<int64_t> outputShape,
3746+
ArrayRef<int64_t> outerDimsPerm,
37453747
ArrayRef<OpFoldResult> innerTiles) {
3748+
SmallVector<int64_t> outputTileSizes(
3749+
outputShape.take_front(inputShape.size()));
3750+
if (!outerDimsPerm.empty()) {
3751+
assert(outerDimsPerm.size() == outputTileSizes.size() &&
3752+
"expected output and outer_dims_perm to have same size");
3753+
applyPermutationToVector(outputTileSizes,
3754+
invertPermutationVector(outerDimsPerm));
3755+
}
37463756
for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
37473757
if (ShapedType::isDynamic(inputShape[pos]))
37483758
continue;
37493759
std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
3750-
if (!constantTile)
3751-
continue;
3752-
if (inputShape[pos] % (*constantTile) != 0)
3760+
3761+
if (!constantTile) {
3762+
if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
3763+
(inputShape[pos] % outputTileSizes[pos] != 0))
3764+
return true;
3765+
} else if (inputShape[pos] % (*constantTile) != 0)
37533766
return true;
37543767
}
37553768
return false;
@@ -3772,9 +3785,11 @@ LogicalResult PackOp::verify() {
37723785

37733786
if (!paddingValue &&
37743787
requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
3788+
getDestType().getShape(), getOuterDimsPerm(),
37753789
getMixedTiles())) {
3776-
return emitOpError("invalid tile factor provided. Only full tiles are "
3777-
"supported when padding_value is not set");
3790+
return emitOpError(
3791+
"invalid tile factor or output size provided. Only full tiles are "
3792+
"supported when padding_value is not set");
37783793
}
37793794
return success();
37803795
}
@@ -3975,8 +3990,9 @@ static bool paddingIsNotNeeded(PackOp op) {
39753990
return false;
39763991
if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
39773992
return false;
3978-
return !PackOp::requirePaddingValue(srcType.getShape(), op.getInnerDimsPos(),
3979-
op.getMixedTiles());
3993+
return !PackOp::requirePaddingValue(
3994+
srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
3995+
op.getOuterDimsPerm(), op.getMixedTiles());
39803996
}
39813997

39823998
LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {

mlir/test/Dialect/Tensor/invalid.mlir

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -597,13 +597,21 @@ func.func @empty_wrong_number_of_operands(%sz : index) {
597597
// -----
598598

599599
func.func @pack_invalid_no_padding_no_full_tiles(%input: tensor<256x128xf32>, %output: tensor<8x8x16x33xf32>) -> tensor<8x8x16x33xf32> {
600-
// expected-error@+1 {{invalid tile factor provided. Only full tiles are supported when padding_value is not set}}
600+
// expected-error@+1 {{invalid tile factor or output size provided. Only full tiles are supported when padding_value is not set}}
601601
%0 = tensor.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 33] into %output : tensor<256x128xf32> -> tensor<8x8x16x33xf32>
602602
return %0 : tensor<8x8x16x33xf32>
603603
}
604604

605605
// -----
606606

607+
func.func @pack_invalid_no_padding_no_full_tiles_dyn_tiles(%input: tensor<256x128xf32>, %output: tensor<10x8x?x?xf32>, %tile_size_0: index, %tile_size_1: index) -> tensor<10x8x?x?xf32> {
608+
// expected-error@+1 {{invalid tile factor or output size provided. Only full tiles are supported when padding_value is not set}}
609+
%0 = tensor.pack %input inner_dims_pos = [1, 0] inner_tiles = [%tile_size_0, %tile_size_1] into %output : tensor<256x128xf32> -> tensor<10x8x?x?xf32>
610+
return %0 : tensor<10x8x?x?xf32>
611+
}
612+
613+
// -----
614+
607615
func.func @pad_and_pack_invalid_type(%input: tensor<13x15xf32>, %output: tensor<2x8x8x2xf32>, %pad: i32) -> tensor<2x8x8x2xf32> {
608616
// expected-error@+1 {{expected padding_value has 'f32' but got: 'i32'}}
609617
%0 = tensor.pack %input padding_value(%pad: i32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<13x15xf32> -> tensor<2x8x8x2xf32>

0 commit comments

Comments
 (0)