13
13
#include " mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
14
14
#include " mlir/Dialect/MemRef/IR/MemRef.h"
15
15
#include " mlir/Dialect/SCF/SCF.h"
16
+ #include " mlir/Dialect/Tensor/IR/Tensor.h"
16
17
#include " mlir/IR/Dialect.h"
17
18
#include " mlir/IR/Operation.h"
18
19
#include " mlir/IR/PatternMatch.h"
@@ -812,6 +813,289 @@ struct YieldOpInterface
812
813
}
813
814
};
814
815
816
+ using tensor::ExtractSliceOp;
817
+
818
+ // / Return the destinations that an ForeachThreadOp is inserting into. One per
819
+ // / ParallelInsertSliceOp.
820
+ static SmallVector<OpOperand *>
821
+ getInsertionDest (ForeachThreadOp foreachThreadOp) {
822
+ PerformConcurrentlyOp terminator = foreachThreadOp.getTerminator ();
823
+ SmallVector<OpOperand *> result;
824
+ terminator.walk ([&](ParallelInsertSliceOp insertOp) {
825
+ result.push_back (&insertOp->getOpOperand (1 ) /* dest*/ );
826
+ });
827
+ return result;
828
+ }
829
+
830
+ // / Bufferization of ForeachThreadOp. This also bufferizes the terminator of the
831
+ // / region. There are op interfaces for the terminators (PerformConcurrentlyOp
832
+ // / and ParallelInsertSliceOp), but these are only used during analysis. Not
833
+ // / for bufferization.
834
+ struct ForeachThreadOpInterface
835
+ : public BufferizableOpInterface::ExternalModel<ForeachThreadOpInterface,
836
+ ForeachThreadOp> {
837
+ SmallVector<OpOperand *>
838
+ getAliasingOpOperand (Operation *op, OpResult opResult,
839
+ const AnalysisState &state) const {
840
+ // Get OpOperand (dest) from corresponding ParallelInsertSliceOp.
841
+ auto foreachThreadOp = cast<ForeachThreadOp>(op);
842
+ return {getInsertionDest (foreachThreadOp)[opResult.getResultNumber ()]};
843
+ }
844
+
845
+ bool isMemoryWrite (Operation *op, OpResult opResult,
846
+ const AnalysisState &state) const {
847
+ // This op is a memory write. Stop lookup here to avoid finding false
848
+ // conflicts involving this op and one of the ops in the region. This is
849
+ // similar to how scf.if ops are analyzed.
850
+ return true ;
851
+ }
852
+
853
+ BufferRelation bufferRelation (Operation *op, OpResult opResult,
854
+ const AnalysisState &state) const {
855
+ return BufferRelation::Equivalent;
856
+ }
857
+
858
+ LogicalResult bufferize (Operation *op, RewriterBase &b,
859
+ BufferizationState &state) const {
860
+ OpBuilder::InsertionGuard g (b);
861
+ auto foreachThreadOp = cast<ForeachThreadOp>(op);
862
+
863
+ // Gather new results of the ForeachThreadOp.
864
+ SmallVector<Value> newResults;
865
+ for (OpResult opResult : foreachThreadOp->getOpResults ()) {
866
+ SmallVector<OpOperand *> insertDestOperands =
867
+ state.getAnalysisState ().getAliasingOpOperand (opResult);
868
+ assert (insertDestOperands.size () == 1 &&
869
+ " expected exactly one aliasing OpOperand" );
870
+ // Insert copies right before the PerformConcurrentlyOp terminator. They
871
+ // should not be inside terminator (which would be the default insertion
872
+ // point).
873
+ Value buffer = *state.getBuffer (b, *insertDestOperands.front (),
874
+ /* forceInPlace=*/ llvm::None,
875
+ /* customCopyInsertionPoint=*/ op);
876
+ newResults.push_back (buffer);
877
+ }
878
+
879
+ // Create new ForeachThreadOp without any results and drop the automatically
880
+ // introduced terminator.
881
+ TypeRange newResultTypes;
882
+ auto newForeachThreadOp =
883
+ b.create <ForeachThreadOp>(foreachThreadOp.getLoc (), newResultTypes,
884
+ foreachThreadOp.getNumThreads ());
885
+ newForeachThreadOp.getBody ()->getTerminator ()->erase ();
886
+
887
+ // Move over block contents of the old op.
888
+ b.mergeBlocks (foreachThreadOp.getBody (), newForeachThreadOp.getBody (),
889
+ {newForeachThreadOp.getBody ()->getArguments ()});
890
+
891
+ // Bufferize terminator.
892
+ auto performConcurrentlyOp = cast<PerformConcurrentlyOp>(
893
+ newForeachThreadOp.getBody ()->getTerminator ());
894
+ b.setInsertionPoint (performConcurrentlyOp);
895
+ unsigned resultCounter = 0 ;
896
+ WalkResult walkResult =
897
+ performConcurrentlyOp.walk ([&](ParallelInsertSliceOp insertOp) {
898
+ Location loc = insertOp.getLoc ();
899
+ Type srcType = getMemRefType (
900
+ insertOp.getSource ().getType ().cast <RankedTensorType>(),
901
+ state.getOptions ());
902
+ // ParallelInsertSliceOp bufferizes to a copy.
903
+ auto srcMemref = b.create <bufferization::ToMemrefOp>(
904
+ loc, srcType, insertOp.getSource ());
905
+ Value destMemref = newResults[resultCounter++];
906
+ Value subview = b.create <memref::SubViewOp>(
907
+ loc, destMemref, insertOp.getMixedOffsets (),
908
+ insertOp.getMixedSizes (), insertOp.getMixedStrides ());
909
+ // This memcpy will fold away if everything bufferizes in-place.
910
+ if (failed (state.getOptions ().createMemCpy (b, insertOp.getLoc (),
911
+ srcMemref, subview)))
912
+ return WalkResult::interrupt ();
913
+ b.eraseOp (insertOp);
914
+ return WalkResult::advance ();
915
+ });
916
+ if (walkResult.wasInterrupted ())
917
+ return failure ();
918
+
919
+ // Replace the op.
920
+ replaceOpWithBufferizedValues (b, op, newResults);
921
+
922
+ return success ();
923
+ }
924
+ };
925
+
926
+ // / Nothing to do for PerformConcurrentlyOp.
927
+ struct PerformConcurrentlyOpInterface
928
+ : public BufferizableOpInterface::ExternalModel<
929
+ PerformConcurrentlyOpInterface, PerformConcurrentlyOp> {
930
+ LogicalResult bufferize (Operation *op, RewriterBase &b,
931
+ BufferizationState &state) const {
932
+ assert (false && " op does not have any tensor OpOperands / OpResults" );
933
+ return failure ();
934
+ }
935
+ };
936
+
937
+ // / Return true if the (ExtractSliceOp, ParallelInsertSliceOp) pair match (i.e.
938
+ // / equivalent operand / result and same offset/sizes/strides specification).
939
+ static bool areEquivalentExtractSliceOps (const AnalysisState &state,
940
+ ExtractSliceOp st,
941
+ ParallelInsertSliceOp sti) {
942
+ if (!st || !sti)
943
+ return false ;
944
+ if (st != sti &&
945
+ !state.areEquivalentBufferizedValues (st.source (), sti.getDest ()))
946
+ return false ;
947
+ if (!sameOffsetsSizesAndStrides (st, sti, isEqualConstantIntOrValue))
948
+ return false ;
949
+ return true ;
950
+ }
951
+
952
+ // / Return true if `value` is originating from an ExtractSliceOp that matches
953
+ // / the given InsertSliceOp.
954
+ static bool hasMatchingExtractSliceOp (const AnalysisState &state, Value value,
955
+ ParallelInsertSliceOp insertOp) {
956
+ auto condition = [&](Value val) {
957
+ if (auto extractOp = val.getDefiningOp <ExtractSliceOp>())
958
+ if (areEquivalentExtractSliceOps (state, extractOp, insertOp))
959
+ return true ;
960
+ return false ;
961
+ };
962
+
963
+ return llvm::all_of (state.findValueInReverseUseDefChain (value, condition),
964
+ condition);
965
+ }
966
+
967
+ // / Analysis of ParallelInsertSliceOp.
968
+ struct ParallelInsertSliceOpInterface
969
+ : public BufferizableOpInterface::ExternalModel<
970
+ ParallelInsertSliceOpInterface, ParallelInsertSliceOp> {
971
+ SmallVector<OpResult> getAliasingOpResult (Operation *op, OpOperand &opOperand,
972
+ const AnalysisState &state) const {
973
+ if (&opOperand != &op->getOpOperand (1 ) /* dest*/ )
974
+ return {};
975
+
976
+ // ParallelInsertSliceOp itself has no results. Tensors are returned via
977
+ // the parent op.
978
+ auto foreachThreadOp = op->getParentOfType <ForeachThreadOp>();
979
+ assert (foreachThreadOp &&
980
+ " could not find valid owner of parallel_insert_slice" );
981
+
982
+ // The i-th ParallelInsertSliceOp result is returned via the i-th OpResult
983
+ // of the parent ForeachThreadOp.
984
+ Block *block = op->getBlock ();
985
+ unsigned int opIdx = 0 ;
986
+ for (ParallelInsertSliceOp insertOp :
987
+ block->getOps <ParallelInsertSliceOp>()) {
988
+ if (insertOp.getOperation () == op)
989
+ break ;
990
+ ++opIdx;
991
+ }
992
+ assert (opIdx < foreachThreadOp->getNumResults () &&
993
+ " could not find op inside terminator op" );
994
+
995
+ return {foreachThreadOp->getResult (opIdx)};
996
+ }
997
+
998
+ bool bufferizesToMemoryRead (Operation *op, OpOperand &opOperand,
999
+ const AnalysisState &state) const {
1000
+ return true ;
1001
+ }
1002
+
1003
+ bool bufferizesToMemoryWrite (Operation *op, OpOperand &opOperand,
1004
+ const AnalysisState &state) const {
1005
+ return &opOperand == &op->getOpOperand (1 ) /* dest*/ ;
1006
+ }
1007
+
1008
+ BufferRelation bufferRelation (Operation *op, OpResult opResult,
1009
+ const AnalysisState &state) const {
1010
+ return BufferRelation::Equivalent;
1011
+ }
1012
+
1013
+ LogicalResult bufferize (Operation *op, RewriterBase &b,
1014
+ BufferizationState &state) const {
1015
+ // Will be bufferized as part of ForeachThreadOp.
1016
+ return failure ();
1017
+ }
1018
+
1019
+ // TODO: This is copied from TensorInterfaceImpl.cpp. Find a way to share
1020
+ // the code.
1021
+ bool isNotConflicting (Operation *op, OpOperand *uRead,
1022
+ OpOperand *uConflictingWrite,
1023
+ const AnalysisState &state) const {
1024
+ Operation *readingOp = uRead->getOwner ();
1025
+ Operation *conflictingWritingOp = uConflictingWrite->getOwner ();
1026
+
1027
+ // Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
1028
+ // uRead is an InsertSliceOp...
1029
+ if (auto insertSliceOp = dyn_cast<ParallelInsertSliceOp>(readingOp)) {
1030
+ // As an example, consider the following IR.
1031
+ //
1032
+ // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
1033
+ // %1 = linalg.fill %cst, %0 {inplace= [true] }
1034
+ // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
1035
+ // {inplace= [true] }
1036
+
1037
+ // TODO: Use insertSliceOp.getDestOpOperand etc. when available.
1038
+ if (uRead == &insertSliceOp->getOpOperand (1 ) /* dest*/ &&
1039
+ hasMatchingExtractSliceOp (state, uConflictingWrite->get (),
1040
+ insertSliceOp))
1041
+ // Case 1: The main insight is that InsertSliceOp reads only part of
1042
+ // the destination tensor. The overwritten area is not read. If
1043
+ // uConflictingWrite writes into exactly the memory location that is
1044
+ // being read by uRead, this is not a conflict.
1045
+ //
1046
+ // In the above example:
1047
+ // uRead = OpOperand 1 (%t) of tensor.insert_slice
1048
+ // uConflictingWrite = OpOperand 1 (%0) of linalg.fill
1049
+ //
1050
+ // The read of %t does not conflict with the write of the FillOp
1051
+ // (same aliases!) because the area that the FillOp operates on is
1052
+ // exactly the one that is *not* read via %t.
1053
+ return true ;
1054
+
1055
+ if (uRead == &insertSliceOp->getOpOperand (0 ) /* source*/ &&
1056
+ uConflictingWrite == &insertSliceOp->getOpOperand (1 ) /* dest*/ &&
1057
+ hasMatchingExtractSliceOp (state, uRead->get (), insertSliceOp))
1058
+ // Case 2: The read of the source tensor and the write to the dest
1059
+ // tensor via an InsertSliceOp is not a conflict if the read is
1060
+ // reading exactly that part of an equivalent tensor that the
1061
+ // InsertSliceOp is writing.
1062
+ //
1063
+ // In the above example:
1064
+ // uRead = OpOperand 0 (%1) of tensor.insert_slice
1065
+ // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
1066
+ return true ;
1067
+ }
1068
+
1069
+ // If uConflictingWrite is an InsertSliceOp...
1070
+ if (auto insertSliceOp =
1071
+ dyn_cast<ParallelInsertSliceOp>(conflictingWritingOp))
1072
+ // As an example, consider the following IR.
1073
+ //
1074
+ // %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
1075
+ // %1 = linalg.fill %cst, %0 {inplace= [true] }
1076
+ // %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
1077
+ // {inplace= [true] }
1078
+ // %3 = vector.transfer_read %1, %cst
1079
+ //
1080
+ // In the above example:
1081
+ // uRead = OpOperand 0 (%1) of vector.transfer_read
1082
+ // uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
1083
+ // lastWrite = %1
1084
+ //
1085
+ // This is not a conflict because the InsertSliceOp overwrites the
1086
+ // memory segment of %1 with the exact same data. (Effectively, there
1087
+ // is no memory write here.)
1088
+ if (uConflictingWrite == &insertSliceOp->getOpOperand (1 ) /* dest*/ &&
1089
+ state.areEquivalentBufferizedValues (uRead->get (),
1090
+ insertSliceOp.getSource ()) &&
1091
+ hasMatchingExtractSliceOp (state, insertSliceOp.getSource (),
1092
+ insertSliceOp))
1093
+ return true ;
1094
+
1095
+ return false ;
1096
+ }
1097
+ };
1098
+
815
1099
} // namespace
816
1100
} // namespace scf
817
1101
} // namespace mlir
@@ -822,6 +1106,11 @@ void mlir::scf::registerBufferizableOpInterfaceExternalModels(
822
1106
ExecuteRegionOp::attachInterface<ExecuteRegionOpInterface>(*ctx);
823
1107
ForOp::attachInterface<ForOpInterface>(*ctx);
824
1108
IfOp::attachInterface<IfOpInterface>(*ctx);
1109
+ ForeachThreadOp::attachInterface<ForeachThreadOpInterface>(*ctx);
1110
+ ParallelInsertSliceOp::attachInterface<ParallelInsertSliceOpInterface>(
1111
+ *ctx);
1112
+ PerformConcurrentlyOp::attachInterface<PerformConcurrentlyOpInterface>(
1113
+ *ctx);
825
1114
WhileOp::attachInterface<WhileOpInterface>(*ctx);
826
1115
YieldOp::attachInterface<YieldOpInterface>(*ctx);
827
1116
});
0 commit comments