Skip to content

Commit fef08da

Browse files
committed
[mlir][llvm] Store memory op metadata using op attributes.
The revision introduces operation attributes to store tbaa metadata on load and store operations rather than relying using dialect attributes. At the same time, the change also ensures the provided getters and setters instead are used instead of a string based lookup. The latter is done for the tbaa, access groups, and alias scope attributes. The goal of this change is to ensure the metadata attributes are only placed on operations that have the corresponding operation attributes. This is imported since only these operations later on translate these attributes to LLVM IR. Dialect attributes placed on other operations are lost during the translation. Reviewed By: vzakhari, Dinistro Differential Revision: https://reviews.llvm.org/D143654
1 parent 067a5c6 commit fef08da

File tree

12 files changed

+218
-176
lines changed

12 files changed

+218
-176
lines changed

flang/lib/Optimizer/CodeGen/TBAABuilder.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "TBAABuilder.h"
1414
#include "flang/Optimizer/Dialect/FIRType.h"
15+
#include "llvm/ADT/TypeSwitch.h"
1516
#include "llvm/Support/CommandLine.h"
1617
#include "llvm/Support/Debug.h"
1718

@@ -159,9 +160,13 @@ void TBAABuilder::attachTBAATag(Operation *op, Type baseFIRType,
159160
else
160161
tbaaTagSym = getDataAccessTag(baseFIRType, accessFIRType, gep);
161162

162-
if (tbaaTagSym)
163-
op->setAttr(LLVMDialect::getTBAAAttrName(),
164-
ArrayAttr::get(op->getContext(), tbaaTagSym));
163+
if (!tbaaTagSym)
164+
return;
165+
166+
auto tbaaAttr = ArrayAttr::get(op->getContext(), tbaaTagSym);
167+
llvm::TypeSwitch<Operation *>(op)
168+
.Case<LoadOp, StoreOp>([&](auto memOp) { memOp.setTbaaAttr(tbaaAttr); })
169+
.Default([](auto) { llvm_unreachable("expected LoadOp or StoreOp"); });
165170
}
166171

167172
} // namespace fir

flang/test/Fir/tbaa.fir

Lines changed: 29 additions & 29 deletions
Large diffs are not rendered by default.

mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,7 @@ def LLVM_Dialect : Dialect {
3535
let extraClassDeclaration = [{
3636
/// Name of the data layout attributes.
3737
static StringRef getDataLayoutAttrName() { return "llvm.data_layout"; }
38-
static StringRef getNoAliasScopesAttrName() { return "noalias_scopes"; }
39-
static StringRef getAliasScopesAttrName() { return "alias_scopes"; }
4038
static StringRef getLoopAttrName() { return "llvm.loop"; }
41-
static StringRef getAccessGroupsAttrName() { return "access_groups"; }
42-
static StringRef getTBAAAttrName() { return "llvm.tbaa"; }
4339

4440
/// Names of llvm parameter attributes.
4541
static StringRef getAlignAttrName() { return "llvm.align"; }

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpBase {
350350
OptionalAttr<SymbolRefArrayAttr>:$access_groups,
351351
OptionalAttr<SymbolRefArrayAttr>:$alias_scopes,
352352
OptionalAttr<SymbolRefArrayAttr>:$noalias_scopes,
353+
OptionalAttr<SymbolRefArrayAttr>:$tbaa,
353354
OptionalAttr<I64Attr>:$alignment, UnitAttr:$volatile_,
354355
UnitAttr:$nontemporal);
355356
let results = (outs LLVM_LoadableType:$res);
@@ -390,6 +391,7 @@ def LLVM_StoreOp : LLVM_Op<"store">, MemoryOpBase {
390391
OptionalAttr<SymbolRefArrayAttr>:$access_groups,
391392
OptionalAttr<SymbolRefArrayAttr>:$alias_scopes,
392393
OptionalAttr<SymbolRefArrayAttr>:$noalias_scopes,
394+
OptionalAttr<SymbolRefArrayAttr>:$tbaa,
393395
OptionalAttr<I64Attr>:$alignment, UnitAttr:$volatile_,
394396
UnitAttr:$nontemporal);
395397
string llvmInstName = "Store";

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -120,15 +120,14 @@ class ModuleTranslation {
120120
/// in these blocks.
121121
void forgetMapping(Region &region);
122122

123-
/// Returns the LLVM metadata corresponding to a reference to an mlir LLVM
124-
/// dialect access group operation.
125-
llvm::MDNode *getAccessGroup(Operation &opInst,
123+
/// Returns the LLVM metadata corresponding to a symbol reference to an mlir
124+
/// LLVM dialect access group operation.
125+
llvm::MDNode *getAccessGroup(Operation *op,
126126
SymbolRefAttr accessGroupRef) const;
127127

128-
/// Returns the LLVM metadata corresponding to a reference to an mlir LLVM
129-
/// dialect alias scope operation
130-
llvm::MDNode *getAliasScope(Operation &opInst,
131-
SymbolRefAttr aliasScopeRef) const;
128+
/// Returns the LLVM metadata corresponding to a symbol reference to an mlir
129+
/// LLVM dialect alias scope operation
130+
llvm::MDNode *getAliasScope(Operation *op, SymbolRefAttr aliasScopeRef) const;
132131

133132
// Sets LLVM metadata for memory operations that are in a parallel loop.
134133
void setAccessGroupsMetadata(Operation *op, llvm::Instruction *inst);
@@ -287,9 +286,9 @@ class ModuleTranslation {
287286
/// metadata nodes for them and their domains.
288287
LogicalResult createAliasScopeMetadata();
289288

290-
/// Returns the LLVM metadata corresponding to a reference to an mlir LLVM
291-
/// dialect TBAATagOp operation.
292-
llvm::MDNode *getTBAANode(Operation &memOp, SymbolRefAttr tagRef) const;
289+
/// Returns the LLVM metadata corresponding to a symbol reference to an mlir
290+
/// LLVM dialect TBAATagOp operation.
291+
llvm::MDNode *getTBAANode(Operation *op, SymbolRefAttr tagRef) const;
293292

294293
/// Process tbaa LLVM Metadata operations and create LLVM
295294
/// metadata nodes for them.

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 61 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -668,53 +668,60 @@ Type LLVM::GEPOp::getSourceElementType() {
668668
// Builder, printer and parser for for LLVM::LoadOp.
669669
//===----------------------------------------------------------------------===//
670670

671-
LogicalResult verifySymbolAttribute(
672-
Operation *op, StringRef attributeName,
671+
/// Verifies the given array attribute contains symbol references and checks the
672+
/// referenced symbol types using the provided verification function.
673+
LogicalResult verifyMemOpSymbolRefs(
674+
Operation *op, StringRef name, ArrayAttr symbolRefs,
673675
llvm::function_ref<LogicalResult(Operation *, SymbolRefAttr)>
674676
verifySymbolType) {
675-
if (Attribute attribute = op->getAttr(attributeName)) {
676-
// Verify that the attribute is a symbol ref array attribute,
677-
// because this constraint is not verified for all attribute
678-
// names processed here (e.g. 'tbaa'). This verification
679-
// is redundant in some cases.
680-
if (!(attribute.isa<ArrayAttr>() &&
681-
llvm::all_of(attribute.cast<ArrayAttr>(), [&](Attribute attr) {
682-
return attr && attr.isa<SymbolRefAttr>();
683-
})))
684-
return op->emitOpError("attribute '")
685-
<< attributeName
686-
<< "' failed to satisfy constraint: symbol ref array attribute";
687-
688-
for (SymbolRefAttr symbolRef :
689-
attribute.cast<ArrayAttr>().getAsRange<SymbolRefAttr>()) {
690-
StringAttr metadataName = symbolRef.getRootReference();
691-
StringAttr symbolName = symbolRef.getLeafReference();
692-
// We want @metadata::@symbol, not just @symbol
693-
if (metadataName == symbolName) {
694-
return op->emitOpError() << "expected '" << symbolRef
695-
<< "' to specify a fully qualified reference";
696-
}
697-
auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
698-
op->getParentOp(), metadataName);
699-
if (!metadataOp)
700-
return op->emitOpError()
701-
<< "expected '" << symbolRef << "' to reference a metadata op";
702-
Operation *symbolOp =
703-
SymbolTable::lookupNearestSymbolFrom(metadataOp, symbolName);
704-
if (!symbolOp)
705-
return op->emitOpError()
706-
<< "expected '" << symbolRef << "' to be a valid reference";
707-
if (failed(verifySymbolType(symbolOp, symbolRef))) {
708-
return failure();
709-
}
677+
assert(symbolRefs && "expected a non-null attribute");
678+
679+
// Verify that the attribute is a symbol ref array attribute,
680+
// because this constraint is not verified for all attribute
681+
// names processed here (e.g. 'tbaa'). This verification
682+
// is redundant in some cases.
683+
if (!llvm::all_of(symbolRefs, [](Attribute attr) {
684+
return attr && attr.isa<SymbolRefAttr>();
685+
}))
686+
return op->emitOpError("attribute '")
687+
<< name
688+
<< "' failed to satisfy constraint: symbol ref array attribute";
689+
690+
for (SymbolRefAttr symbolRef : symbolRefs.getAsRange<SymbolRefAttr>()) {
691+
StringAttr metadataName = symbolRef.getRootReference();
692+
StringAttr symbolName = symbolRef.getLeafReference();
693+
// We want @metadata::@symbol, not just @symbol
694+
if (metadataName == symbolName) {
695+
return op->emitOpError() << "expected '" << symbolRef
696+
<< "' to specify a fully qualified reference";
697+
}
698+
auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
699+
op->getParentOp(), metadataName);
700+
if (!metadataOp)
701+
return op->emitOpError()
702+
<< "expected '" << symbolRef << "' to reference a metadata op";
703+
Operation *symbolOp =
704+
SymbolTable::lookupNearestSymbolFrom(metadataOp, symbolName);
705+
if (!symbolOp)
706+
return op->emitOpError()
707+
<< "expected '" << symbolRef << "' to be a valid reference";
708+
if (failed(verifySymbolType(symbolOp, symbolRef))) {
709+
return failure();
710710
}
711711
}
712+
712713
return success();
713714
}
714715

715-
// Verifies that metadata ops are wired up properly.
716+
/// Verifies the given array attribute contains symbol references that point to
717+
/// metadata operations of the given type.
716718
template <typename OpTy>
717-
static LogicalResult verifyOpMetadata(Operation *op, StringRef attributeName) {
719+
static LogicalResult
720+
verifyMemOpSymbolRefsPointTo(Operation *op, StringRef name,
721+
std::optional<ArrayAttr> symbolRefs) {
722+
if (!symbolRefs)
723+
return success();
724+
718725
auto verifySymbolType = [op](Operation *symbolOp,
719726
SymbolRefAttr symbolRef) -> LogicalResult {
720727
if (!isa<OpTy>(symbolOp)) {
@@ -724,35 +731,33 @@ static LogicalResult verifyOpMetadata(Operation *op, StringRef attributeName) {
724731
}
725732
return success();
726733
};
727-
728-
return verifySymbolAttribute(op, attributeName, verifySymbolType);
734+
return verifyMemOpSymbolRefs(op, name, *symbolRefs, verifySymbolType);
729735
}
730736

731-
static LogicalResult verifyMemoryOpMetadata(Operation *op) {
732-
// access_groups
733-
if (failed(verifyOpMetadata<LLVM::AccessGroupMetadataOp>(
734-
op, LLVMDialect::getAccessGroupsAttrName())))
737+
/// Verifies the types of the metadata operations referenced by aliasing and
738+
/// access group metadata.
739+
template <typename OpTy>
740+
LogicalResult verifyMemOpMetadata(OpTy memOp) {
741+
if (failed(verifyMemOpSymbolRefsPointTo<LLVM::AccessGroupMetadataOp>(
742+
memOp, memOp.getAccessGroupsAttrName(), memOp.getAccessGroups())))
735743
return failure();
736744

737-
// alias_scopes
738-
if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>(
739-
op, LLVMDialect::getAliasScopesAttrName())))
745+
if (failed(verifyMemOpSymbolRefsPointTo<LLVM::AliasScopeMetadataOp>(
746+
memOp, memOp.getAliasScopesAttrName(), memOp.getAliasScopes())))
740747
return failure();
741748

742-
// noalias_scopes
743-
if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>(
744-
op, LLVMDialect::getNoAliasScopesAttrName())))
749+
if (failed(verifyMemOpSymbolRefsPointTo<LLVM::AliasScopeMetadataOp>(
750+
memOp, memOp.getNoaliasScopesAttrName(), memOp.getNoaliasScopes())))
745751
return failure();
746752

747-
// tbaa
748-
if (failed(verifyOpMetadata<LLVM::TBAATagOp>(op,
749-
LLVMDialect::getTBAAAttrName())))
753+
if (failed(verifyMemOpSymbolRefsPointTo<LLVM::TBAATagOp>(
754+
memOp, memOp.getTbaaAttrName(), memOp.getTbaa())))
750755
return failure();
751756

752757
return success();
753758
}
754759

755-
LogicalResult LoadOp::verify() { return verifyMemoryOpMetadata(*this); }
760+
LogicalResult LoadOp::verify() { return verifyMemOpMetadata(*this); }
756761

757762
void LoadOp::build(OpBuilder &builder, OperationState &result, Type t,
758763
Value addr, unsigned alignment, bool isVolatile,
@@ -828,7 +833,7 @@ ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
828833
// Builder, printer and parser for LLVM::StoreOp.
829834
//===----------------------------------------------------------------------===//
830835

831-
LogicalResult StoreOp::verify() { return verifyMemoryOpMetadata(*this); }
836+
LogicalResult StoreOp::verify() { return verifyMemOpMetadata(*this); }
832837

833838
void StoreOp::build(OpBuilder &builder, OperationState &result, Value value,
834839
Value addr, unsigned alignment, bool isVolatile,

mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.cpp

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,30 @@ static ArrayRef<unsigned> getSupportedMetadataImpl() {
7676
return convertibleMetadata;
7777
}
7878

79+
namespace {
80+
/// Helper class to attach metadata attributes to specific operation types. It
81+
/// specializes TypeSwitch to take an Operation and return a LogicalResult.
82+
template <typename... OpTys>
83+
struct AttributeSetter {
84+
AttributeSetter(Operation *op) : op(op) {}
85+
86+
/// Calls `attachFn` on the provided Operation if it has one of
87+
/// the given operation types. Returns failure otherwise.
88+
template <typename CallableT>
89+
LogicalResult apply(CallableT &&attachFn) {
90+
return llvm::TypeSwitch<Operation *, LogicalResult>(op)
91+
.Case<OpTys...>([&attachFn](auto concreteOp) {
92+
attachFn(concreteOp);
93+
return success();
94+
})
95+
.Default([&](auto) { return failure(); });
96+
}
97+
98+
private:
99+
Operation *op;
100+
};
101+
} // namespace
102+
79103
/// Converts the given profiling metadata `node` to an MLIR profiling attribute
80104
/// and attaches it to the imported operation if the translation succeeds.
81105
/// Returns failure otherwise.
@@ -129,16 +153,10 @@ static LogicalResult setProfilingAttr(OpBuilder &builder, llvm::MDNode *node,
129153
branchWeights.push_back(branchWeight->getZExtValue());
130154
}
131155

132-
// Attach the branch weights to the operations that support it.
133-
return llvm::TypeSwitch<Operation *, LogicalResult>(op)
134-
.Case<CondBrOp, SwitchOp, CallOp, InvokeOp>([&](auto branchWeightOp) {
156+
return AttributeSetter<CondBrOp, SwitchOp, CallOp, InvokeOp>(op).apply(
157+
[&](auto branchWeightOp) {
135158
branchWeightOp.setBranchWeightsAttr(
136159
builder.getI32VectorAttr(branchWeights));
137-
return success();
138-
})
139-
.Default([op](auto) {
140-
return op->emitWarning()
141-
<< op->getName() << " does not support branch weights";
142160
});
143161
}
144162

@@ -151,9 +169,9 @@ static LogicalResult setTBAAAttr(const llvm::MDNode *node, Operation *op,
151169
if (!tbaaTagSym)
152170
return failure();
153171

154-
op->setAttr(LLVMDialect::getTBAAAttrName(),
155-
ArrayAttr::get(op->getContext(), tbaaTagSym));
156-
return success();
172+
return AttributeSetter<LoadOp, StoreOp>(op).apply([&](auto memOp) {
173+
memOp.setTbaaAttr(ArrayAttr::get(memOp.getContext(), tbaaTagSym));
174+
});
157175
}
158176

159177
/// Looks up all the symbol references pointing to the access group operations
@@ -169,9 +187,10 @@ static LogicalResult setAccessGroupAttr(const llvm::MDNode *node, Operation *op,
169187

170188
SmallVector<Attribute> accessGroupAttrs(accessGroups->begin(),
171189
accessGroups->end());
172-
op->setAttr(LLVMDialect::getAccessGroupsAttrName(),
173-
ArrayAttr::get(op->getContext(), accessGroupAttrs));
174-
return success();
190+
return AttributeSetter<LoadOp, StoreOp>(op).apply([&](auto memOp) {
191+
memOp.setAccessGroupsAttr(
192+
ArrayAttr::get(memOp.getContext(), accessGroupAttrs));
193+
});
175194
}
176195

177196
/// Converts the given loop metadata node to an MLIR loop annotation attribute

mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ llvm::MDNode *LoopAnnotationConversion::convert() {
210210
llvm::MDString::get(ctx, "llvm.loop.parallel_accesses"));
211211
for (SymbolRefAttr accessGroupRef : parallelAccessGroups)
212212
parallelAccess.push_back(
213-
moduleTranslation.getAccessGroup(*op, accessGroupRef));
213+
moduleTranslation.getAccessGroup(op, accessGroupRef));
214214
metadataNodes.push_back(llvm::MDNode::get(ctx, parallelAccess));
215215
}
216216

0 commit comments

Comments
 (0)