@@ -23,9 +23,9 @@ using namespace mlir::bufferization;
23
23
// Helper functions
24
24
// ===----------------------------------------------------------------------===//
25
25
26
- FailureOr<Value>
27
- mlir::bufferization::castOrReallocMemRefValue ( OpBuilder &b, Value value,
28
- MemRefType destType ) {
26
+ FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue (
27
+ OpBuilder &b, Value value, MemRefType destType ,
28
+ const BufferizationOptions &options ) {
29
29
auto srcType = llvm::cast<MemRefType>(value.getType ());
30
30
31
31
// Element type, rank and memory space must match.
@@ -73,18 +73,21 @@ mlir::bufferization::castOrReallocMemRefValue(OpBuilder &b, Value value,
73
73
Value size = b.create <memref::DimOp>(loc, value, i);
74
74
dynamicOperands.push_back (size);
75
75
}
76
- // TODO: Use alloc/memcpy callback from BufferizationOptions if called via
77
- // BufferizableOpInterface impl of ToMemrefOp.
78
- Value copy = b.create <memref::AllocOp>(loc, destType, dynamicOperands);
79
- b.create <memref::CopyOp>(loc, value, copy);
76
+
77
+ FailureOr<Value> copy =
78
+ options.createAlloc (b, loc, destType, dynamicOperands);
79
+ if (failed (copy))
80
+ return failure ();
81
+ if (failed (options.createMemCpy (b, loc, value, *copy)))
82
+ return failure ();
80
83
return copy;
81
84
}
82
85
83
86
// / Try to fold to_memref(to_tensor(x)). If x's type and the result type of the
84
87
// / to_memref op are different, a memref.cast is needed.
85
- LogicalResult
86
- mlir::bufferization::foldToMemrefToTensorPair ( RewriterBase &rewriter,
87
- ToMemrefOp toMemref ) {
88
+ LogicalResult mlir::bufferization::foldToMemrefToTensorPair (
89
+ RewriterBase &rewriter, ToMemrefOp toMemref ,
90
+ const BufferizationOptions &options ) {
88
91
auto memrefToTensor = toMemref.getTensor ().getDefiningOp <ToTensorOp>();
89
92
if (!memrefToTensor)
90
93
return failure ();
@@ -105,7 +108,7 @@ mlir::bufferization::foldToMemrefToTensorPair(RewriterBase &rewriter,
105
108
// Ranked memref -> Ranked memref cast.
106
109
if (rankedSrcType && rankedDestType) {
107
110
FailureOr<Value> replacement = castOrReallocMemRefValue (
108
- rewriter, memrefToTensor.getMemref (), rankedDestType);
111
+ rewriter, memrefToTensor.getMemref (), rankedDestType, options );
109
112
if (failed (replacement))
110
113
return failure ();
111
114
@@ -795,7 +798,9 @@ struct ToMemrefToTensorFolding : public OpRewritePattern<ToMemrefOp> {
795
798
796
799
LogicalResult matchAndRewrite (ToMemrefOp toMemref,
797
800
PatternRewriter &rewriter) const final {
798
- return foldToMemrefToTensorPair (rewriter, toMemref);
801
+ BufferizationOptions options;
802
+ options.bufferAlignment = 0 ;
803
+ return foldToMemrefToTensorPair (rewriter, toMemref, options);
799
804
}
800
805
};
801
806
@@ -843,7 +848,7 @@ void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results,
843
848
LogicalResult ToMemrefOp::bufferize (RewriterBase &rewriter,
844
849
const BufferizationOptions &options) {
845
850
// Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary.
846
- (void )foldToMemrefToTensorPair (rewriter, *this );
851
+ (void )foldToMemrefToTensorPair (rewriter, *this , options );
847
852
// Note: The return value of `bufferize` indicates whether there was an error
848
853
// or not. (And not whether the pattern matched or not.)
849
854
return success ();
0 commit comments