Skip to content

Commit ebf3537

Browse files
[mlir][tensor] Insert explicit tensor.cast ops for insert_slice src
If additional static type information can be deduced from a insert_slice's size operands, insert an explicit cast of the op's source operand. This enables other canonicalization patterns that are matching for tensor_cast ops such as `ForOpTensorCastFolder` in SCF. Differential Revision: https://reviews.llvm.org/D108617
1 parent 0c36082 commit ebf3537

File tree

3 files changed

+96
-7
lines changed

3 files changed

+96
-7
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,7 +1085,24 @@ class InsertSliceOpConstantArgumentFolder final
10851085
}
10861086
};
10871087

1088-
/// Fold tensor_casts with insert_slice operations.
1088+
/// Fold tensor_casts with insert_slice operations. If the source or destination
1089+
/// tensor is a tensor_cast that removes static type information, the cast is
1090+
/// folded into the insert_slice operation. E.g.:
1091+
///
1092+
/// ```mlir
1093+
/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
1094+
/// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
1095+
/// ```
1096+
///
1097+
/// folds into:
1098+
///
1099+
/// ```mlir
1100+
/// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
1101+
/// ```
1102+
///
1103+
/// Note: When folding a cast on the destination tensor, the result of the
1104+
/// insert_slice operation is casted to ensure that the type of the result did
1105+
/// not change.
10891106
struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
10901107
using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
10911108

@@ -1123,12 +1140,63 @@ struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertSliceOp> {
11231140
return success();
11241141
}
11251142
};
1143+
1144+
/// If additional static type information can be deduced from a insert_slice's
1145+
/// size operands, insert an explicit cast of the op's source operand. This
1146+
/// enables other canonicalization patterns that are matching for tensor_cast
1147+
/// ops such as `ForOpTensorCastFolder` in SCF.
1148+
///
1149+
/// Example:
1150+
///
1151+
/// ```mlir
1152+
/// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
1153+
/// : tensor<?x?xf32> into ...
1154+
/// ```
1155+
///
1156+
/// folds into:
1157+
///
1158+
/// ```mlir
1159+
/// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
1160+
/// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
1161+
/// : tensor<64x64xf32> into ...
1162+
/// ```
1163+
struct InsertSliceOpSourceCastInserter final
1164+
: public OpRewritePattern<InsertSliceOp> {
1165+
using OpRewritePattern<InsertSliceOp>::OpRewritePattern;
1166+
1167+
LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp,
1168+
PatternRewriter &rewriter) const override {
1169+
RankedTensorType srcType = insertSliceOp.getSourceType();
1170+
if (srcType.getRank() != insertSliceOp.getType().getRank())
1171+
return failure();
1172+
SmallVector<int64_t> newSrcShape(srcType.getShape().begin(),
1173+
srcType.getShape().end());
1174+
for (int64_t i = 0; i < srcType.getRank(); ++i) {
1175+
if (Optional<int64_t> constInt =
1176+
getConstantIntValue(insertSliceOp.getMixedSizes()[i]))
1177+
newSrcShape[i] = *constInt;
1178+
}
1179+
RankedTensorType newSrcType =
1180+
RankedTensorType::get(newSrcShape, srcType.getElementType());
1181+
if (srcType == newSrcType)
1182+
return failure();
1183+
1184+
// srcType and newSrcType are different. Insert a cast.
1185+
Value cast = rewriter.create<tensor::CastOp>(
1186+
insertSliceOp.getLoc(), newSrcType, insertSliceOp.source());
1187+
rewriter.replaceOpWithNewOp<InsertSliceOp>(
1188+
insertSliceOp, cast, insertSliceOp.dest(),
1189+
insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
1190+
insertSliceOp.getMixedStrides());
1191+
return success();
1192+
}
1193+
};
11261194
} // namespace
11271195

11281196
void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
11291197
MLIRContext *context) {
1130-
results.add<InsertSliceOpConstantArgumentFolder, InsertSliceOpCastFolder>(
1131-
context);
1198+
results.add<InsertSliceOpConstantArgumentFolder, InsertSliceOpCastFolder,
1199+
InsertSliceOpSourceCastInserter>(context);
11321200
}
11331201

11341202
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,7 @@ func @matmul_on_tensors(%t0: tensor<32x1024xf32>, %t1: tensor<1024x1024xf32>) ->
666666
return %res : tensor<1024x1024xf32>
667667
}
668668

669-
669+
// -----
670670

671671
// CHECK-LABEL: @cond_prop
672672
func @cond_prop(%arg0 : i1) -> index {
@@ -707,6 +707,8 @@ func @cond_prop(%arg0 : i1) -> index {
707707
// CHECK-NEXT: return %[[if]] : index
708708
// CHECK-NEXT:}
709709

710+
// -----
711+
710712
// CHECK-LABEL: @replace_if_with_cond1
711713
func @replace_if_with_cond1(%arg0 : i1) -> (i32, i1) {
712714
%true = constant true
@@ -729,6 +731,8 @@ func @replace_if_with_cond1(%arg0 : i1) -> (i32, i1) {
729731
// CHECK-NEXT: }
730732
// CHECK-NEXT: return %[[if]], %arg0 : i32, i1
731733

734+
// -----
735+
732736
// CHECK-LABEL: @replace_if_with_cond2
733737
func @replace_if_with_cond2(%arg0 : i1) -> (i32, i1) {
734738
%true = constant true
@@ -753,6 +757,7 @@ func @replace_if_with_cond2(%arg0 : i1) -> (i32, i1) {
753757
// CHECK-NEXT: }
754758
// CHECK-NEXT: return %[[if]], %[[toret]] : i32, i1
755759

760+
// -----
756761

757762
// CHECK-LABEL: @replace_if_with_cond3
758763
func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
@@ -774,6 +779,7 @@ func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) {
774779
// CHECK-NEXT: }
775780
// CHECK-NEXT: return %[[if]], %arg1 : i32, i64
776781

782+
// -----
777783

778784
// CHECK-LABEL: @while_cond_true
779785
func @while_cond_true() {

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -366,10 +366,11 @@ func @insert_slice_canonicalize(%arg0 : tensor<?x?x?xf32>, %arg1 : index,
366366
}
367367
// CHECK-LABEL: func @insert_slice_canonicalize
368368
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
369-
// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[ARG0]]
369+
// CHECK: %[[CAST:.+]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<4x1x?xf32>
370+
// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[CAST]]
370371
// CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1]
371-
// CHECK-SAME: : tensor<?x?x?xf32> into tensor<?x?x?xf32>
372-
// CHEKC: return %[[RESULT]]
372+
// CHECK-SAME: : tensor<4x1x?xf32> into tensor<?x?x?xf32>
373+
// CHECK: return %[[RESULT]]
373374

374375
// -----
375376

@@ -517,3 +518,17 @@ func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
517518
%2 = tensor.dim %0, %c1 : tensor<?x?xf32>
518519
return %1, %2: index, index
519520
}
521+
522+
// -----
523+
524+
// CHECK-LABEL: func @insert_tensor_cast_on_insert_slice_src(
525+
// CHECK-SAME: %[[arg0:.*]]: tensor<?x5x?xf32>, %[[arg1:.*]]: tensor<?x?x?xf32>
526+
// CHECK: %[[cast:.*]] = tensor.cast %[[arg0]] : tensor<?x5x?xf32> to tensor<64x5x64xf32>
527+
// CHECK: %[[r:.*]] = tensor.insert_slice %[[cast]] into %[[arg1]][0, 1, 2] [64, 5, 64] [1, 1, 1] : tensor<64x5x64xf32> into tensor<?x?x?xf32>
528+
// CHECK: return %[[r]]
529+
func @insert_tensor_cast_on_insert_slice_src(
530+
%arg0 : tensor<?x5x?xf32>, %arg1 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
531+
%r = tensor.insert_slice %arg0 into %arg1[0, 1, 2] [64, 5, 64] [1, 1, 1]
532+
: tensor<?x5x?xf32> into tensor<?x?x?xf32>
533+
return %r : tensor<?x?x?xf32>
534+
}

0 commit comments

Comments
 (0)