@@ -171,15 +171,6 @@ static void setInPlaceOpResult(OpResult opResult, bool inPlace) {
171
171
OpBuilder (op).getStrArrayAttr (inPlaceVector));
172
172
}
173
173
174
- // / Set the attribute that triggers inplace bufferization on a FuncOp argument
175
- // / `bbArg`.
176
- static void setInPlaceFuncArgument (BlockArgument bbArg, bool inPlace) {
177
- auto funcOp = cast<FuncOp>(bbArg.getOwner ()->getParentOp ());
178
- funcOp.setArgAttr (bbArg.getArgNumber (),
179
- BufferizableOpInterface::kInplaceableAttrName ,
180
- BoolAttr::get (bbArg.getContext (), inPlace));
181
- }
182
-
183
174
// ===----------------------------------------------------------------------===//
184
175
// Printing helpers.
185
176
// ===----------------------------------------------------------------------===//
@@ -258,25 +249,22 @@ static bool isInplaceMemoryWrite(OpOperand &opOperand,
258
249
// / Return true if, under current bufferization decisions, the buffer of `value`
259
250
// / is not writable.
260
251
static bool aliasesNonWritableBuffer (Value value,
261
- const BufferizationAliasInfo &aliasInfo) {
252
+ const BufferizationAliasInfo &aliasInfo,
253
+ BufferizationState &state) {
262
254
LDBG (" WRITABILITY ANALYSIS FOR " << printValueInfo (value) << " \n " );
263
255
bool foundNonWritableBuffer = false ;
264
256
aliasInfo.applyOnAliases (value, [&](Value v) {
265
- // Some values are known to be writable.
266
- if (aliasInfo.bufferizesToWritableMemory (v))
267
- return ;
268
-
269
257
// Query BufferizableOpInterface to see if the OpResult is writable.
270
258
// TODO: Out-of-place bufferized OpResult could be considered writable.
271
259
if (auto bufferizableOp = v.getDefiningOp <BufferizableOpInterface>())
272
- if (bufferizableOp && bufferizableOp.isWritable (v))
260
+ if (bufferizableOp && bufferizableOp.isWritable (v, state ))
273
261
return ;
274
262
275
263
// Query BufferizableOpInterface to see if the BlockArgument is writable.
276
264
if (auto bbArg = v.dyn_cast <BlockArgument>())
277
265
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(
278
266
bbArg.getOwner ()->getParentOp ()))
279
- if (bufferizableOp.isWritable (bbArg))
267
+ if (bufferizableOp.isWritable (bbArg, state ))
280
268
return ;
281
269
282
270
foundNonWritableBuffer = true ;
@@ -515,7 +503,8 @@ bool wouldCreateReadAfterWriteInterference(
515
503
// / a write to a non-writable buffer.
516
504
static bool
517
505
wouldCreateWriteToNonWritableBuffer (OpOperand &opOperand, OpResult opResult,
518
- const BufferizationAliasInfo &aliasInfo) {
506
+ const BufferizationAliasInfo &aliasInfo,
507
+ BufferizationState &state) {
519
508
#ifndef NDEBUG
520
509
SmallVector<OpOperand *> opOperands = getAliasingOpOperand (opResult);
521
510
assert (llvm::find (opOperands, &opOperand) != opOperands.end () &&
@@ -525,9 +514,10 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand, OpResult opResult,
525
514
// Certain buffers are not writeable:
526
515
// 1. A function bbArg that is not inplaceable or
527
516
// 2. A constant op.
528
- assert (!aliasesNonWritableBuffer (opResult, aliasInfo) &&
517
+ assert (!aliasesNonWritableBuffer (opResult, aliasInfo, state ) &&
529
518
" expected that opResult does not alias non-writable buffer" );
530
- bool nonWritable = aliasesNonWritableBuffer (opOperand.get (), aliasInfo);
519
+ bool nonWritable =
520
+ aliasesNonWritableBuffer (opOperand.get (), aliasInfo, state);
531
521
if (!nonWritable)
532
522
return false ;
533
523
@@ -547,10 +537,9 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand, OpResult opResult,
547
537
// ===----------------------------------------------------------------------===//
548
538
549
539
// / Determine if `operand` can be bufferized in-place with `result`.
550
- static LogicalResult
551
- bufferizableInPlaceAnalysisImpl (OpOperand &operand, OpResult result,
552
- BufferizationAliasInfo &aliasInfo,
553
- const DominanceInfo &domInfo) {
540
+ static LogicalResult bufferizableInPlaceAnalysisImpl (
541
+ OpOperand &operand, OpResult result, BufferizationAliasInfo &aliasInfo,
542
+ BufferizationState &state, const DominanceInfo &domInfo) {
554
543
#ifndef NDEBUG
555
544
SmallVector<OpOperand *> opOperands = getAliasingOpOperand (result);
556
545
assert (llvm::find (opOperands, &operand) != opOperands.end () &&
@@ -565,7 +554,7 @@ bufferizableInPlaceAnalysisImpl(OpOperand &operand, OpResult result,
565
554
<< printValueInfo (result) << ' \n ' );
566
555
567
556
bool foundInterference =
568
- wouldCreateWriteToNonWritableBuffer (operand, result, aliasInfo) ||
557
+ wouldCreateWriteToNonWritableBuffer (operand, result, aliasInfo, state ) ||
569
558
wouldCreateReadAfterWriteInterference (operand, result, domInfo,
570
559
aliasInfo);
571
560
@@ -599,6 +588,7 @@ bufferizableInPlaceAnalysisImpl(OpOperand &operand, OpResult result,
599
588
// / RaW dependence violations.
600
589
static LogicalResult inPlaceAnalysis (SmallVector<Operation *> &ops,
601
590
BufferizationAliasInfo &aliasInfo,
591
+ BufferizationState &state,
602
592
const DominanceInfo &domInfo,
603
593
unsigned analysisFuzzerSeed = 0 ) {
604
594
if (analysisFuzzerSeed) {
@@ -615,8 +605,8 @@ static LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
615
605
if (opOperand.get ().getType ().isa <TensorType>())
616
606
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
617
607
if (OpResult opResult = bufferizableOp.getAliasingOpResult (opOperand))
618
- if (failed (bufferizableInPlaceAnalysisImpl (opOperand, opResult,
619
- aliasInfo, domInfo)))
608
+ if (failed (bufferizableInPlaceAnalysisImpl (
609
+ opOperand, opResult, aliasInfo, state , domInfo)))
620
610
return failure ();
621
611
622
612
return success ();
@@ -625,6 +615,7 @@ static LogicalResult inPlaceAnalysis(SmallVector<Operation *> &ops,
625
615
// / Analyze all ops that are contained in `op`.
626
616
static LogicalResult inPlaceAnalysis (Operation *op,
627
617
BufferizationAliasInfo &aliasInfo,
618
+ BufferizationState &state,
628
619
const DominanceInfo &domInfo,
629
620
unsigned analysisFuzzerSeed = 0 ) {
630
621
// Collect ops so we can build our own reverse traversal.
@@ -637,7 +628,7 @@ static LogicalResult inPlaceAnalysis(Operation *op,
637
628
ops.push_back (op);
638
629
});
639
630
640
- return inPlaceAnalysis (ops, aliasInfo, domInfo, analysisFuzzerSeed);
631
+ return inPlaceAnalysis (ops, aliasInfo, state, domInfo, analysisFuzzerSeed);
641
632
}
642
633
643
634
// / Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
@@ -712,15 +703,9 @@ static void
712
703
annotateOpsWithBufferizationMarkers (Operation *op,
713
704
const BufferizationAliasInfo &aliasInfo) {
714
705
op->walk ([&](Operation *op) {
715
- for (OpResult opResult : op->getResults ()) {
706
+ for (OpResult opResult : op->getResults ())
716
707
if (opResult.getType ().isa <TensorType>())
717
708
setInPlaceOpResult (opResult, aliasInfo.isInPlace (opResult));
718
- if (auto funcOp = dyn_cast<FuncOp>(op))
719
- for (BlockArgument bbArg : funcOp.getArguments ())
720
- if (bbArg.getType ().isa <TensorType>())
721
- setInPlaceFuncArgument (bbArg,
722
- aliasInfo.bufferizesToWritableMemory (bbArg));
723
- }
724
709
});
725
710
}
726
711
@@ -739,8 +724,8 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
739
724
740
725
// If the analysis fails, just return.
741
726
Operation *op = funcOp.getOperation ();
742
- if (failed (
743
- inPlaceAnalysis (op, aliasInfo, domInfo, options.analysisFuzzerSeed )))
727
+ if (failed (inPlaceAnalysis (op, aliasInfo, state, domInfo,
728
+ options.analysisFuzzerSeed )))
744
729
return failure ();
745
730
equivalenceAnalysis (op, aliasInfo);
746
731
@@ -750,7 +735,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
750
735
if (failed (step->run (funcOp, state, newOps)))
751
736
return failure ();
752
737
// Analyze ops that were created by the PostAnalysisStep.
753
- if (failed (inPlaceAnalysis (newOps, aliasInfo, domInfo)))
738
+ if (failed (inPlaceAnalysis (newOps, aliasInfo, state, domInfo)))
754
739
return failure ();
755
740
equivalenceAnalysis (newOps, aliasInfo);
756
741
}
0 commit comments