Skip to content

Commit 87a0479

Browse files
committed
[mlir][llvm] Fuse access_group & loop export (NFC)
This commit moves the access group translation into the LoopAnnotationTranslation class as these two metadata kinds only appear together. Drops the access group cleanup from `ModuleTranslation::forgetMapping` as this is only used on function regions. Access groups only appear in the region of a global metadata operation and will thus not be cleaned here. Analogous to https://reviews.llvm.org/D143577 Reviewed By: gysit Differential Revision: https://reviews.llvm.org/D144253
1 parent cf4df61 commit 87a0479

File tree

4 files changed

+83
-66
lines changed

4 files changed

+83
-66
lines changed

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,6 @@ class ModuleTranslation {
120120
/// in these blocks.
121121
void forgetMapping(Region &region);
122122

123-
/// Returns the LLVM metadata corresponding to a symbol reference to an mlir
124-
/// LLVM dialect access group operation.
125-
llvm::MDNode *getAccessGroup(Operation *op,
126-
SymbolRefAttr accessGroupRef) const;
127-
128123
/// Returns the LLVM metadata corresponding to a symbol reference to an mlir
129124
/// LLVM dialect alias scope operation
130125
llvm::MDNode *getAliasScope(Operation *op, SymbolRefAttr aliasScopeRef) const;
@@ -332,11 +327,6 @@ class ModuleTranslation {
332327
/// values after all operations are converted.
333328
DenseMap<Operation *, llvm::Instruction *> branchMapping;
334329

335-
/// Mapping from an access group metadata operation to its LLVM metadata.
336-
/// This map is populated on module entry and is used to annotate loops (as
337-
/// identified via their branches) and contained memory accesses.
338-
DenseMap<Operation *, llvm::MDNode *> accessGroupMetadataMapping;
339-
340330
/// Mapping from an alias scope metadata operation to its LLVM metadata.
341331
/// This map is populated on module entry.
342332
DenseMap<Operation *, llvm::MDNode *> aliasScopeMetadataMapping;

mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.cpp

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,11 @@ using namespace mlir::LLVM::detail;
1515
namespace {
1616
/// Helper class that keeps the state of one attribute to metadata conversion.
1717
struct LoopAnnotationConversion {
18-
LoopAnnotationConversion(LoopAnnotationAttr attr,
19-
ModuleTranslation &moduleTranslation, Operation *op,
20-
LoopAnnotationTranslation &loopAnnotationTranslation)
21-
: attr(attr), moduleTranslation(moduleTranslation), op(op),
22-
loopAnnotationTranslation(loopAnnotationTranslation),
23-
ctx(moduleTranslation.getLLVMContext()) {}
18+
LoopAnnotationConversion(LoopAnnotationAttr attr, Operation *op,
19+
LoopAnnotationTranslation &loopAnnotationTranslation,
20+
llvm::LLVMContext &ctx)
21+
: attr(attr), op(op),
22+
loopAnnotationTranslation(loopAnnotationTranslation), ctx(ctx) {}
2423

2524
/// Converts this struct's loop annotation into a corresponding LLVMIR
2625
/// metadata representation.
@@ -46,7 +45,6 @@ struct LoopAnnotationConversion {
4645
void convertLoopOptions(LoopUnswitchAttr options);
4746

4847
LoopAnnotationAttr attr;
49-
ModuleTranslation &moduleTranslation;
5048
Operation *op;
5149
LoopAnnotationTranslation &loopAnnotationTranslation;
5250
llvm::LLVMContext &ctx;
@@ -95,7 +93,8 @@ void LoopAnnotationConversion::convertFollowupNode(StringRef name,
9593
if (!attr)
9694
return;
9795

98-
llvm::MDNode *node = loopAnnotationTranslation.translate(attr, op);
96+
llvm::MDNode *node =
97+
loopAnnotationTranslation.translateLoopAnnotation(attr, op);
9998

10099
metadataNodes.push_back(
101100
llvm::MDNode::get(ctx, {llvm::MDString::get(ctx, name), node}));
@@ -225,7 +224,7 @@ llvm::MDNode *LoopAnnotationConversion::convert() {
225224
llvm::MDString::get(ctx, "llvm.loop.parallel_accesses"));
226225
for (SymbolRefAttr accessGroupRef : parallelAccessGroups)
227226
parallelAccess.push_back(
228-
moduleTranslation.getAccessGroup(op, accessGroupRef));
227+
loopAnnotationTranslation.getAccessGroup(op, accessGroupRef));
229228
metadataNodes.push_back(llvm::MDNode::get(ctx, parallelAccess));
230229
}
231230

@@ -236,7 +235,8 @@ llvm::MDNode *LoopAnnotationConversion::convert() {
236235
return loopMD;
237236
}
238237

239-
llvm::MDNode *LoopAnnotationTranslation::translate(LoopAnnotationAttr attr,
238+
llvm::MDNode *
239+
LoopAnnotationTranslation::translateLoopAnnotation(LoopAnnotationAttr attr,
240240
Operation *op) {
241241
if (!attr)
242242
return nullptr;
@@ -246,9 +246,47 @@ llvm::MDNode *LoopAnnotationTranslation::translate(LoopAnnotationAttr attr,
246246
return loopMD;
247247

248248
loopMD =
249-
LoopAnnotationConversion(attr, moduleTranslation, op, *this).convert();
249+
LoopAnnotationConversion(attr, op, *this, this->llvmModule.getContext())
250+
.convert();
250251
// Store a map from this Attribute to the LLVM metadata in case we
251252
// encounter it again.
252253
mapLoopMetadata(attr, loopMD);
253254
return loopMD;
254255
}
256+
257+
LogicalResult LoopAnnotationTranslation::createAccessGroupMetadata() {
258+
mlirModule->walk([&](LLVM::MetadataOp metadatas) {
259+
metadatas.walk([&](LLVM::AccessGroupMetadataOp op) {
260+
llvm::MDNode *accessGroup =
261+
llvm::MDNode::getDistinct(llvmModule.getContext(), {});
262+
accessGroupMetadataMapping.insert({op, accessGroup});
263+
});
264+
});
265+
return success();
266+
}
267+
268+
llvm::MDNode *
269+
LoopAnnotationTranslation::getAccessGroup(Operation *op,
270+
SymbolRefAttr accessGroupRef) const {
271+
auto metadataName = accessGroupRef.getRootReference();
272+
auto accessGroupName = accessGroupRef.getLeafReference();
273+
auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
274+
op->getParentOp(), metadataName);
275+
auto *accessGroupOp =
276+
SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
277+
return accessGroupMetadataMapping.lookup(accessGroupOp);
278+
}
279+
280+
llvm::MDNode *
281+
LoopAnnotationTranslation::getAccessGroups(Operation *op,
282+
ArrayAttr accessGroupRefs) const {
283+
if (!accessGroupRefs || accessGroupRefs.empty())
284+
return nullptr;
285+
286+
SmallVector<llvm::Metadata *> groupMDs;
287+
for (SymbolRefAttr groupRef : accessGroupRefs.getAsRange<SymbolRefAttr>())
288+
groupMDs.push_back(getAccessGroup(op, groupRef));
289+
if (groupMDs.size() == 1)
290+
return llvm::cast<llvm::MDNode>(groupMDs.front());
291+
return llvm::MDNode::get(llvmModule.getContext(), groupMDs);
292+
}

mlir/lib/Target/LLVMIR/LoopAnnotationTranslation.h

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,28 @@ namespace mlir {
2121
namespace LLVM {
2222
namespace detail {
2323

24-
/// A helper class that converts a LoopAnnotationAttr into a corresponding
25-
/// llvm::MDNode.
24+
/// A helper class that converts LoopAnnotationAttrs and AccessGroupMetadataOps
25+
/// into a corresponding llvm::MDNodes.
2626
class LoopAnnotationTranslation {
2727
public:
28-
LoopAnnotationTranslation(LLVM::ModuleTranslation &moduleTranslation)
29-
: moduleTranslation(moduleTranslation) {}
28+
LoopAnnotationTranslation(Operation *mlirModule, llvm::Module &llvmModule)
29+
: mlirModule(mlirModule), llvmModule(llvmModule) {}
3030

31-
llvm::MDNode *translate(LoopAnnotationAttr attr, Operation *op);
31+
llvm::MDNode *translateLoopAnnotation(LoopAnnotationAttr attr, Operation *op);
32+
33+
/// Traverses the global access group metadata operation in the `mlirModule`
34+
/// and creates corresponding LLVM metadata nodes.
35+
LogicalResult createAccessGroupMetadata();
36+
37+
/// Returns the LLVM metadata corresponding to a symbol reference to an mlir
38+
/// LLVM dialect access group operation.
39+
llvm::MDNode *getAccessGroup(Operation *op,
40+
SymbolRefAttr accessGroupRef) const;
41+
42+
/// Returns the LLVM metadata corresponding to a list of symbol reference to
43+
/// an mlir LLVM dialect access group operation. Returns nullptr if
44+
/// `accessGroupRefs` is null or empty.
45+
llvm::MDNode *getAccessGroups(Operation *op, ArrayAttr accessGroupRefs) const;
3246

3347
private:
3448
/// Returns the LLVM metadata corresponding to a llvm loop metadata attribute.
@@ -47,7 +61,13 @@ class LoopAnnotationTranslation {
4761
/// The metadata is attached to Latch block branches with this attribute.
4862
DenseMap<Attribute, llvm::MDNode *> loopMetadataMapping;
4963

50-
LLVM::ModuleTranslation &moduleTranslation;
64+
/// Mapping from an access group metadata operation to its LLVM metadata.
65+
/// This map is populated on module entry and is used to annotate loops (as
66+
/// identified via their branches) and contained memory accesses.
67+
DenseMap<Operation *, llvm::MDNode *> accessGroupMetadataMapping;
68+
69+
Operation *mlirModule;
70+
llvm::Module &llvmModule;
5171
};
5272

5373
} // namespace detail

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 8 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -421,8 +421,8 @@ ModuleTranslation::ModuleTranslation(Operation *module,
421421
: mlirModule(module), llvmModule(std::move(llvmModule)),
422422
debugTranslation(
423423
std::make_unique<DebugTranslation>(module, *this->llvmModule)),
424-
loopAnnotationTranslation(
425-
std::make_unique<LoopAnnotationTranslation>(*this)),
424+
loopAnnotationTranslation(std::make_unique<LoopAnnotationTranslation>(
425+
module, *this->llvmModule)),
426426
typeTranslator(this->llvmModule->getContext()),
427427
iface(module->getContext()) {
428428
assert(satisfiesLLVMModule(mlirModule) &&
@@ -449,7 +449,6 @@ void ModuleTranslation::forgetMapping(Region &region) {
449449
branchMapping.erase(&op);
450450
if (isa<LLVM::GlobalOp>(op))
451451
globalsMapping.erase(&op);
452-
accessGroupMetadataMapping.erase(&op);
453452
llvm::append_range(
454453
toProcess,
455454
llvm::map_range(op.getRegions(), [](Region &r) { return &r; }));
@@ -994,47 +993,16 @@ LogicalResult ModuleTranslation::convertFunctions() {
994993
return success();
995994
}
996995

997-
llvm::MDNode *
998-
ModuleTranslation::getAccessGroup(Operation *op,
999-
SymbolRefAttr accessGroupRef) const {
1000-
auto metadataName = accessGroupRef.getRootReference();
1001-
auto accessGroupName = accessGroupRef.getLeafReference();
1002-
auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
1003-
op->getParentOp(), metadataName);
1004-
auto *accessGroupOp =
1005-
SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
1006-
return accessGroupMetadataMapping.lookup(accessGroupOp);
1007-
}
1008-
1009996
LogicalResult ModuleTranslation::createAccessGroupMetadata() {
1010-
mlirModule->walk([&](LLVM::MetadataOp metadatas) {
1011-
metadatas.walk([&](LLVM::AccessGroupMetadataOp op) {
1012-
llvm::LLVMContext &ctx = llvmModule->getContext();
1013-
llvm::MDNode *accessGroup = llvm::MDNode::getDistinct(ctx, {});
1014-
accessGroupMetadataMapping.insert({op, accessGroup});
1015-
});
1016-
});
1017-
return success();
997+
return loopAnnotationTranslation->createAccessGroupMetadata();
1018998
}
1019999

10201000
void ModuleTranslation::setAccessGroupsMetadata(Operation *op,
10211001
llvm::Instruction *inst) {
10221002
auto populateGroupsMetadata = [&](ArrayAttr groupRefs) {
1023-
if (!groupRefs || groupRefs.empty())
1024-
return;
1025-
1026-
llvm::Module *module = inst->getModule();
1027-
SmallVector<llvm::Metadata *> groupMDs;
1028-
for (SymbolRefAttr groupRef : groupRefs.getAsRange<SymbolRefAttr>())
1029-
groupMDs.push_back(getAccessGroup(op, groupRef));
1030-
1031-
llvm::MDNode *node = nullptr;
1032-
if (groupMDs.size() == 1)
1033-
node = llvm::cast<llvm::MDNode>(groupMDs.front());
1034-
else if (groupMDs.size() >= 2)
1035-
node = llvm::MDNode::get(module->getContext(), groupMDs);
1036-
1037-
inst->setMetadata(llvm::LLVMContext::MD_access_group, node);
1003+
if (llvm::MDNode *node =
1004+
loopAnnotationTranslation->getAccessGroups(op, groupRefs))
1005+
inst->setMetadata(llvm::LLVMContext::MD_access_group, node);
10381006
};
10391007

10401008
auto groupRefs =
@@ -1250,7 +1218,8 @@ void ModuleTranslation::setLoopMetadata(Operation *op,
12501218
[](auto branchOp) { return branchOp.getLoopAnnotationAttr(); });
12511219
if (!attr)
12521220
return;
1253-
llvm::MDNode *loopMD = loopAnnotationTranslation->translate(attr, op);
1221+
llvm::MDNode *loopMD =
1222+
loopAnnotationTranslation->translateLoopAnnotation(attr, op);
12541223
inst->setMetadata(llvm::LLVMContext::MD_loop, loopMD);
12551224
}
12561225

0 commit comments

Comments
 (0)