|
56 | 56 | //===----------------------------------------------------------------------===//
|
57 | 57 |
|
58 | 58 | #include "PassDetail.h"
|
| 59 | +#include "mlir/Analysis/Liveness.h" |
59 | 60 | #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
| 61 | +#include "mlir/Dialect/StandardOps/IR/Ops.h" |
| 62 | +#include "mlir/IR/Dominance.h" |
60 | 63 | #include "mlir/IR/Operation.h"
|
| 64 | +#include "mlir/Interfaces/ControlFlowInterfaces.h" |
61 | 65 | #include "mlir/Pass/Pass.h"
|
62 |
| -#include "mlir/Transforms/Bufferize.h" |
63 | 66 | #include "mlir/Transforms/Passes.h"
|
64 | 67 | #include "llvm/ADT/SetOperations.h"
|
65 | 68 |
|
@@ -809,245 +812,6 @@ struct BufferPlacementPass : BufferPlacementBase<BufferPlacementPass> {
|
809 | 812 |
|
810 | 813 | } // end anonymous namespace
|
811 | 814 |
|
812 |
| -//===----------------------------------------------------------------------===// |
813 |
| -// BufferAssignmentTypeConverter |
814 |
| -//===----------------------------------------------------------------------===// |
815 |
| - |
816 |
| -/// Registers conversions into BufferAssignmentTypeConverter |
817 |
| -BufferAssignmentTypeConverter::BufferAssignmentTypeConverter() { |
818 |
| - // Keep all types unchanged. |
819 |
| - addConversion([](Type type) { return type; }); |
820 |
| - // Convert RankedTensorType to MemRefType. |
821 |
| - addConversion([](RankedTensorType type) { |
822 |
| - return (Type)MemRefType::get(type.getShape(), type.getElementType()); |
823 |
| - }); |
824 |
| - // Convert UnrankedTensorType to UnrankedMemRefType. |
825 |
| - addConversion([](UnrankedTensorType type) { |
826 |
| - return (Type)UnrankedMemRefType::get(type.getElementType(), 0); |
827 |
| - }); |
828 |
| -} |
829 |
| - |
830 |
| -/// This method tries to decompose a value of a certain type using provided |
831 |
| -/// decompose callback functions. If it is unable to do so, the original value |
832 |
| -/// is returned. |
833 |
| -void BufferAssignmentTypeConverter::tryDecomposeValue( |
834 |
| - OpBuilder &builder, Location loc, Type type, Value value, |
835 |
| - SmallVectorImpl<Value> &results) { |
836 |
| - for (auto conversion : decomposeValueConversions) |
837 |
| - if (conversion(builder, loc, type, value, results) != llvm::None) |
838 |
| - return; |
839 |
| - results.push_back(value); |
840 |
| -} |
841 |
| - |
842 |
| -/// This method tries to decompose a type using provided decompose callback |
843 |
| -/// functions. If it is unable to do so, the original type is returned. |
844 |
| -void BufferAssignmentTypeConverter::tryDecomposeType( |
845 |
| - Type type, SmallVectorImpl<Type> &types) { |
846 |
| - for (auto conversion : decomposeTypeConversions) |
847 |
| - if (conversion(type, types) != llvm::None) |
848 |
| - return; |
849 |
| - types.push_back(type); |
850 |
| -} |
851 |
| - |
852 |
| -/// This method returns ResultConversionKind for the input type. |
853 |
| -BufferAssignmentTypeConverter::ResultConversionKind |
854 |
| -BufferAssignmentTypeConverter::getResultConversionKind(Type origin, |
855 |
| - Type converted) { |
856 |
| - for (auto conversion : resultTypeConversions) { |
857 |
| - auto res = conversion(origin, converted); |
858 |
| - if (res != llvm::None) |
859 |
| - return res.getValue(); |
860 |
| - } |
861 |
| - return KeepAsFunctionResult; |
862 |
| -} |
863 |
| - |
864 |
| -//===----------------------------------------------------------------------===// |
865 |
| -// BufferAssignmentFuncOpConverter |
866 |
| -//===----------------------------------------------------------------------===// |
867 |
| - |
868 |
| -/// Performs the actual function signature rewriting step. |
869 |
| -LogicalResult BufferAssignmentFuncOpConverter::matchAndRewrite( |
870 |
| - mlir::FuncOp funcOp, ArrayRef<Value> operands, |
871 |
| - ConversionPatternRewriter &rewriter) const { |
872 |
| - auto funcType = funcOp.getType(); |
873 |
| - |
874 |
| - // Convert function arguments using the provided TypeConverter. |
875 |
| - TypeConverter::SignatureConversion conversion(funcType.getNumInputs()); |
876 |
| - for (auto argType : llvm::enumerate(funcType.getInputs())) { |
877 |
| - SmallVector<Type, 2> decomposedTypes, convertedTypes; |
878 |
| - converter.tryDecomposeType(argType.value(), decomposedTypes); |
879 |
| - converter.convertTypes(decomposedTypes, convertedTypes); |
880 |
| - conversion.addInputs(argType.index(), convertedTypes); |
881 |
| - } |
882 |
| - |
883 |
| - // Convert the result types of the function. |
884 |
| - SmallVector<Type, 2> newResultTypes; |
885 |
| - newResultTypes.reserve(funcOp.getNumResults()); |
886 |
| - for (Type resultType : funcType.getResults()) { |
887 |
| - SmallVector<Type, 2> originTypes; |
888 |
| - converter.tryDecomposeType(resultType, originTypes); |
889 |
| - for (auto origin : originTypes) { |
890 |
| - Type converted = converter.convertType(origin); |
891 |
| - auto kind = converter.getResultConversionKind(origin, converted); |
892 |
| - if (kind == BufferAssignmentTypeConverter::AppendToArgumentsList) |
893 |
| - conversion.addInputs(converted); |
894 |
| - else |
895 |
| - // kind = BufferAssignmentTypeConverter::KeepAsFunctionResult |
896 |
| - newResultTypes.push_back(converted); |
897 |
| - } |
898 |
| - } |
899 |
| - |
900 |
| - if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), converter, |
901 |
| - &conversion))) |
902 |
| - return failure(); |
903 |
| - |
904 |
| - // Update the signature of the function. |
905 |
| - rewriter.updateRootInPlace(funcOp, [&] { |
906 |
| - funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(), |
907 |
| - newResultTypes)); |
908 |
| - }); |
909 |
| - return success(); |
910 |
| -} |
911 |
| - |
912 |
| -//===----------------------------------------------------------------------===// |
913 |
| -// BufferAssignmentCallOpConverter |
914 |
| -//===----------------------------------------------------------------------===// |
915 |
| - |
916 |
| -namespace { |
917 |
| -// This class represents a mapping from a result to a list of values and some |
918 |
| -// results that have not yet constructed. Instead, the indices of these |
919 |
| -// results in the operation that will be constructed are known. They will be |
920 |
| -// replaced with the actual values when they are available. The order of |
921 |
| -// adding to this mapping is important. |
922 |
| -class CallOpResultMapping { |
923 |
| -public: |
924 |
| - CallOpResultMapping() { order = 0; }; |
925 |
| - |
926 |
| - /// Add an available value to the mapping. |
927 |
| - void addMapping(Value value) { toValuesMapping.push_back({order++, value}); } |
928 |
| - |
929 |
| - /// Add the index of unavailble result value to the mapping. |
930 |
| - void addMapping(unsigned index) { |
931 |
| - toIndicesMapping.push_back({order++, index}); |
932 |
| - } |
933 |
| - |
934 |
| - /// This method returns the mapping values list. The unknown result values |
935 |
| - /// that only their indicies are available are replaced with their values. |
936 |
| - void getMappingValues(ValueRange valuesToReplaceIndices, |
937 |
| - SmallVectorImpl<Value> &values) { |
938 |
| - // Append available values to the list. |
939 |
| - SmallVector<std::pair<unsigned, Value>, 2> res(toValuesMapping.begin(), |
940 |
| - toValuesMapping.end()); |
941 |
| - // Replace the indices with the actual values. |
942 |
| - llvm::for_each( |
943 |
| - toIndicesMapping, [&](const std::pair<unsigned, unsigned> &entry) { |
944 |
| - assert(entry.second < valuesToReplaceIndices.size() && |
945 |
| - "The value index is out of range."); |
946 |
| - res.push_back({entry.first, valuesToReplaceIndices[entry.second]}); |
947 |
| - }); |
948 |
| - // Sort the values based on their adding orders. |
949 |
| - llvm::sort(res, [](const std::pair<unsigned, Value> &v1, |
950 |
| - const std::pair<unsigned, Value> &v2) { |
951 |
| - return v1.first < v2.first; |
952 |
| - }); |
953 |
| - // Fill the values. |
954 |
| - llvm::for_each(res, [&](const std::pair<unsigned, Value> &entry) { |
955 |
| - values.push_back(entry.second); |
956 |
| - }); |
957 |
| - } |
958 |
| - |
959 |
| -private: |
960 |
| - /// Keeping the inserting order of mapping values. |
961 |
| - int order; |
962 |
| - |
963 |
| - /// Containing the mapping values with their inserting orders. |
964 |
| - SmallVector<std::pair<unsigned, Value>, 2> toValuesMapping; |
965 |
| - |
966 |
| - /// Containing the indices of result values with their inserting orders. |
967 |
| - SmallVector<std::pair<unsigned, unsigned>, 2> toIndicesMapping; |
968 |
| -}; |
969 |
| -} // namespace |
970 |
| - |
971 |
| -/// Performs the actual rewriting step. |
972 |
| -LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite( |
973 |
| - CallOp callOp, ArrayRef<Value> operands, |
974 |
| - ConversionPatternRewriter &rewriter) const { |
975 |
| - |
976 |
| - Location loc = callOp.getLoc(); |
977 |
| - OpBuilder builder(callOp); |
978 |
| - SmallVector<Value, 2> newOperands; |
979 |
| - |
980 |
| - // TODO: if the CallOp references a FuncOp that only has a declaration (e.g. |
981 |
| - // to an externally defined symbol like an external library calls), only |
982 |
| - // convert if some special attribute is set. |
983 |
| - // This will allow more control of interop across ABI boundaries. |
984 |
| - |
985 |
| - // Create the operands list of the new `CallOp`. It unpacks the decomposable |
986 |
| - // values if a decompose callback function has been provided by the user. |
987 |
| - for (auto operand : operands) { |
988 |
| - SmallVector<Value, 2> values; |
989 |
| - this->converter.tryDecomposeValue(builder, loc, operand.getType(), operand, |
990 |
| - values); |
991 |
| - newOperands.append(values.begin(), values.end()); |
992 |
| - } |
993 |
| - |
994 |
| - // Create the new result types for the new `CallOp` and a mapping from the old |
995 |
| - // result to new value(s). |
996 |
| - SmallVector<Type, 2> newResultTypes; |
997 |
| - SmallVector<CallOpResultMapping, 4> mappings; |
998 |
| - mappings.resize(callOp.getNumResults()); |
999 |
| - for (auto result : llvm::enumerate(callOp.getResults())) { |
1000 |
| - SmallVector<Type, 2> originTypes; |
1001 |
| - converter.tryDecomposeType(result.value().getType(), originTypes); |
1002 |
| - auto &resultMapping = mappings[result.index()]; |
1003 |
| - for (Type origin : originTypes) { |
1004 |
| - Type converted = converter.convertType(origin); |
1005 |
| - auto kind = converter.getResultConversionKind(origin, converted); |
1006 |
| - if (kind == BufferAssignmentTypeConverter::KeepAsFunctionResult) { |
1007 |
| - newResultTypes.push_back(converted); |
1008 |
| - // The result value is not yet available. Its index is kept and it is |
1009 |
| - // replaced with the actual value of the new `CallOp` later. |
1010 |
| - resultMapping.addMapping(newResultTypes.size() - 1); |
1011 |
| - } else { |
1012 |
| - // kind = BufferAssignmentTypeConverter::AppendToArgumentsList |
1013 |
| - MemRefType memref = converted.dyn_cast<MemRefType>(); |
1014 |
| - if (!memref) |
1015 |
| - return callOp.emitError("Cannot allocate for a non-Memref type"); |
1016 |
| - Value alloc = rewriter.create<AllocOp>(loc, memref); |
1017 |
| - newOperands.push_back(alloc); |
1018 |
| - resultMapping.addMapping(alloc); |
1019 |
| - } |
1020 |
| - } |
1021 |
| - } |
1022 |
| - |
1023 |
| - CallOp newCallOp = rewriter.create<CallOp>(loc, callOp.getCallee(), |
1024 |
| - newResultTypes, newOperands); |
1025 |
| - |
1026 |
| - // Build a replacing value for each result to replace its uses. If a result |
1027 |
| - // has multiple mapping values, it needs to be packed to a single value. |
1028 |
| - OpBuilder nextBuilder(callOp.getOperation()->getNextNode()); |
1029 |
| - SmallVector<Value, 2> replacedValues; |
1030 |
| - replacedValues.reserve(callOp.getNumResults()); |
1031 |
| - for (unsigned i = 0, e = callOp.getNumResults(); i < e; ++i) { |
1032 |
| - SmallVector<Value, 2> valuesToPack; |
1033 |
| - mappings[i].getMappingValues(newCallOp.getResults(), valuesToPack); |
1034 |
| - if (valuesToPack.empty()) { |
1035 |
| - // No replacement is required. |
1036 |
| - replacedValues.push_back(nullptr); |
1037 |
| - } else if (valuesToPack.size() == 1) { |
1038 |
| - replacedValues.push_back(valuesToPack.front()); |
1039 |
| - } else { |
1040 |
| - // Values need to be packed using callback function. The same callback |
1041 |
| - // that is used for materializeArgumentConversion is used for packing. |
1042 |
| - Value packed = converter.materializeArgumentConversion( |
1043 |
| - nextBuilder, loc, callOp.getType(i), valuesToPack); |
1044 |
| - replacedValues.push_back(packed); |
1045 |
| - } |
1046 |
| - } |
1047 |
| - rewriter.replaceOp(callOp, replacedValues); |
1048 |
| - return success(); |
1049 |
| -} |
1050 |
| - |
1051 | 815 | //===----------------------------------------------------------------------===//
|
1052 | 816 | // BufferPlacementPass construction
|
1053 | 817 | //===----------------------------------------------------------------------===//
|
|
0 commit comments