Skip to content

Commit d0da3d8

Browse files
authored
[mlir][tensor] Fold when source is const (#71643)
Fixes #60656. This patch implements a basic fold for various reshape/resize tensor operations. Specifically, the folding removes tensor reshape/resize ops when they are applied to a constant tensor. For example, the following function: ```mlir func.func @main(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> { %cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32> %0 = tensor.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32> return %0 : tensor<8x16x8x32xf32> } ``` will be changed into the following with `mlir-opt -canonicalize`: ```mlir func.func @main(%arg0: tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> { %cst = arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32> return %cst : tensor<8x16x8x32xf32> } ``` As a side-note, this patch is essentially an extension of f79f430.
1 parent bede010 commit d0da3d8

File tree

3 files changed

+102
-6
lines changed

3 files changed

+102
-6
lines changed

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,7 @@ def Tensor_GatherOp : Tensor_Op<"gather", [
659659
}
660660
}];
661661
let hasVerifier = 1;
662+
let hasFolder = 1;
662663
}
663664

664665
//===----------------------------------------------------------------------===//
@@ -986,6 +987,7 @@ def Tensor_ReshapeOp: Tensor_Op<"reshape", [
986987
$source `(` $shape `)` attr-dict `:` functional-type(operands, results)
987988
}];
988989
let hasVerifier = 1;
990+
let hasFolder = 1;
989991
}
990992

991993
//===----------------------------------------------------------------------===//
@@ -1867,6 +1869,8 @@ def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
18671869
}];
18681870

18691871
let hasCanonicalizeMethod = 1;
1872+
1873+
let hasFolder = 1;
18701874
}
18711875

18721876
//===----------------------------------------------------------------------===//
@@ -1948,6 +1952,8 @@ def Tensor_UnPackOp : Tensor_RelayoutOp<"unpack"> {
19481952
}];
19491953

19501954
let hasCanonicalizeMethod = 1;
1955+
1956+
let hasFolder = 1;
19511957
}
19521958

19531959
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,17 @@ void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
834834
ReplaceEmptyTensorStaticShapeDims>(context);
835835
}
836836

837+
/// Try to remove a tensor operation if it would only reshape a constant.
838+
/// Removes the op and replaces the constant with a new constant of the result
839+
/// shape.
840+
static OpFoldResult reshapeConstantSource(DenseElementsAttr source,
841+
TensorType result) {
842+
if (source && source.isSplat() && result.hasStaticShape())
843+
return source.resizeSplat(result);
844+
845+
return {};
846+
}
847+
837848
//===----------------------------------------------------------------------===//
838849
// ExtractOp
839850
//===----------------------------------------------------------------------===//
@@ -1089,6 +1100,14 @@ LogicalResult GatherOp::verify() {
10891100
return success();
10901101
}
10911102

1103+
OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1104+
if (OpFoldResult reshapedSource = reshapeConstantSource(
1105+
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1106+
getResult().getType()))
1107+
return reshapedSource;
1108+
return {};
1109+
}
1110+
10921111
//===----------------------------------------------------------------------===//
10931112
// InsertOp
10941113
//===----------------------------------------------------------------------===//
@@ -1367,6 +1386,14 @@ LogicalResult ReshapeOp::verify() {
13671386
return success();
13681387
}
13691388

1389+
OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1390+
if (OpFoldResult reshapedSource = reshapeConstantSource(
1391+
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1392+
getResult().getType()))
1393+
return reshapedSource;
1394+
return {};
1395+
}
1396+
13701397
//===----------------------------------------------------------------------===//
13711398
// Reassociative reshape ops
13721399
//===----------------------------------------------------------------------===//
@@ -2153,12 +2180,10 @@ static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
21532180
}
21542181

21552182
OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2156-
if (auto splat =
2157-
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
2158-
auto resultType = llvm::cast<ShapedType>(getResult().getType());
2159-
if (resultType.hasStaticShape())
2160-
return splat.resizeSplat(resultType);
2161-
}
2183+
if (OpFoldResult reshapedSource = reshapeConstantSource(
2184+
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2185+
getResult().getType()))
2186+
return reshapedSource;
21622187
if (getSourceType() == getType() &&
21632188
succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
21642189
return this->getSource();
@@ -3823,6 +3848,14 @@ bool PackOp::isLikePad() {
38233848
return isLikePadUnPad(*this, packedTensorType);
38243849
}
38253850

3851+
OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
3852+
if (OpFoldResult reshapedSource = reshapeConstantSource(
3853+
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
3854+
getResult().getType()))
3855+
return reshapedSource;
3856+
return {};
3857+
}
3858+
38263859
//===----------------------------------------------------------------------===//
38273860
// UnPackOp
38283861
//===----------------------------------------------------------------------===//
@@ -3951,6 +3984,15 @@ bool UnPackOp::isLikeUnPad() {
39513984
RankedTensorType packedTensorType = getSourceType();
39523985
return isLikePadUnPad(*this, packedTensorType);
39533986
}
3987+
3988+
OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
3989+
if (OpFoldResult reshapedSource = reshapeConstantSource(
3990+
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
3991+
getResult().getType()))
3992+
return reshapedSource;
3993+
return {};
3994+
}
3995+
39543996
//===----------------------------------------------------------------------===//
39553997
// Common Canonicalizers and Folders.
39563998
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,30 @@ func.func @fold_extract_insert(%input : tensor<?x?x?xf32>, %slice: tensor<4x?x8x
672672

673673
// -----
674674

675+
// CHECK-LABEL: func @fold_gather_constant_splat
676+
// CHECK-NOT: tensor.gather
677+
// CHECK: arith.constant dense<1.000000e-01> : tensor<1x2x1x1x1xf32>
678+
func.func @fold_gather_constant_splat(%indices : tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32> {
679+
%cst = arith.constant dense<1.000000e-01> : tensor<4x4x4xf32>
680+
%0 = tensor.gather %cst[%indices] gather_dims([0, 1, 2]) :
681+
(tensor<4x4x4xf32>, tensor<1x2x 3xindex>) -> tensor<1x2x 1x1x1xf32>
682+
return %0 : tensor<1x2x 1x1x1xf32>
683+
}
684+
685+
// -----
686+
687+
// CHECK-LABEL: func @fold_reshape_constant_splat
688+
// CHECK-NOT: tensor.reshape
689+
// CHECK: arith.constant dense<1.000000e-01> : tensor<4xf32>
690+
func.func @fold_reshape_constant_splat(%shape : tensor<1xi32>) -> tensor<4xf32> {
691+
%cst = arith.constant dense<1.000000e-01> : tensor<4x1xf32>
692+
%0 = tensor.reshape %cst(%shape)
693+
: (tensor<4x1xf32>, tensor<1xi32>) -> tensor<4xf32>
694+
return %0 : tensor<4xf32>
695+
}
696+
697+
// -----
698+
675699
// CHECK-LABEL: func @fold_extract_constant_splat
676700
// CHECK-NOT: tensor.extract_slice
677701
// CHECK: arith.constant dense<42> : tensor<4x4xi32>
@@ -683,6 +707,30 @@ func.func @fold_extract_constant_splat() -> (tensor<4x4xi32>) {
683707

684708
// -----
685709

710+
// CHECK-LABEL: func @fold_pack_constant_splat
711+
// CHECK-NOT: tensor.pack
712+
// CHECK: arith.constant dense<1.000000e-01> : tensor<8x16x8x32xf32>
713+
func.func @fold_pack_constant_splat(%dest : tensor<8x16x8x32xf32>) -> tensor<8x16x8x32xf32> {
714+
%cst = arith.constant dense<1.000000e-01> : tensor<64x128xf32>
715+
%0 = tensor.pack %cst outer_dims_perm = [1, 0] inner_dims_pos = [0, 1]
716+
inner_tiles = [8, 32] into %dest : tensor<64x128xf32> -> tensor<8x16x8x32xf32>
717+
return %0 : tensor<8x16x8x32xf32>
718+
}
719+
720+
// -----
721+
722+
// CHECK-LABEL: func @fold_unpack_constant_splat
723+
// CHECK-NOT: tensor.unpack
724+
// CHECK: arith.constant dense<1.000000e-01> : tensor<128x256xf32>
725+
func.func @fold_unpack_constant_splat(%dest : tensor<128x256xf32>) -> tensor<128x256xf32> {
726+
%cst = arith.constant dense<1.000000e-01> : tensor<16x8x8x32xf32>
727+
%0 = tensor.unpack %cst inner_dims_pos = [0, 1]
728+
inner_tiles = [8, 32] into %dest : tensor<16x8x8x32xf32> -> tensor<128x256xf32>
729+
return %0 : tensor<128x256xf32>
730+
}
731+
732+
// -----
733+
686734
// CHECK-LABEL: func @fold_overlapping_insert
687735
// CHECK-SAME: %[[INPUT:.+]]: tensor<?x?x?xf32>, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32>
688736
func.func @fold_overlapping_insert(%input : tensor<?x?x?xf32>, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor<?x?x?xf32>) {

0 commit comments

Comments
 (0)