Skip to content

[MLIR][tensor] Improve tensor.pack verifier to catch more cases with unconditional runtime errors #77217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1943,11 +1943,12 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [

// Returns true if we have enough static information to catch undefined
// behavior when the tile size does not divide perfectly the dimension of
// the input tensor. If a given dimension or a tile associated with it is
// dynamic, the dimension is not considered as we don't have enough static
// information to understand if the tile perfectly divides that dimension.
// the input tensor. Detecting UB requires that the input size and either
// corresponding tile or output size are static.
static bool requirePaddingValue(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outputShape,
ArrayRef<int64_t> outerDimsPerm,
ArrayRef<OpFoldResult> innerTiles);

static Value createDestinationTensor(OpBuilder &b, Location loc,
Expand Down
6 changes: 4 additions & 2 deletions mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -563,8 +563,10 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter,
return getConstantIntValue(tile).has_value();
});
if (areConstantTiles && operandType.hasStaticShape() &&
!tensor::PackOp::requirePaddingValue(operandType.getShape(), innerPos,
innerPackSizes)) {
!tensor::PackOp::requirePaddingValue(
operandType.getShape(), innerPos,
dest.getType().cast<ShapedType>().getShape(), {},
innerPackSizes)) {
packOps.push_back(rewriter.create<tensor::PackOp>(
loc, operand, dest, innerPos, innerPackSizes));
} else {
Expand Down
31 changes: 24 additions & 7 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3746,15 +3746,29 @@ SmallVector<int64_t> PackOp::getStaticTiles() {

bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outputShape,
ArrayRef<int64_t> outerDimsPerm,
ArrayRef<OpFoldResult> innerTiles) {
SmallVector<int64_t> outputTileSizes(
outputShape.take_front(inputShape.size()));
if (!outerDimsPerm.empty()) {
assert(outerDimsPerm.size() == outputTileSizes.size() &&
"expected output and outer_dims_perm to have same size");
applyPermutationToVector(outputTileSizes,
invertPermutationVector(outerDimsPerm));
}
for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
if (ShapedType::isDynamic(inputShape[pos]))
continue;
std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
if (!constantTile)
continue;
if (inputShape[pos] % (*constantTile) != 0)

if (!constantTile) {
if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
(inputShape[pos] % outputTileSizes[pos] != 0))
return true;
} else if (inputShape[pos] % (*constantTile) != 0) {
return true;
}
}
return false;
}
Expand All @@ -3776,9 +3790,11 @@ LogicalResult PackOp::verify() {

if (!paddingValue &&
requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
getDestType().getShape(), getOuterDimsPerm(),
getMixedTiles())) {
return emitOpError("invalid tile factor provided. Only full tiles are "
"supported when padding_value is not set");
return emitOpError(
"invalid tile factor or output size provided. Only full tiles are "
"supported when padding_value is not set");
}
return success();
}
Expand Down Expand Up @@ -3979,8 +3995,9 @@ static bool paddingIsNotNeeded(PackOp op) {
return false;
if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
return false;
return !PackOp::requirePaddingValue(srcType.getShape(), op.getInnerDimsPos(),
op.getMixedTiles());
return !PackOp::requirePaddingValue(
srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
op.getOuterDimsPerm(), op.getMixedTiles());
}

LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
Expand Down
18 changes: 17 additions & 1 deletion mlir/test/Dialect/Tensor/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -597,13 +597,29 @@ func.func @empty_wrong_number_of_operands(%sz : index) {
// -----

func.func @pack_invalid_no_padding_no_full_tiles(%input: tensor<256x128xf32>, %output: tensor<8x8x16x33xf32>) -> tensor<8x8x16x33xf32> {
// expected-error@+1 {{invalid tile factor provided. Only full tiles are supported when padding_value is not set}}
// expected-error@+1 {{invalid tile factor or output size provided. Only full tiles are supported when padding_value is not set}}
%0 = tensor.pack %input inner_dims_pos = [1, 0] inner_tiles = [16, 33] into %output : tensor<256x128xf32> -> tensor<8x8x16x33xf32>
return %0 : tensor<8x8x16x33xf32>
}

// -----

func.func @pack_invalid_no_padding_no_full_tiles_dyn_tiles(%input: tensor<256x128xf32>, %output: tensor<10x8x?x?xf32>, %tile_size_0: index, %tile_size_1: index) -> tensor<10x8x?x?xf32> {
// expected-error@+1 {{invalid tile factor or output size provided. Only full tiles are supported when padding_value is not set}}
%0 = tensor.pack %input inner_dims_pos = [1, 0] inner_tiles = [%tile_size_0, %tile_size_1] into %output : tensor<256x128xf32> -> tensor<10x8x?x?xf32>
return %0 : tensor<10x8x?x?xf32>
}

// -----

func.func @pack_invalid_no_padding_no_full_tiles_dyn_tiles_outperm(%input: tensor<256x128xf32>, %output: tensor<8x10x?x?xf32>, %tile_size_0: index, %tile_size_1: index) -> tensor<8x10x?x?xf32> {
// expected-error@+1 {{invalid tile factor or output size provided. Only full tiles are supported when padding_value is not set}}
%0 = tensor.pack %input outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [%tile_size_0, %tile_size_1] into %output : tensor<256x128xf32> -> tensor<8x10x?x?xf32>
return %0 : tensor<8x10x?x?xf32>
}

// -----

func.func @pad_and_pack_invalid_type(%input: tensor<13x15xf32>, %output: tensor<2x8x8x2xf32>, %pad: i32) -> tensor<2x8x8x2xf32> {
// expected-error@+1 {{expected padding_value has 'f32' but got: 'i32'}}
%0 = tensor.pack %input padding_value(%pad: i32) inner_dims_pos = [0, 1] inner_tiles = [8, 2] into %output : tensor<13x15xf32> -> tensor<2x8x8x2xf32>
Expand Down