@@ -345,9 +345,10 @@ defaultFunctionArgTypeConverter(TensorType type, Attribute memorySpace,
345
345
}
346
346
// / Default unknown type converter: Use a fully dynamic layout map.
347
347
BaseMemRefType
348
- defaultUnknownTypeConverter (TensorType tensorType , Attribute memorySpace,
348
+ defaultUnknownTypeConverter (Value value , Attribute memorySpace,
349
349
const BufferizationOptions &options) {
350
- return getMemRefTypeWithFullyDynamicLayout (tensorType, memorySpace);
350
+ return getMemRefTypeWithFullyDynamicLayout (
351
+ llvm::cast<TensorType>(value.getType ()), memorySpace);
351
352
}
352
353
353
354
} // namespace
@@ -723,8 +724,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
723
724
if (!memSpace.has_value ())
724
725
return op->emitError (" could not infer memory space" );
725
726
726
- return getMemRefType (cast<TensorType>(value.getType ()), options,
727
- /* layout=*/ {}, *memSpace);
727
+ return getMemRefType (value, options, /* layout=*/ {}, *memSpace);
728
728
}
729
729
730
730
bool bufferization::hasTensorSemantics (Operation *op) {
@@ -797,10 +797,12 @@ LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc,
797
797
// Bufferization-specific IRMapping support with debugging.
798
798
// ===----------------------------------------------------------------------===//
799
799
800
- BaseMemRefType bufferization::getMemRefType (TensorType tensorType ,
800
+ BaseMemRefType bufferization::getMemRefType (Value value ,
801
801
const BufferizationOptions &options,
802
802
MemRefLayoutAttrInterface layout,
803
803
Attribute memorySpace) {
804
+ auto tensorType = llvm::cast<TensorType>(value.getType ());
805
+
804
806
// Case 1: Unranked memref type.
805
807
if (auto unrankedTensorType =
806
808
llvm::dyn_cast<UnrankedTensorType>(tensorType)) {
@@ -817,7 +819,7 @@ BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
817
819
memorySpace);
818
820
}
819
821
820
- return options.unknownTypeConverterFn (tensorType , memorySpace, options);
822
+ return options.unknownTypeConverterFn (value , memorySpace, options);
821
823
}
822
824
823
825
BaseMemRefType
@@ -953,11 +955,10 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
953
955
const BufferizationState &bufferizationState,
954
956
SmallVector<Value> &invocationStack) {
955
957
assert (llvm::isa<TensorType>(value.getType ()) && " expected tensor type" );
956
- auto tensorType = cast<TensorType>(value.getType ());
957
958
958
959
// No further analysis is possible for a block argument.
959
960
if (llvm::isa<BlockArgument>(value))
960
- return bufferization::getMemRefType (tensorType , options);
961
+ return bufferization::getMemRefType (value , options);
961
962
962
963
// Value is an OpResult.
963
964
Operation *op = getOwnerOfValue (value);
@@ -980,7 +981,7 @@ FailureOr<BaseMemRefType> bufferization::detail::defaultGetBufferType(
980
981
if (!memSpace.has_value ())
981
982
return op->emitError (" could not infer memory space" );
982
983
983
- return getMemRefType (tensorType , options, /* layout=*/ {}, *memSpace);
984
+ return getMemRefType (value , options, /* layout=*/ {}, *memSpace);
984
985
}
985
986
986
987
bool bufferization::detail::defaultIsRepetitiveRegion (
0 commit comments