Skip to content

Commit 7a8e240

Browse files
bjacoblialan
authored andcommitted
Revert "[mlir][bufferization] Use Type instead of Value in unknown conversion (llvm#144658)"
This reverts commit a1c2a71.
1 parent 0d09249 commit 7a8e240

File tree

4 files changed

+19
-19
lines changed

4 files changed

+19
-19
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -265,9 +265,9 @@ struct BufferizationOptions {
265265
std::function<BaseMemRefType(TensorType, Attribute memorySpace,
266266
func::FuncOp, const BufferizationOptions &)>;
267267
/// Tensor -> MemRef type converter.
268-
/// Parameters: tensor type, memory space, bufferization options
268+
/// Parameters: Value, memory space, bufferization options
269269
using UnknownTypeConverterFn = std::function<BaseMemRefType(
270-
TensorType, Attribute memorySpace, const BufferizationOptions &)>;
270+
Value, Attribute memorySpace, const BufferizationOptions &)>;
271271
// Produce a MemorySpace attribute from a tensor type
272272
using DefaultMemorySpaceFn =
273273
std::function<std::optional<Attribute>(TensorType t)>;
@@ -655,7 +655,7 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
655655
return newOp;
656656
}
657657

658-
/// Return a MemRefType to which the TensorType can be bufferized.
658+
/// Return a MemRefType to which the type of the given value can be bufferized.
659659
///
660660
/// If possible, op bufferization implementations should not use this function
661661
/// and instead infer precise memref types for tensor results by themselves.
@@ -667,8 +667,7 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
667667
/// Note: Canonicalization patterns could clean up layout maps and infer more
668668
/// precise layout maps after bufferization. However, many possible
669669
/// canonicalizations are currently not implemented.
670-
BaseMemRefType getMemRefType(TensorType tensorType,
671-
const BufferizationOptions &options,
670+
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options,
672671
MemRefLayoutAttrInterface layout = {},
673672
Attribute memorySpace = nullptr);
674673

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -345,9 +345,10 @@ defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
345345
}
346346
/// Default unknown type converter: Use a fully dynamic layout map.
347347
BaseMemRefType
348-
defaultUnknownTypeConverter(TensorType tensorType, Attribute memorySpace,
348+
defaultUnknownTypeConverter(Value value, Attribute memorySpace,
349349
const BufferizationOptions &options) {
350-
return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace);
350+
return getMemRefTypeWithFullyDynamicLayout(
351+
llvm::cast<TensorType>(value.getType()), memorySpace);
351352
}
352353

353354
} // namespace
@@ -723,8 +724,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
723724
if (!memSpace.has_value())
724725
return op->emitError("could not infer memory space");
725726

726-
return getMemRefType(cast<TensorType>(value.getType()), options,
727-
/*layout=*/{}, *memSpace);
727+
return getMemRefType(value, options, /*layout=*/{}, *memSpace);
728728
}
729729

730730
bool bufferization::hasTensorSemantics(Operation *op) {
@@ -797,10 +797,12 @@ LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
797797
// Bufferization-specific IRMapping support with debugging.
798798
//===----------------------------------------------------------------------===//
799799

800-
BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
800+
BaseMemRefType bufferization::getMemRefType(Value value,
801801
const BufferizationOptions &options,
802802
MemRefLayoutAttrInterface layout,
803803
Attribute memorySpace) {
804+
auto tensorType = llvm::cast<TensorType>(value.getType());
805+
804806
// Case 1: Unranked memref type.
805807
if (auto unrankedTensorType =
806808
llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
@@ -817,7 +819,7 @@ BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
817819
memorySpace);
818820
}
819821

820-
return options.unknownTypeConverterFn(tensorType, memorySpace, options);
822+
return options.unknownTypeConverterFn(value, memorySpace, options);
821823
}
822824

823825
BaseMemRefType
@@ -953,11 +955,10 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
953955
const BufferizationState &bufferizationState,
954956
SmallVector<Value> &invocationStack) {
955957
assert(llvm::isa<TensorType>(value.getType()) && "expected tensor type");
956-
auto tensorType = cast<TensorType>(value.getType());
957958

958959
// No further analysis is possible for a block argument.
959960
if (llvm::isa<BlockArgument>(value))
960-
return bufferization::getMemRefType(tensorType, options);
961+
return bufferization::getMemRefType(value, options);
961962

962963
// Value is an OpResult.
963964
Operation *op = getOwnerOfValue(value);
@@ -980,7 +981,7 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
980981
if (!memSpace.has_value())
981982
return op->emitError("could not infer memory space");
982983

983-
return getMemRefType(tensorType, options, /*layout=*/{}, *memSpace);
984+
return getMemRefType(value, options, /*layout=*/{}, *memSpace);
984985
}
985986

986987
bool bufferization::detail::defaultIsRepetitiveRegion(

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,9 @@ struct OneShotBufferizePass
109109
"'unknown-type-conversion'");
110110
return signalPassFailure();
111111
}
112-
opt.unknownTypeConverterFn = [=](TensorType tensorType,
113-
Attribute memorySpace,
112+
opt.unknownTypeConverterFn = [=](Value value, Attribute memorySpace,
114113
const BufferizationOptions &options) {
114+
auto tensorType = cast<TensorType>(value.getType());
115115
if (unknownTypeConversionOption == LayoutMapOption::IdentityLayoutMap)
116116
return bufferization::getMemRefTypeWithStaticIdentityLayout(
117117
tensorType, memorySpace);

mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,10 +223,10 @@ mlir::getBufferizationOptionsForSparsification(bool analysisOnly) {
223223
OneShotBufferizationOptions options;
224224
options.bufferizeFunctionBoundaries = true;
225225
options.setFunctionBoundaryTypeConversion(LayoutMapOption::IdentityLayoutMap);
226-
options.unknownTypeConverterFn = [](TensorType tensorType,
227-
Attribute memorySpace,
226+
options.unknownTypeConverterFn = [](Value value, Attribute memorySpace,
228227
const BufferizationOptions &options) {
229-
return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
228+
return getMemRefTypeWithStaticIdentityLayout(
229+
cast<TensorType>(value.getType()), memorySpace);
230230
};
231231
if (analysisOnly) {
232232
options.testAnalysisOnly = true;

0 commit comments

Comments
 (0)