19
19
#include " mlir/Dialect/Tensor/IR/Tensor.h"
20
20
#include " mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
21
21
#include " mlir/Dialect/Utils/StaticValueUtils.h"
22
+ #include " mlir/IR/BuiltinTypeInterfaces.h"
22
23
#include " mlir/IR/Dialect.h"
23
24
#include " mlir/IR/Operation.h"
24
25
@@ -636,6 +637,34 @@ struct InsertOpInterface
636
637
}
637
638
};
638
639
640
+ template <typename InsertOpTy>
641
+ static bool insertSliceOpRequiresRead (InsertOpTy insertSliceOp,
642
+ OpOperand &opOperand) {
643
+ RankedTensorType destType = insertSliceOp.getDestType ();
644
+
645
+ // The source is always read.
646
+ if (opOperand == insertSliceOp.getSourceMutable ())
647
+ return true ;
648
+
649
+ // For the destination, it depends...
650
+ assert (opOperand == insertSliceOp.getDestMutable () && " expected dest" );
651
+
652
+ // Dest is not read if it is entirely overwritten. E.g.:
653
+ // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
654
+ bool allOffsetsZero =
655
+ llvm::all_of (insertSliceOp.getMixedOffsets (),
656
+ [](OpFoldResult ofr) { return isConstantIntValue (ofr, 0 ); });
657
+ bool sizesMatchDestSizes = llvm::all_of (
658
+ llvm::enumerate (insertSliceOp.getMixedSizes ()), [&](const auto &it) {
659
+ return getConstantIntValue (it.value ()) ==
660
+ destType.getDimSize (it.index ());
661
+ });
662
+ bool allStridesOne =
663
+ llvm::all_of (insertSliceOp.getMixedStrides (),
664
+ [](OpFoldResult ofr) { return isConstantIntValue (ofr, 1 ); });
665
+ return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
666
+ }
667
+
639
668
// / Bufferization of tensor.insert_slice. Replace with a memory copy. Under
640
669
// / certain circumstances, this op can also be a no-op.
641
670
// /
@@ -646,32 +675,8 @@ struct InsertSliceOpInterface
646
675
tensor::InsertSliceOp> {
647
676
bool bufferizesToMemoryRead (Operation *op, OpOperand &opOperand,
648
677
const AnalysisState &state) const {
649
- auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
650
- RankedTensorType destType = insertSliceOp.getDestType ();
651
-
652
- // The source is always read.
653
- if (opOperand == insertSliceOp.getSourceMutable ())
654
- return true ;
655
-
656
- // For the destination, it depends...
657
- assert (opOperand == insertSliceOp.getDestMutable () && " expected dest" );
658
-
659
- // Dest is not read if it is entirely overwritten. E.g.:
660
- // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
661
- bool allOffsetsZero =
662
- llvm::all_of (insertSliceOp.getMixedOffsets (), [](OpFoldResult ofr) {
663
- return isConstantIntValue (ofr, 0 );
664
- });
665
- bool sizesMatchDestSizes = llvm::all_of (
666
- llvm::enumerate (insertSliceOp.getMixedSizes ()), [&](const auto &it) {
667
- return getConstantIntValue (it.value ()) ==
668
- destType.getDimSize (it.index ());
669
- });
670
- bool allStridesOne =
671
- llvm::all_of (insertSliceOp.getMixedStrides (), [](OpFoldResult ofr) {
672
- return isConstantIntValue (ofr, 1 );
673
- });
674
- return !(allOffsetsZero && sizesMatchDestSizes && allStridesOne);
678
+ return insertSliceOpRequiresRead (cast<tensor::InsertSliceOp>(op),
679
+ opOperand);
675
680
}
676
681
677
682
LogicalResult bufferize (Operation *op, RewriterBase &rewriter,
@@ -931,7 +936,8 @@ struct ParallelInsertSliceOpInterface
931
936
932
937
bool bufferizesToMemoryRead (Operation *op, OpOperand &opOperand,
933
938
const AnalysisState &state) const {
934
- return true ;
939
+ return insertSliceOpRequiresRead (cast<tensor::ParallelInsertSliceOp>(op),
940
+ opOperand);
935
941
}
936
942
937
943
bool bufferizesToMemoryWrite (Operation *op, OpOperand &opOperand,
0 commit comments