Skip to content

Commit 4ec00fb

Browse files
[mlir][bufferize] Add a way for ops to fail the analysis
Add `BufferizableOpInterface::verifyAnalysis`. Ops can implement this method to check for expected invariants and limitations. The purpose of this change is to introduce a modular way of checking assertions such as `assertScfForAliasingProperties`. Differential Revision: https://reviews.llvm.org/D120189
1 parent 24bfa24 commit 4ec00fb

File tree

8 files changed

+65
-83
lines changed

8 files changed

+65
-83
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,23 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
290290
/*defaultImplementation=*/[{
291291
return false;
292292
}]
293+
>,
294+
InterfaceMethod<
295+
/*desc=*/[{
296+
Return `failure` if this op does not pass the analysis. This method
297+
is run during One-Shot Bufferize (after all post-analysis steps). If
298+
the op does not pass the analysis, bufferization is aborted.
299+
300+
This method can be used to check expected invariants and limitations
301+
of the current bufferization implementation.
302+
}],
303+
/*retType=*/"LogicalResult",
304+
/*methodName=*/"verifyAnalysis",
305+
/*args=*/(ins "const BufferizationState &":$state),
306+
/*methodBody=*/"",
307+
/*defaultImplementation=*/[{
308+
return success();
309+
}]
293310
>
294311
];
295312

mlir/include/mlir/Dialect/SCF/BufferizableOpInterfaceImpl.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,6 @@ class BufferizationAliasInfo;
2020
} // namespace bufferization
2121

2222
namespace scf {
23-
/// Assert that yielded values of an scf.for op are aliasing their corresponding
24-
/// bbArgs. This is required because the i-th OpResult of an scf.for op is
25-
/// currently assumed to alias with the i-th iter_arg (in the absence of
26-
/// conflicts).
27-
LogicalResult
28-
assertScfForAliasingProperties(Operation *op,
29-
bufferization::BufferizationState &state,
30-
bufferization::BufferizationAliasInfo &aliasInfo,
31-
SmallVector<Operation *> &newOps);
32-
3323
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
3424
} // namespace scf
3525
} // namespace mlir

mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,19 @@ LogicalResult bufferization::analyzeOp(Operation *op,
778778
return failure();
779779
}
780780

781+
// Analysis verification: After setting up alias/equivalence sets, each op
782+
// can check for expected invariants/limitations and fail the analysis if
783+
// necessary.
784+
bool passedAnalysis = true;
785+
op->walk([&](Operation *op) {
786+
if (BufferizableOpInterface bufferizableOp =
787+
options.dynCastBufferizableOp(op))
788+
if (failed(bufferizableOp.verifyAnalysis(state)))
789+
passedAnalysis = false;
790+
});
791+
if (!passedAnalysis)
792+
return failure();
793+
781794
// Annotate operations if we only want to report the analysis.
782795
if (options.testAnalysisOnly)
783796
annotateOpsWithBufferizationMarkers(op, aliasInfo, state);

mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,6 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
105105
opt = *options;
106106
}
107107

108-
// Only certain scf.for ops are supported by the analysis.
109-
opt.addPostAnalysisStep(scf::assertScfForAliasingProperties);
110-
111108
ModuleOp moduleOp = getOperation();
112109
applyEnablingTransformations(moduleOp);
113110

mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,37 @@ struct ForOpInterface
385385

386386
return success();
387387
}
388+
389+
/// Assert that yielded values of an scf.for op are aliasing with their
390+
/// corresponding bbArgs. This is required because the i-th OpResult of an
391+
/// scf.for op is currently assumed to alias with the i-th iter_arg (in the
392+
/// absence of conflicts).
393+
LogicalResult verifyAnalysis(Operation *op,
394+
const BufferizationState &state) const {
395+
auto forOp = cast<scf::ForOp>(op);
396+
auto yieldOp =
397+
cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
398+
for (OpOperand &operand : yieldOp->getOpOperands()) {
399+
auto tensorType = operand.get().getType().dyn_cast<TensorType>();
400+
if (!tensorType)
401+
continue;
402+
403+
OpOperand &forOperand = forOp.getOpOperandForResult(
404+
forOp->getResult(operand.getOperandNumber()));
405+
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
406+
// Note: This is overly strict. We should check for aliasing bufferized
407+
// values. But we don't have a "must-alias" analysis yet.
408+
if (!state.areEquivalentBufferizedValues(operand.get(), bbArg))
409+
// TODO: this could get resolved with copies but it can also turn into
410+
// swaps so we need to be careful about order of copies.
411+
return yieldOp->emitError()
412+
<< "Yield operand #" << operand.getOperandNumber()
413+
<< " does not bufferize to a buffer that is aliasing the "
414+
"matching"
415+
<< " enclosing scf::for operand";
416+
}
417+
return success();
418+
}
388419
};
389420

390421
/// Bufferization of scf.yield. Bufferized as part of their enclosing ops, so
@@ -434,41 +465,6 @@ struct YieldOpInterface
434465
} // namespace scf
435466
} // namespace mlir
436467

437-
LogicalResult mlir::scf::assertScfForAliasingProperties(
438-
Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo,
439-
SmallVector<Operation *> &newOps) {
440-
LogicalResult status = success();
441-
442-
op->walk([&](scf::ForOp forOp) {
443-
auto yieldOp =
444-
cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
445-
for (OpOperand &operand : yieldOp->getOpOperands()) {
446-
auto tensorType = operand.get().getType().dyn_cast<TensorType>();
447-
if (!tensorType)
448-
continue;
449-
450-
OpOperand &forOperand = forOp.getOpOperandForResult(
451-
forOp->getResult(operand.getOperandNumber()));
452-
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
453-
// Note: This is overly strict. We should check for aliasing bufferized
454-
// values. But we don't have a "must-alias" analysis yet.
455-
if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) {
456-
// TODO: this could get resolved with copies but it can also turn into
457-
// swaps so we need to be careful about order of copies.
458-
status =
459-
yieldOp->emitError()
460-
<< "Yield operand #" << operand.getOperandNumber()
461-
<< " does not bufferize to a buffer that is aliasing the matching"
462-
<< " enclosing scf::for operand";
463-
return WalkResult::interrupt();
464-
}
465-
}
466-
return WalkResult::advance();
467-
});
468-
469-
return status;
470-
}
471-
472468
void mlir::scf::registerBufferizableOpInterfaceExternalModels(
473469
DialectRegistry &registry) {
474470
registry.addOpInterface<ExecuteRegionOp, ExecuteRegionOpInterface>();

mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ func @scf_for(%A : tensor<?xf32>,
8787
%B : tensor<?xf32> {linalg.inplaceable = true},
8888
%C : tensor<4xf32>,
8989
%lb : index, %ub : index, %step : index)
90-
-> (tensor<?xf32>, tensor<?xf32>)
90+
-> (f32, f32)
9191
{
9292
%r0:2 = scf.for %i = %lb to %ub step %step iter_args(%tA = %A, %tB = %B)
9393
-> (tensor<?xf32>, tensor<?xf32>)
@@ -102,7 +102,9 @@ func @scf_for(%A : tensor<?xf32>,
102102
scf.yield %ttB, %ttA : tensor<?xf32>, tensor<?xf32>
103103
}
104104

105-
return %r0#0, %r0#1: tensor<?xf32>, tensor<?xf32>
105+
%f0 = tensor.extract %r0#0[%step] : tensor<?xf32>
106+
%f1 = tensor.extract %r0#1[%step] : tensor<?xf32>
107+
return %f0, %f1: f32, f32
106108
}
107109

108110
// -----

mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -179,35 +179,6 @@ func @simple_tensor_test(%t1 : tensor<?xf32>, %f : f32) -> tensor<?xf32> {
179179

180180
// -----
181181

182-
// CHECK-SCF-LABEL: func @simple_scf_for(
183-
// CHECK-SCF-SAME: %[[t1:.*]]: tensor<?xf32>
184-
func @simple_scf_for(
185-
%t1: tensor<?xf32>, %sz: index, %step: index, %f: f32) -> tensor<?xf32> {
186-
%c0 = arith.constant 0 : index
187-
188-
// CHECK-SCF: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]]
189-
// CHECK-SCF: %[[alloc:.*]] = memref.alloc
190-
// CHECK-SCF: %[[casted:.*]] = memref.cast %[[alloc]]
191-
// CHECK-SCF: memref.copy %[[t1_memref]], %[[alloc]]
192-
// CHECK-SCF: %[[scf_for:.*]] = scf.for %[[iv:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[arg0:.*]] = %[[casted]]) -> ({{.*}}) {
193-
%0 = scf.for %iv = %c0 to %sz step %step iter_args(%arg0 = %t1) -> tensor<?xf32> {
194-
// CHECK-SCF: %[[arg0_tensor:.*]] = bufferization.to_tensor %[[arg0]]
195-
// CHECK-SCF: %[[insert:.*]] = tensor.insert %{{.*}} into %[[arg0_tensor]]
196-
%1 = tensor.insert %f into %arg0[%iv] : tensor<?xf32>
197-
198-
// CHECK-SCF: %[[insert_memref:.*]] = bufferization.to_memref %[[insert]]
199-
// CHECK-SCF: scf.yield %[[insert_memref]]
200-
scf.yield %1 : tensor<?xf32>
201-
}
202-
// CHECK-SCF: }
203-
204-
// CHECK-SCF: %[[scf_for_tensor:.*]] = bufferization.to_tensor %[[scf_for]]
205-
// CHECK-SCF: return %[[scf_for_tensor]]
206-
return %0 : tensor<?xf32>
207-
}
208-
209-
// -----
210-
211182
// CHECK-SCF-LABEL: func @simple_scf_if(
212183
// CHECK-SCF-SAME: %[[t1:.*]]: tensor<?xf32> {linalg.inplaceable = true}, %[[c:.*]]: i1, %[[pos:.*]]: index
213184
func @simple_scf_if(%t1: tensor<?xf32> {linalg.inplaceable = true}, %c: i1, %pos: index, %f: f32)

mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,6 @@ struct TestComprehensiveFunctionBufferize
102102

103103
void TestComprehensiveFunctionBufferize::runOnOperation() {
104104
auto options = std::make_unique<AnalysisBufferizationOptions>();
105-
106-
if (!allowReturnMemref)
107-
options->addPostAnalysisStep(scf::assertScfForAliasingProperties);
108-
109105
options->allowReturnMemref = allowReturnMemref;
110106
options->allowUnknownOps = allowUnknownOps;
111107
options->testAnalysisOnly = testAnalysisOnly;

0 commit comments

Comments
 (0)