Skip to content

Commit 38e53fb

Browse files
author
git apple-llvm automerger
committed
Merge commit '71eeb5ec4d6e' from llvm.org/master into apple/main
2 parents 490fc19 + 71eeb5e commit 38e53fb

File tree

7 files changed

+127
-37
lines changed

7 files changed

+127
-37
lines changed

mlir/docs/Interfaces.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,4 +231,12 @@ format of the header for each interface section goes as follows:
231231

232232
##### SymbolInterfaces
233233

234-
* `SymbolOpInterface` - Used to represent [`Symbol`](SymbolsAndSymbolTables.md#symbol) operations which reside immediately within a region that defines a [`SymbolTable`](SymbolsAndSymbolTables.md#symbol-table).
234+
* `SymbolOpInterface` - Used to represent
235+
[`Symbol`](SymbolsAndSymbolTables.md#symbol) operations which reside
236+
immediately within a region that defines a
237+
[`SymbolTable`](SymbolsAndSymbolTables.md#symbol-table).
238+
239+
* `SymbolUserOpInterface` - Used to represent operations that reference
240+
[`Symbol`](SymbolsAndSymbolTables.md#symbol) operations. This provides the
241+
ability to perform safe and efficient verification of symbol uses, as well
242+
as additional functionality.

mlir/docs/SymbolsAndSymbolTables.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@ See the `LangRef` definition of the
142142
[`SymbolRefAttr`](LangRef.md#symbol-reference-attribute) for more information
143143
about the structure of this attribute.
144144

145+
Operations that reference a `Symbol` and want to perform verification and
146+
general mutation of the symbol should implement the `SymbolUserOpInterface` to
147+
ensure that symbol accesses are legal and efficient.
148+
145149
### Manipulating a Symbol
146150

147151
As described above, `SymbolRefs` act as an auxiliary way of defining uses of

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td"
1717
include "mlir/IR/OpAsmInterface.td"
18+
include "mlir/IR/SymbolInterfaces.td"
1819
include "mlir/Interfaces/CallInterfaces.td"
1920
include "mlir/Interfaces/ControlFlowInterfaces.td"
2021
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -733,7 +734,9 @@ def BranchOp : Std_Op<"br",
733734
// CallOp
734735
//===----------------------------------------------------------------------===//
735736

736-
def CallOp : Std_Op<"call", [CallOpInterface, MemRefsNormalizable]> {
737+
def CallOp : Std_Op<"call",
738+
[CallOpInterface, MemRefsNormalizable,
739+
DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
737740
let summary = "call operation";
738741
let description = [{
739742
The `call` operation represents a direct call to a function that is within
@@ -788,6 +791,7 @@ def CallOp : Std_Op<"call", [CallOpInterface, MemRefsNormalizable]> {
788791
let assemblyFormat = [{
789792
$callee `(` $operands `)` attr-dict `:` functional-type($operands, results)
790793
}];
794+
let verifier = ?;
791795
}
792796

793797
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/SymbolInterfaces.td

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,27 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
158158
}];
159159
}
160160

161+
//===----------------------------------------------------------------------===//
162+
// SymbolUserOpInterface
163+
//===----------------------------------------------------------------------===//
164+
165+
def SymbolUserOpInterface : OpInterface<"SymbolUserOpInterface"> {
166+
let description = [{
167+
This interface describes an operation that may use a `Symbol`. This
168+
interface allows for users of symbols to hook into verification and other
169+
symbol related utilities that are either costly or otherwise disallowed
170+
within a traditional operation.
171+
}];
172+
let cppNamespace = "::mlir";
173+
174+
let methods = [
175+
InterfaceMethod<"Verify the symbol uses held by this operation.",
176+
"LogicalResult", "verifySymbolUses",
177+
(ins "::mlir::SymbolTableCollection &":$symbolTable)
178+
>,
179+
];
180+
}
181+
161182
//===----------------------------------------------------------------------===//
162183
// Symbol Traits
163184
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/SymbolTable.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,21 @@ class SymbolTableCollection {
236236
LogicalResult lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name,
237237
SmallVectorImpl<Operation *> &symbols);
238238

239+
/// Returns the operation registered with the given symbol name within the
240+
/// closest parent operation of, or including, 'from' with the
241+
/// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
242+
/// found.
243+
Operation *lookupNearestSymbolFrom(Operation *from, StringRef symbol);
244+
Operation *lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol);
245+
template <typename T>
246+
T lookupNearestSymbolFrom(Operation *from, StringRef symbol) {
247+
return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
248+
}
249+
template <typename T>
250+
T lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol) {
251+
return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
252+
}
253+
239254
/// Lookup, or create, a symbol table for an operation.
240255
SymbolTable &getSymbolTable(Operation *op);
241256

mlir/lib/Dialect/StandardOps/IR/Ops.cpp

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -740,34 +740,33 @@ Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) { return dest(); }
740740
// CallOp
741741
//===----------------------------------------------------------------------===//
742742

743-
static LogicalResult verify(CallOp op) {
743+
LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
744744
// Check that the callee attribute was specified.
745-
auto fnAttr = op.getAttrOfType<FlatSymbolRefAttr>("callee");
745+
auto fnAttr = getAttrOfType<FlatSymbolRefAttr>("callee");
746746
if (!fnAttr)
747-
return op.emitOpError("requires a 'callee' symbol reference attribute");
748-
auto fn =
749-
op.getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(fnAttr.getValue());
747+
return emitOpError("requires a 'callee' symbol reference attribute");
748+
FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
750749
if (!fn)
751-
return op.emitOpError() << "'" << fnAttr.getValue()
752-
<< "' does not reference a valid function";
750+
return emitOpError() << "'" << fnAttr.getValue()
751+
<< "' does not reference a valid function";
753752

754753
// Verify that the operand and result types match the callee.
755754
auto fnType = fn.getType();
756-
if (fnType.getNumInputs() != op.getNumOperands())
757-
return op.emitOpError("incorrect number of operands for callee");
755+
if (fnType.getNumInputs() != getNumOperands())
756+
return emitOpError("incorrect number of operands for callee");
758757

759758
for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
760-
if (op.getOperand(i).getType() != fnType.getInput(i))
761-
return op.emitOpError("operand type mismatch: expected operand type ")
759+
if (getOperand(i).getType() != fnType.getInput(i))
760+
return emitOpError("operand type mismatch: expected operand type ")
762761
<< fnType.getInput(i) << ", but provided "
763-
<< op.getOperand(i).getType() << " for operand number " << i;
762+
<< getOperand(i).getType() << " for operand number " << i;
764763

765-
if (fnType.getNumResults() != op.getNumResults())
766-
return op.emitOpError("incorrect number of results for callee");
764+
if (fnType.getNumResults() != getNumResults())
765+
return emitOpError("incorrect number of results for callee");
767766

768767
for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
769-
if (op.getResult(i).getType() != fnType.getResult(i))
770-
return op.emitOpError("result type mismatch");
768+
if (getResult(i).getType() != fnType.getResult(i))
769+
return emitOpError("result type mismatch");
771770

772771
return success();
773772
}

mlir/lib/IR/SymbolTable.cpp

Lines changed: 58 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,30 @@ collectValidReferencesFor(Operation *symbol, StringRef symbolName,
6868
return success();
6969
}
7070

71+
/// Walk all of the operations within the given set of regions, without
72+
/// traversing into any nested symbol tables. Stops walking if the result of the
73+
/// callback is anything other than `WalkResult::advance`.
74+
static Optional<WalkResult>
75+
walkSymbolTable(MutableArrayRef<Region> regions,
76+
function_ref<Optional<WalkResult>(Operation *)> callback) {
77+
SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions));
78+
while (!worklist.empty()) {
79+
for (Operation &op : worklist.pop_back_val()->getOps()) {
80+
Optional<WalkResult> result = callback(&op);
81+
if (result != WalkResult::advance())
82+
return result;
83+
84+
// If this op defines a new symbol table scope, we can't traverse. Any
85+
// symbol references nested within 'op' are different semantically.
86+
if (!op.hasTrait<OpTrait::SymbolTable>()) {
87+
for (Region &region : op.getRegions())
88+
worklist.push_back(&region);
89+
}
90+
}
91+
}
92+
return WalkResult::advance();
93+
}
94+
7195
//===----------------------------------------------------------------------===//
7296
// SymbolTable
7397
//===----------------------------------------------------------------------===//
@@ -347,7 +371,18 @@ LogicalResult detail::verifySymbolTable(Operation *op) {
347371
.append("see existing symbol definition here");
348372
}
349373
}
350-
return success();
374+
375+
// Verify any nested symbol user operations.
376+
SymbolTableCollection symbolTable;
377+
auto verifySymbolUserFn = [&](Operation *op) -> Optional<WalkResult> {
378+
if (SymbolUserOpInterface user = dyn_cast<SymbolUserOpInterface>(op))
379+
return WalkResult(user.verifySymbolUses(symbolTable));
380+
return WalkResult::advance();
381+
};
382+
383+
Optional<WalkResult> result =
384+
walkSymbolTable(op->getRegions(), verifySymbolUserFn);
385+
return success(result && !result->wasInterrupted());
351386
}
352387

353388
LogicalResult detail::verifySymbol(Operation *op) {
@@ -452,25 +487,13 @@ static WalkResult walkSymbolRefs(
452487
static Optional<WalkResult> walkSymbolUses(
453488
MutableArrayRef<Region> regions,
454489
function_ref<WalkResult(SymbolTable::SymbolUse, ArrayRef<int>)> callback) {
455-
SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions));
456-
while (!worklist.empty()) {
457-
for (Operation &op : worklist.pop_back_val()->getOps()) {
458-
if (walkSymbolRefs(&op, callback).wasInterrupted())
459-
return WalkResult::interrupt();
460-
461-
// Check that this isn't a potentially unknown symbol table.
462-
if (isPotentiallyUnknownSymbolTable(&op))
463-
return llvm::None;
490+
return walkSymbolTable(regions, [&](Operation *op) -> Optional<WalkResult> {
491+
// Check that this isn't a potentially unknown symbol table.
492+
if (isPotentiallyUnknownSymbolTable(op))
493+
return llvm::None;
464494

465-
// If this op defines a new symbol table scope, we can't traverse. Any
466-
// symbol references nested within 'op' are different semantically.
467-
if (!op.hasTrait<OpTrait::SymbolTable>()) {
468-
for (Region &region : op.getRegions())
469-
worklist.push_back(&region);
470-
}
471-
}
472-
}
473-
return WalkResult::advance();
495+
return walkSymbolRefs(op, callback);
496+
});
474497
}
475498
/// Walk all of the uses, for any symbol, that are nested within the given
476499
/// operation 'from', invoking the provided callback for each. This does not
@@ -927,6 +950,22 @@ SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
927950
return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
928951
}
929952

953+
/// Returns the operation registered with the given symbol name within the
954+
/// closest parent operation of, or including, 'from' with the
955+
/// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
956+
/// found.
957+
Operation *SymbolTableCollection::lookupNearestSymbolFrom(Operation *from,
958+
StringRef symbol) {
959+
Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
960+
return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
961+
}
962+
Operation *
963+
SymbolTableCollection::lookupNearestSymbolFrom(Operation *from,
964+
SymbolRefAttr symbol) {
965+
Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
966+
return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
967+
}
968+
930969
/// Lookup, or create, a symbol table for an operation.
931970
SymbolTable &SymbolTableCollection::getSymbolTable(Operation *op) {
932971
auto it = symbolTables.try_emplace(op, nullptr);

0 commit comments

Comments
 (0)