Skip to content

Commit d026e75

Browse files
Address last remaining comments
1 parent c627d06 commit d026e75

File tree

3 files changed

+13
-15
lines changed

3 files changed

+13
-15
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,6 @@
1919
#include "mlir/Interfaces/SubsetOpInterface.h"
2020
#include "llvm/Support/Debug.h"
2121

22-
namespace mlir::bufferization::detail {
23-
bool tensorTypesMatchUpToEncoding(Type lhs, Type rhs);
24-
} // namespace mlir::bufferization::detail
25-
2622
//===----------------------------------------------------------------------===//
2723
// Bufferization Dialect
2824
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,6 @@ using namespace mlir::bufferization;
2323
// Helper functions
2424
//===----------------------------------------------------------------------===//
2525

26-
bool bufferization::detail::tensorTypesMatchUpToEncoding(Type lhs, Type rhs) {
27-
auto lhsType = cast<ShapedType>(lhs);
28-
auto rhsType = cast<ShapedType>(rhs);
29-
if (lhsType.getElementType() != rhsType.getElementType())
30-
return false;
31-
if (lhsType.hasRank() && rhsType.hasRank())
32-
return lhsType.getShape() == rhsType.getShape();
33-
return true;
34-
}
35-
3626
FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue(
3727
OpBuilder &b, Value value, MemRefType destType,
3828
const BufferizationOptions &options) {

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,19 +147,31 @@ struct OneShotBufferizePass
147147
opt.dumpAliasSets = dumpAliasSets;
148148
opt.setFunctionBoundaryTypeConversion(
149149
parseLayoutMapOption(functionBoundaryTypeConversion));
150+
151+
if (mustInferMemorySpace && useEncodingForMemorySpace) {
152+
emitError(getOperation()->getLoc())
153+
<< "only one of 'must-infer-memory-space' and "
154+
"'use-encoding-for-memory-space' are allowed in "
155+
<< getArgument();
156+
return signalPassFailure();
157+
}
158+
150159
if (mustInferMemorySpace) {
151160
opt.defaultMemorySpaceFn =
152161
[](TensorType t) -> std::optional<Attribute> {
153162
return std::nullopt;
154163
};
155-
} else if (useEncodingForMemorySpace) {
164+
}
165+
166+
if (useEncodingForMemorySpace) {
156167
opt.defaultMemorySpaceFn =
157168
[](TensorType t) -> std::optional<Attribute> {
158169
if (auto rtt = dyn_cast<RankedTensorType>(t))
159170
return rtt.getEncoding();
160171
return std::nullopt;
161172
};
162173
}
174+
163175
opt.printConflicts = printConflicts;
164176
opt.bufferAlignment = bufferAlignment;
165177
opt.testAnalysisOnly = testAnalysisOnly;

0 commit comments

Comments
 (0)