Skip to content

Commit 1cca0f3

Browse files
committed
[mlir] Refactor code out of BufferPlacement.cpp
Now BufferPlacement.cpp doesn't depend on Bufferize.h. Part of the refactor discussed in: https://llvm.discourse.group/t/what-is-the-strategy-for-tensor-memref-conversion-bufferization/1938/17 Differential Revision: https://reviews.llvm.org/D89268
1 parent 6b30fb7 commit 1cca0f3

File tree

4 files changed

+256
-242
lines changed

4 files changed

+256
-242
lines changed

mlir/include/mlir/Transforms/Bufferize.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,8 @@
1818
#ifndef MLIR_TRANSFORMS_BUFFERIZE_H
1919
#define MLIR_TRANSFORMS_BUFFERIZE_H
2020

21-
#include "mlir/Analysis/Liveness.h"
2221
#include "mlir/Dialect/StandardOps/IR/Ops.h"
2322
#include "mlir/IR/Builders.h"
24-
#include "mlir/IR/Dominance.h"
2523
#include "mlir/IR/Function.h"
2624
#include "mlir/IR/Operation.h"
2725
#include "mlir/Transforms/DialectConversion.h"

mlir/lib/Transforms/BufferPlacement.cpp

Lines changed: 4 additions & 240 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,13 @@
5656
//===----------------------------------------------------------------------===//
5757

5858
#include "PassDetail.h"
59+
#include "mlir/Analysis/Liveness.h"
5960
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
61+
#include "mlir/Dialect/StandardOps/IR/Ops.h"
62+
#include "mlir/IR/Dominance.h"
6063
#include "mlir/IR/Operation.h"
64+
#include "mlir/Interfaces/ControlFlowInterfaces.h"
6165
#include "mlir/Pass/Pass.h"
62-
#include "mlir/Transforms/Bufferize.h"
6366
#include "mlir/Transforms/Passes.h"
6467
#include "llvm/ADT/SetOperations.h"
6568

@@ -809,245 +812,6 @@ struct BufferPlacementPass : BufferPlacementBase<BufferPlacementPass> {
809812

810813
} // end anonymous namespace
811814

812-
//===----------------------------------------------------------------------===//
813-
// BufferAssignmentTypeConverter
814-
//===----------------------------------------------------------------------===//
815-
816-
/// Registers conversions into BufferAssignmentTypeConverter
817-
BufferAssignmentTypeConverter::BufferAssignmentTypeConverter() {
818-
// Keep all types unchanged.
819-
addConversion([](Type type) { return type; });
820-
// Convert RankedTensorType to MemRefType.
821-
addConversion([](RankedTensorType type) {
822-
return (Type)MemRefType::get(type.getShape(), type.getElementType());
823-
});
824-
// Convert UnrankedTensorType to UnrankedMemRefType.
825-
addConversion([](UnrankedTensorType type) {
826-
return (Type)UnrankedMemRefType::get(type.getElementType(), 0);
827-
});
828-
}
829-
830-
/// This method tries to decompose a value of a certain type using provided
831-
/// decompose callback functions. If it is unable to do so, the original value
832-
/// is returned.
833-
void BufferAssignmentTypeConverter::tryDecomposeValue(
834-
OpBuilder &builder, Location loc, Type type, Value value,
835-
SmallVectorImpl<Value> &results) {
836-
for (auto conversion : decomposeValueConversions)
837-
if (conversion(builder, loc, type, value, results) != llvm::None)
838-
return;
839-
results.push_back(value);
840-
}
841-
842-
/// This method tries to decompose a type using provided decompose callback
843-
/// functions. If it is unable to do so, the original type is returned.
844-
void BufferAssignmentTypeConverter::tryDecomposeType(
845-
Type type, SmallVectorImpl<Type> &types) {
846-
for (auto conversion : decomposeTypeConversions)
847-
if (conversion(type, types) != llvm::None)
848-
return;
849-
types.push_back(type);
850-
}
851-
852-
/// This method returns ResultConversionKind for the input type.
853-
BufferAssignmentTypeConverter::ResultConversionKind
854-
BufferAssignmentTypeConverter::getResultConversionKind(Type origin,
855-
Type converted) {
856-
for (auto conversion : resultTypeConversions) {
857-
auto res = conversion(origin, converted);
858-
if (res != llvm::None)
859-
return res.getValue();
860-
}
861-
return KeepAsFunctionResult;
862-
}
863-
864-
//===----------------------------------------------------------------------===//
865-
// BufferAssignmentFuncOpConverter
866-
//===----------------------------------------------------------------------===//
867-
868-
/// Performs the actual function signature rewriting step.
869-
LogicalResult BufferAssignmentFuncOpConverter::matchAndRewrite(
870-
mlir::FuncOp funcOp, ArrayRef<Value> operands,
871-
ConversionPatternRewriter &rewriter) const {
872-
auto funcType = funcOp.getType();
873-
874-
// Convert function arguments using the provided TypeConverter.
875-
TypeConverter::SignatureConversion conversion(funcType.getNumInputs());
876-
for (auto argType : llvm::enumerate(funcType.getInputs())) {
877-
SmallVector<Type, 2> decomposedTypes, convertedTypes;
878-
converter.tryDecomposeType(argType.value(), decomposedTypes);
879-
converter.convertTypes(decomposedTypes, convertedTypes);
880-
conversion.addInputs(argType.index(), convertedTypes);
881-
}
882-
883-
// Convert the result types of the function.
884-
SmallVector<Type, 2> newResultTypes;
885-
newResultTypes.reserve(funcOp.getNumResults());
886-
for (Type resultType : funcType.getResults()) {
887-
SmallVector<Type, 2> originTypes;
888-
converter.tryDecomposeType(resultType, originTypes);
889-
for (auto origin : originTypes) {
890-
Type converted = converter.convertType(origin);
891-
auto kind = converter.getResultConversionKind(origin, converted);
892-
if (kind == BufferAssignmentTypeConverter::AppendToArgumentsList)
893-
conversion.addInputs(converted);
894-
else
895-
// kind = BufferAssignmentTypeConverter::KeepAsFunctionResult
896-
newResultTypes.push_back(converted);
897-
}
898-
}
899-
900-
if (failed(rewriter.convertRegionTypes(&funcOp.getBody(), converter,
901-
&conversion)))
902-
return failure();
903-
904-
// Update the signature of the function.
905-
rewriter.updateRootInPlace(funcOp, [&] {
906-
funcOp.setType(rewriter.getFunctionType(conversion.getConvertedTypes(),
907-
newResultTypes));
908-
});
909-
return success();
910-
}
911-
912-
//===----------------------------------------------------------------------===//
913-
// BufferAssignmentCallOpConverter
914-
//===----------------------------------------------------------------------===//
915-
916-
namespace {
917-
// This class represents a mapping from a result to a list of values and some
918-
// results that have not yet constructed. Instead, the indices of these
919-
// results in the operation that will be constructed are known. They will be
920-
// replaced with the actual values when they are available. The order of
921-
// adding to this mapping is important.
922-
class CallOpResultMapping {
923-
public:
924-
CallOpResultMapping() { order = 0; };
925-
926-
/// Add an available value to the mapping.
927-
void addMapping(Value value) { toValuesMapping.push_back({order++, value}); }
928-
929-
/// Add the index of unavailble result value to the mapping.
930-
void addMapping(unsigned index) {
931-
toIndicesMapping.push_back({order++, index});
932-
}
933-
934-
/// This method returns the mapping values list. The unknown result values
935-
/// that only their indicies are available are replaced with their values.
936-
void getMappingValues(ValueRange valuesToReplaceIndices,
937-
SmallVectorImpl<Value> &values) {
938-
// Append available values to the list.
939-
SmallVector<std::pair<unsigned, Value>, 2> res(toValuesMapping.begin(),
940-
toValuesMapping.end());
941-
// Replace the indices with the actual values.
942-
llvm::for_each(
943-
toIndicesMapping, [&](const std::pair<unsigned, unsigned> &entry) {
944-
assert(entry.second < valuesToReplaceIndices.size() &&
945-
"The value index is out of range.");
946-
res.push_back({entry.first, valuesToReplaceIndices[entry.second]});
947-
});
948-
// Sort the values based on their adding orders.
949-
llvm::sort(res, [](const std::pair<unsigned, Value> &v1,
950-
const std::pair<unsigned, Value> &v2) {
951-
return v1.first < v2.first;
952-
});
953-
// Fill the values.
954-
llvm::for_each(res, [&](const std::pair<unsigned, Value> &entry) {
955-
values.push_back(entry.second);
956-
});
957-
}
958-
959-
private:
960-
/// Keeping the inserting order of mapping values.
961-
int order;
962-
963-
/// Containing the mapping values with their inserting orders.
964-
SmallVector<std::pair<unsigned, Value>, 2> toValuesMapping;
965-
966-
/// Containing the indices of result values with their inserting orders.
967-
SmallVector<std::pair<unsigned, unsigned>, 2> toIndicesMapping;
968-
};
969-
} // namespace
970-
971-
/// Performs the actual rewriting step.
972-
LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite(
973-
CallOp callOp, ArrayRef<Value> operands,
974-
ConversionPatternRewriter &rewriter) const {
975-
976-
Location loc = callOp.getLoc();
977-
OpBuilder builder(callOp);
978-
SmallVector<Value, 2> newOperands;
979-
980-
// TODO: if the CallOp references a FuncOp that only has a declaration (e.g.
981-
// to an externally defined symbol like an external library calls), only
982-
// convert if some special attribute is set.
983-
// This will allow more control of interop across ABI boundaries.
984-
985-
// Create the operands list of the new `CallOp`. It unpacks the decomposable
986-
// values if a decompose callback function has been provided by the user.
987-
for (auto operand : operands) {
988-
SmallVector<Value, 2> values;
989-
this->converter.tryDecomposeValue(builder, loc, operand.getType(), operand,
990-
values);
991-
newOperands.append(values.begin(), values.end());
992-
}
993-
994-
// Create the new result types for the new `CallOp` and a mapping from the old
995-
// result to new value(s).
996-
SmallVector<Type, 2> newResultTypes;
997-
SmallVector<CallOpResultMapping, 4> mappings;
998-
mappings.resize(callOp.getNumResults());
999-
for (auto result : llvm::enumerate(callOp.getResults())) {
1000-
SmallVector<Type, 2> originTypes;
1001-
converter.tryDecomposeType(result.value().getType(), originTypes);
1002-
auto &resultMapping = mappings[result.index()];
1003-
for (Type origin : originTypes) {
1004-
Type converted = converter.convertType(origin);
1005-
auto kind = converter.getResultConversionKind(origin, converted);
1006-
if (kind == BufferAssignmentTypeConverter::KeepAsFunctionResult) {
1007-
newResultTypes.push_back(converted);
1008-
// The result value is not yet available. Its index is kept and it is
1009-
// replaced with the actual value of the new `CallOp` later.
1010-
resultMapping.addMapping(newResultTypes.size() - 1);
1011-
} else {
1012-
// kind = BufferAssignmentTypeConverter::AppendToArgumentsList
1013-
MemRefType memref = converted.dyn_cast<MemRefType>();
1014-
if (!memref)
1015-
return callOp.emitError("Cannot allocate for a non-Memref type");
1016-
Value alloc = rewriter.create<AllocOp>(loc, memref);
1017-
newOperands.push_back(alloc);
1018-
resultMapping.addMapping(alloc);
1019-
}
1020-
}
1021-
}
1022-
1023-
CallOp newCallOp = rewriter.create<CallOp>(loc, callOp.getCallee(),
1024-
newResultTypes, newOperands);
1025-
1026-
// Build a replacing value for each result to replace its uses. If a result
1027-
// has multiple mapping values, it needs to be packed to a single value.
1028-
OpBuilder nextBuilder(callOp.getOperation()->getNextNode());
1029-
SmallVector<Value, 2> replacedValues;
1030-
replacedValues.reserve(callOp.getNumResults());
1031-
for (unsigned i = 0, e = callOp.getNumResults(); i < e; ++i) {
1032-
SmallVector<Value, 2> valuesToPack;
1033-
mappings[i].getMappingValues(newCallOp.getResults(), valuesToPack);
1034-
if (valuesToPack.empty()) {
1035-
// No replacement is required.
1036-
replacedValues.push_back(nullptr);
1037-
} else if (valuesToPack.size() == 1) {
1038-
replacedValues.push_back(valuesToPack.front());
1039-
} else {
1040-
// Values need to be packed using callback function. The same callback
1041-
// that is used for materializeArgumentConversion is used for packing.
1042-
Value packed = converter.materializeArgumentConversion(
1043-
nextBuilder, loc, callOp.getType(i), valuesToPack);
1044-
replacedValues.push_back(packed);
1045-
}
1046-
}
1047-
rewriter.replaceOp(callOp, replacedValues);
1048-
return success();
1049-
}
1050-
1051815
//===----------------------------------------------------------------------===//
1052816
// BufferPlacementPass construction
1053817
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)