Skip to content

Commit 8f2d83d

Browse files
[mlir][bufferization] Add BufferizableOpInterface::hasTensorSemantics (#75273)
Add a new interface method to `BufferizableOpInterface`: `hasTensorSemantics`. This method returns "true" if the op has tensor semantics and should be bufferized. Until now, we assumed that an op has tensor semantics if it has tensor operands and/or tensor op results. However, there are ops like `ml_program.global` that do not have any results/operands but must still be bufferized (#75103). The new interface method can return "true" for such ops. This change also decouples `bufferization::bufferizeOp` a bit from the func dialect.
1 parent 4b0a76a commit 8f2d83d

File tree

7 files changed

+83
-32
lines changed

7 files changed

+83
-32
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,12 @@ FailureOr<BaseMemRefType> getBufferType(Value value,
601601
const BufferizationOptions &options,
602602
SmallVector<Value> &invocationStack);
603603

604+
/// Return "true" if the given op has tensor semantics and should be bufferized.
605+
/// If the op is bufferizable, the BufferizableOpInterface is queried.
606+
/// Otherwise, an op has tensor semantics if it has tensor operands, tensor
607+
/// op results and/or tensor block arguments.
608+
bool hasTensorSemantics(Operation *op);
609+
604610
/// Replace an op with replacement values. The op is deleted. Tensor OpResults
605611
/// must be replaced with memref values.
606612
void replaceOpWithBufferizedValues(RewriterBase &rewriter, Operation *op,
@@ -694,6 +700,10 @@ AliasingOpOperandList unknownGetAliasingOpOperands(Value value);
694700
/// This is the default implementation of getAliasingValues in case the owner
695701
/// op does not implement the BufferizableOpInterface.
696702
AliasingValueList unknownGetAliasingValues(OpOperand &opOperand);
703+
704+
/// This is the default implementation of
705+
/// BufferizableOpInterface::hasTensorSemantics
706+
bool defaultHasTensorSemantics(Operation *op);
697707
} // namespace detail
698708

699709
} // namespace bufferization

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,25 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
575575
return false;
576576
}]
577577
>,
578+
InterfaceMethod<
579+
/*desc=*/[{
580+
Return "true" if the this op has tensor semantics and should be
581+
bufferized. By default, ops with tensor operands, tensor op results
582+
and/or tensor block arguments have tensor semantics.
583+
584+
This interface methods can be implemented by ops that should be
585+
bufferized but do not have tensor semantics according to the above
586+
definition. E.g., this function can return "true" for symbols.
587+
}],
588+
/*retType=*/"bool",
589+
/*methodName=*/"hasTensorSemantics",
590+
/*args=*/(ins),
591+
/*methodBody=*/"",
592+
/*defaultImplementation=*/[{
593+
return ::mlir::bufferization::detail
594+
::defaultHasTensorSemantics($_op.getOperation());
595+
}]
596+
>,
578597
StaticInterfaceMethod<
579598
/*desc=*/[{
580599
Return `true` if the op and this interface implementation supports

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,6 +689,12 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options,
689689
*options.defaultMemorySpace);
690690
}
691691

692+
bool bufferization::hasTensorSemantics(Operation *op) {
693+
if (auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op))
694+
return bufferizableOp.hasTensorSemantics();
695+
return detail::defaultHasTensorSemantics(op);
696+
}
697+
692698
void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
693699
Operation *op,
694700
ValueRange values) {
@@ -989,3 +995,20 @@ bufferization::detail::unknownGetAliasingValues(OpOperand &opOperand) {
989995
r.addAlias({bbArg, BufferRelation::Unknown, /*isDefinite=*/false});
990996
return r;
991997
}
998+
999+
bool bufferization::detail::defaultHasTensorSemantics(Operation *op) {
1000+
auto isaTensor = [](Type t) { return isa<TensorType>(t); };
1001+
bool hasTensorBlockArgument = any_of(op->getRegions(), [&](Region &r) {
1002+
return any_of(r.getBlocks(), [&](Block &b) {
1003+
return any_of(b.getArguments(), [&](BlockArgument bbArg) {
1004+
return isaTensor(bbArg.getType());
1005+
});
1006+
});
1007+
});
1008+
if (hasTensorBlockArgument)
1009+
return true;
1010+
1011+
if (any_of(op->getResultTypes(), isaTensor))
1012+
return true;
1013+
return any_of(op->getOperandTypes(), isaTensor);
1014+
}

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -350,31 +350,6 @@ mlir::bufferization::createFinalizingBufferizePass() {
350350
// BufferizableOpInterface-based Bufferization
351351
//===----------------------------------------------------------------------===//
352352

353-
static bool isaTensor(Type t) { return isa<TensorType>(t); }
354-
355-
/// Return true if the given op has a tensor result or a tensor operand.
356-
static bool hasTensorSemantics(Operation *op) {
357-
bool hasTensorBlockArgument = any_of(op->getRegions(), [](Region &r) {
358-
return any_of(r.getBlocks(), [](Block &b) {
359-
return any_of(b.getArguments(), [](BlockArgument bbArg) {
360-
return isaTensor(bbArg.getType());
361-
});
362-
});
363-
});
364-
if (hasTensorBlockArgument)
365-
return true;
366-
367-
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
368-
bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor);
369-
bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor);
370-
return hasTensorArg || hasTensorResult;
371-
}
372-
373-
bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
374-
bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
375-
return hasTensorResult || hasTensorOperand;
376-
}
377-
378353
namespace {
379354
/// A rewriter that keeps track of extra information during bufferization.
380355
class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,28 @@ struct FuncOpInterface
325325

326326
static bool supportsUnstructuredControlFlow() { return true; }
327327

328+
bool hasTensorSemantics(Operation *op) const {
329+
auto isaTensor = [](Type type) { return isa<TensorType>(type); };
330+
331+
// A function has tensor semantics if it has tensor arguments/results.
332+
auto funcOp = cast<FuncOp>(op);
333+
bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor);
334+
bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor);
335+
if (hasTensorArg || hasTensorResult)
336+
return true;
337+
338+
// It also has tensor semantics if it has tensor block arguments.
339+
// TODO: Decouple bufferization of unstructured control flow from
340+
// BufferizableOpInterface implementations. We should only care about
341+
// region entry block arguments here (which are already covered by the
342+
// argument types of the function).
343+
for (Block &block : funcOp.getBody())
344+
if (any_of(block.getArgumentTypes(), isaTensor))
345+
return true;
346+
347+
return false;
348+
}
349+
328350
AliasingOpOperandList
329351
getAliasingOpOperands(Operation *op, Value value,
330352
const AnalysisState &state) const {

mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,13 +1030,6 @@ OneShotAnalysisState::analyzeSingleOp(Operation *op,
10301030
return success();
10311031
}
10321032

1033-
/// Return true if the given op has a tensor result or a tensor operand.
1034-
static bool hasTensorSemantics(Operation *op) {
1035-
bool hasTensorResult = any_of(op->getResultTypes(), isaTensor);
1036-
bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor);
1037-
return hasTensorResult || hasTensorOperand;
1038-
}
1039-
10401033
/// Analyze equivalence of tied OpResult/OpOperand pairs of the given ops.
10411034
static void equivalenceAnalysis(SmallVector<Operation *> &ops,
10421035
OneShotAnalysisState &state) {

mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,15 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
458458
foldMemRefCasts(funcOp);
459459
}
460460

461+
// Bufferize all other ops.
462+
for (Operation &op : moduleOp.getOps()) {
463+
// Functions were already bufferized.
464+
if (isa<func::FuncOp>(&op))
465+
continue;
466+
if (failed(bufferizeOp(&op, options, statistics)))
467+
return failure();
468+
}
469+
461470
// Post-pass cleanup of function argument attributes.
462471
removeBufferizationAttributesInModule(moduleOp);
463472

0 commit comments

Comments
 (0)