Skip to content

Commit d25e91d

Browse files
nimiwioftynse
authored andcommitted
Support alias.scope and noalias metadata
Introduces new Ops to represent 1. alias.scope metadata in LLVM, and 2. domains for these scopes. These correspond to the metadata described in https://llvm.org/docs/LangRef.html#noalias-and-alias-scope-metadata. Lists of scopes are modeled the same way as access groups - as an ArrayAttr on the Op (added in https://reviews.llvm.org/D97944). Lowering 'noalias' attributes on function parameters is already supported. However, lowering `noalias` metadata on individual Ops is not, which is added in this change. LLVM uses the same keyword for these, but this change introduces a separate attribute name 'noalias_scopes' to represent this distinct concept. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D107870
1 parent 87dd519 commit d25e91d

File tree

7 files changed

+345
-23
lines changed

7 files changed

+345
-23
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def LLVM_Dialect : Dialect {
3232
static StringRef getDataLayoutAttrName() { return "llvm.data_layout"; }
3333
static StringRef getAlignAttrName() { return "llvm.align"; }
3434
static StringRef getNoAliasAttrName() { return "llvm.noalias"; }
35+
static StringRef getNoAliasScopesAttrName() { return "noalias_scopes"; }
36+
static StringRef getAliasScopesAttrName() { return "alias_scopes"; }
3537
static StringRef getLoopAttrName() { return "llvm.loop"; }
3638
static StringRef getParallelAccessAttrName() { return "parallel_access"; }
3739
static StringRef getLoopOptionsAttrName() { return "options"; }

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 78 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,10 @@ class MemoryOpWithAlignmentAndAttributes : MemoryOpWithAlignmentBase {
286286
code setAccessGroupsMetadataCode = [{
287287
moduleTranslation.setAccessGroupsMetadata(op, inst);
288288
}];
289+
290+
code setAliasScopeMetadataCode = [{
291+
moduleTranslation.setAliasScopeMetadata(op, inst);
292+
}];
289293
}
290294

291295
// Memory-related operations.
@@ -329,13 +333,19 @@ def LLVM_GEPOp
329333
def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes {
330334
let arguments = (ins LLVM_PointerTo<LLVM_LoadableType>:$addr,
331335
OptionalAttr<SymbolRefArrayAttr>:$access_groups,
336+
OptionalAttr<SymbolRefArrayAttr>:$alias_scopes,
337+
OptionalAttr<SymbolRefArrayAttr>:$noalias_scopes,
332338
OptionalAttr<I64Attr>:$alignment, UnitAttr:$volatile_,
333339
UnitAttr:$nontemporal);
334340
let results = (outs LLVM_LoadableType:$res);
335341
string llvmBuilder = [{
336342
auto *inst = builder.CreateLoad(
337343
$addr->getType()->getPointerElementType(), $addr, $volatile_);
338-
}] # setAlignmentCode # setNonTemporalMetadataCode # setAccessGroupsMetadataCode # [{
344+
}] # setAlignmentCode
345+
# setNonTemporalMetadataCode
346+
# setAccessGroupsMetadataCode
347+
# setAliasScopeMetadataCode
348+
# [{
339349
$res = inst;
340350
}];
341351
let builders = [
@@ -357,11 +367,16 @@ def LLVM_StoreOp : LLVM_Op<"store">, MemoryOpWithAlignmentAndAttributes {
357367
let arguments = (ins LLVM_LoadableType:$value,
358368
LLVM_PointerTo<LLVM_LoadableType>:$addr,
359369
OptionalAttr<SymbolRefArrayAttr>:$access_groups,
370+
OptionalAttr<SymbolRefArrayAttr>:$alias_scopes,
371+
OptionalAttr<SymbolRefArrayAttr>:$noalias_scopes,
360372
OptionalAttr<I64Attr>:$alignment, UnitAttr:$volatile_,
361373
UnitAttr:$nontemporal);
362374
string llvmBuilder = [{
363375
auto *inst = builder.CreateStore($value, $addr, $volatile_);
364-
}] # setAlignmentCode # setNonTemporalMetadataCode # setAccessGroupsMetadataCode;
376+
}] # setAlignmentCode
377+
# setNonTemporalMetadataCode
378+
# setAccessGroupsMetadataCode
379+
# setAliasScopeMetadataCode;
365380
let builders = [
366381
OpBuilder<(ins "Value":$value, "Value":$addr,
367382
CArg<"unsigned", "0">:$alignment, CArg<"bool", "false">:$isVolatile,
@@ -876,8 +891,7 @@ def LLVM_MetadataOp : LLVM_Op<"metadata", [
876891
);
877892
let summary = "LLVM dialect metadata.";
878893
let description = [{
879-
llvm.metadata op defines one or more metadata nodes. Currently the
880-
llvm.access_group metadata op is supported.
894+
llvm.metadata op defines one or more metadata nodes.
881895

882896
Example:
883897
llvm.metadata @metadata {
@@ -890,6 +904,66 @@ def LLVM_MetadataOp : LLVM_Op<"metadata", [
890904
let assemblyFormat = "$sym_name attr-dict-with-keyword $body";
891905
}
892906

907+
def LLVM_AliasScopeDomainMetadataOp : LLVM_Op<"alias_scope_domain", [
908+
HasParent<"MetadataOp">, Symbol
909+
]> {
910+
let arguments = (ins
911+
SymbolNameAttr:$sym_name,
912+
OptionalAttr<StrAttr>:$description
913+
);
914+
let summary = "LLVM dialect alias.scope domain metadata.";
915+
let description = [{
916+
Defines a domain that may be associated with an alias scope.
917+
918+
See the following link for more details:
919+
https://llvm.org/docs/LangRef.html#noalias-and-alias-scope-metadata
920+
}];
921+
let assemblyFormat = "$sym_name attr-dict";
922+
}
923+
924+
def LLVM_AliasScopeMetadataOp : LLVM_Op<"alias_scope", [
925+
HasParent<"MetadataOp">, Symbol
926+
]> {
927+
let arguments = (ins
928+
SymbolNameAttr:$sym_name,
929+
FlatSymbolRefAttr:$domain,
930+
OptionalAttr<StrAttr>:$description
931+
);
932+
let summary = "LLVM dialect alias.scope metadata.";
933+
let description = [{
934+
Defines an alias scope that can be attached to a memory-accessing operation.
935+
Such scopes can be used in combination with `noalias` metadata to indicate
936+
that sets of memory-affecting operations in one scope do not alias with
937+
memory-affecting operations in another scope.
938+
939+
Example:
940+
module {
941+
llvm.func @foo(%ptr1 : !llvm.ptr<i32>) {
942+
%c0 = llvm.mlir.constant(0 : i32) : i32
943+
%c4 = llvm.mlir.constant(4 : i32) : i32
944+
%1 = llvm.ptrtoint %ptr1 : !llvm.ptr<i32> to i32
945+
%2 = llvm.add %1, %c1 : i32
946+
%ptr2 = llvm.inttoptr %2 : i32 to !llvm.ptr<i32>
947+
llvm.store %c0, %ptr1 { alias_scopes = [@metadata::@scope1], llvm.noalias = [@metadata::@scope2] } : !llvm.ptr<i32>
948+
llvm.store %c4, %ptr2 { alias_scopes = [@metadata::@scope2], llvm.noalias = [@metadata::@scope1] } : !llvm.ptr<i32>
949+
llvm.return
950+
}
951+
952+
llvm.metadata @metadata {
953+
llvm.alias_scope_domain @unused_domain
954+
llvm.alias_scope_domain @domain { description = "Optional domain description"}
955+
llvm.alias_scope @scope1 { domain = @domain }
956+
llvm.alias_scope @scope2 { domain = @domain, description = "Optional scope description" }
957+
llvm.return
958+
}
959+
}
960+
961+
See the following link for more details:
962+
https://llvm.org/docs/LangRef.html#noalias-and-alias-scope-metadata
963+
}];
964+
let assemblyFormat = "$sym_name attr-dict";
965+
}
966+
893967
def LLVM_AccessGroupMetadataOp : LLVM_Op<"access_group", [
894968
HasParent<"MetadataOp">, Symbol
895969
]> {

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,11 @@ class ModuleTranslation {
115115
llvm::MDNode *getAccessGroup(Operation &opInst,
116116
SymbolRefAttr accessGroupRef) const;
117117

118+
/// Returns the LLVM metadata corresponding to a reference to an mlir LLVM
119+
/// dialect alias scope operation
120+
llvm::MDNode *getAliasScope(Operation &opInst,
121+
SymbolRefAttr aliasScopeRef) const;
122+
118123
/// Returns the LLVM metadata corresponding to a llvm loop's codegen
119124
/// options attribute.
120125
llvm::MDNode *lookupLoopOptionsMetadata(Attribute options) const {
@@ -131,6 +136,9 @@ class ModuleTranslation {
131136
// Sets LLVM metadata for memory operations that are in a parallel loop.
132137
void setAccessGroupsMetadata(Operation *op, llvm::Instruction *inst);
133138

139+
// Sets LLVM metadata for memory operations that have alias scope information.
140+
void setAliasScopeMetadata(Operation *op, llvm::Instruction *inst);
141+
134142
/// Converts the type from MLIR LLVM dialect to LLVM.
135143
llvm::Type *convertType(Type type);
136144

@@ -268,6 +276,10 @@ class ModuleTranslation {
268276
/// metadata nodes.
269277
LogicalResult createAccessGroupMetadata();
270278

279+
/// Process alias.scope LLVM Metadata operations and create LLVM
280+
/// metadata nodes for them and their domains.
281+
LogicalResult createAliasScopeMetadata();
282+
271283
/// Translates dialect attributes attached to the given operation.
272284
LogicalResult convertDialectAttributes(Operation *op);
273285

@@ -300,7 +312,7 @@ class ModuleTranslation {
300312
/// values after all operations are converted.
301313
DenseMap<Operation *, llvm::Instruction *> branchMapping;
302314

303-
/// Mapping from an access group metadata optation to its LLVM metadata.
315+
/// Mapping from an access group metadata operation to its LLVM metadata.
304316
/// This map is populated on module entry and is used to annotate loops (as
305317
/// identified via their branches) and contained memory accesses.
306318
DenseMap<Operation *, llvm::MDNode *> accessGroupMetadataMapping;
@@ -310,6 +322,10 @@ class ModuleTranslation {
310322
/// attribute.
311323
DenseMap<Attribute, llvm::MDNode *> loopOptionsMetadataMapping;
312324

325+
/// Mapping from an access scope metadata operation to its LLVM metadata.
326+
/// This map is populated on module entry.
327+
DenseMap<Operation *, llvm::MDNode *> aliasScopeMetadataMapping;
328+
313329
/// Stack of user-specified state elements, useful when translating operations
314330
/// with regions.
315331
SmallVector<std::unique_ptr<StackFrame>> stack;

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -335,32 +335,76 @@ SwitchOp::getMutableSuccessorOperands(unsigned index) {
335335
// Builder, printer and parser for for LLVM::LoadOp.
336336
//===----------------------------------------------------------------------===//
337337

338-
static LogicalResult verifyAccessGroups(Operation *op) {
339-
if (Attribute attribute =
340-
op->getAttr(LLVMDialect::getAccessGroupsAttrName())) {
338+
LogicalResult verifySymbolAttribute(
339+
Operation *op, StringRef attributeName,
340+
std::function<LogicalResult(Operation *, SymbolRefAttr)> verifySymbolType) {
341+
if (Attribute attribute = op->getAttr(attributeName)) {
341342
// The attribute is already verified to be a symbol ref array attribute via
342343
// a constraint in the operation definition.
343-
for (SymbolRefAttr accessGroupRef :
344+
for (SymbolRefAttr symbolRef :
344345
attribute.cast<ArrayAttr>().getAsRange<SymbolRefAttr>()) {
345-
StringRef metadataName = accessGroupRef.getRootReference();
346+
StringRef metadataName = symbolRef.getRootReference();
347+
StringRef symbolName = symbolRef.getLeafReference();
348+
// We want @metadata::@symbol, not just @symbol
349+
if (metadataName == symbolName) {
350+
return op->emitOpError() << "expected '" << symbolRef
351+
<< "' to specify a fully qualified reference";
352+
}
346353
auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
347354
op->getParentOp(), metadataName);
348355
if (!metadataOp)
349-
return op->emitOpError() << "expected '" << accessGroupRef
350-
<< "' to reference a metadata op";
351-
StringRef accessGroupName = accessGroupRef.getLeafReference();
352-
Operation *accessGroupOp =
353-
SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
354-
if (!accessGroupOp)
355-
return op->emitOpError() << "expected '" << accessGroupRef
356-
<< "' to reference an access_group op";
356+
return op->emitOpError()
357+
<< "expected '" << symbolRef << "' to reference a metadata op";
358+
Operation *symbolOp =
359+
SymbolTable::lookupNearestSymbolFrom(metadataOp, symbolName);
360+
if (!symbolOp)
361+
return op->emitOpError()
362+
<< "expected '" << symbolRef << "' to be a valid reference";
363+
if (failed(verifySymbolType(symbolOp, symbolRef))) {
364+
return failure();
365+
}
357366
}
358367
}
359368
return success();
360369
}
361370

371+
// Verifies that metadata ops are wired up properly.
372+
template <typename OpTy>
373+
static LogicalResult verifyOpMetadata(Operation *op, StringRef attributeName) {
374+
auto verifySymbolType = [op](Operation *symbolOp,
375+
SymbolRefAttr symbolRef) -> LogicalResult {
376+
if (!isa<OpTy>(symbolOp)) {
377+
return op->emitOpError()
378+
<< "expected '" << symbolRef << "' to resolve to a "
379+
<< OpTy::getOperationName();
380+
}
381+
return success();
382+
};
383+
384+
return verifySymbolAttribute(op, attributeName, verifySymbolType);
385+
}
386+
387+
static LogicalResult verifyMemoryOpMetadata(Operation *op) {
388+
// access_groups
389+
if (failed(verifyOpMetadata<LLVM::AccessGroupMetadataOp>(
390+
op, LLVMDialect::getAccessGroupsAttrName())))
391+
return failure();
392+
393+
// alias_scopes
394+
if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>(
395+
op, LLVMDialect::getAliasScopesAttrName())))
396+
return failure();
397+
398+
// noalias_scopes
399+
if (failed(verifyOpMetadata<LLVM::AliasScopeMetadataOp>(
400+
op, LLVMDialect::getNoAliasScopesAttrName())))
401+
return failure();
402+
403+
return success();
404+
}
405+
362406
static LogicalResult verify(LoadOp op) {
363-
return verifyAccessGroups(op.getOperation());
407+
return verifyMemoryOpMetadata(op.getOperation());
364408
}
365409

366410
void LoadOp::build(OpBuilder &builder, OperationState &result, Type t,
@@ -422,7 +466,7 @@ static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
422466
//===----------------------------------------------------------------------===//
423467

424468
static LogicalResult verify(StoreOp op) {
425-
return verifyAccessGroups(op.getOperation());
469+
return verifyMemoryOpMetadata(op.getOperation());
426470
}
427471

428472
void StoreOp::build(OpBuilder &builder, OperationState &result, Value value,

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,78 @@ void ModuleTranslation::setAccessGroupsMetadata(Operation *op,
774774
}
775775
}
776776

777+
LogicalResult ModuleTranslation::createAliasScopeMetadata() {
778+
mlirModule->walk([&](LLVM::MetadataOp metadatas) {
779+
// Create the domains first, so they can be reference below in the scopes.
780+
DenseMap<Operation *, llvm::MDNode *> aliasScopeDomainMetadataMapping;
781+
metadatas.walk([&](LLVM::AliasScopeDomainMetadataOp op) {
782+
llvm::LLVMContext &ctx = llvmModule->getContext();
783+
llvm::SmallVector<llvm::Metadata *, 2> operands;
784+
operands.push_back({}); // Placeholder for self-reference
785+
if (Optional<StringRef> description = op.description())
786+
operands.push_back(llvm::MDString::get(ctx, description.getValue()));
787+
llvm::MDNode *domain = llvm::MDNode::get(ctx, operands);
788+
domain->replaceOperandWith(0, domain); // Self-reference for uniqueness
789+
aliasScopeDomainMetadataMapping.insert({op, domain});
790+
});
791+
792+
// Now create the scopes, referencing the domains created above.
793+
metadatas.walk([&](LLVM::AliasScopeMetadataOp op) {
794+
llvm::LLVMContext &ctx = llvmModule->getContext();
795+
assert(isa<LLVM::MetadataOp>(op->getParentOp()));
796+
auto metadataOp = dyn_cast<LLVM::MetadataOp>(op->getParentOp());
797+
Operation *domainOp =
798+
SymbolTable::lookupNearestSymbolFrom(metadataOp, op.domainAttr());
799+
llvm::MDNode *domain = aliasScopeDomainMetadataMapping.lookup(domainOp);
800+
assert(domain && "Scope's domain should already be valid");
801+
llvm::SmallVector<llvm::Metadata *, 3> operands;
802+
operands.push_back({}); // Placeholder for self-reference
803+
operands.push_back(domain);
804+
if (Optional<StringRef> description = op.description())
805+
operands.push_back(llvm::MDString::get(ctx, description.getValue()));
806+
llvm::MDNode *scope = llvm::MDNode::get(ctx, operands);
807+
scope->replaceOperandWith(0, scope); // Self-reference for uniqueness
808+
aliasScopeMetadataMapping.insert({op, scope});
809+
});
810+
});
811+
return success();
812+
}
813+
814+
llvm::MDNode *
815+
ModuleTranslation::getAliasScope(Operation &opInst,
816+
SymbolRefAttr aliasScopeRef) const {
817+
StringRef metadataName = aliasScopeRef.getRootReference();
818+
StringRef scopeName = aliasScopeRef.getLeafReference();
819+
auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
820+
opInst.getParentOp(), metadataName);
821+
Operation *aliasScopeOp =
822+
SymbolTable::lookupNearestSymbolFrom(metadataOp, scopeName);
823+
return aliasScopeMetadataMapping.lookup(aliasScopeOp);
824+
}
825+
826+
void ModuleTranslation::setAliasScopeMetadata(Operation *op,
827+
llvm::Instruction *inst) {
828+
auto populateScopeMetadata = [this, op, inst](StringRef attrName,
829+
StringRef llvmMetadataName) {
830+
auto scopes = op->getAttrOfType<ArrayAttr>(attrName);
831+
if (!scopes || scopes.empty())
832+
return;
833+
llvm::Module *module = inst->getModule();
834+
SmallVector<llvm::Metadata *> scopeMDs;
835+
for (SymbolRefAttr scopeRef : scopes.getAsRange<SymbolRefAttr>())
836+
scopeMDs.push_back(getAliasScope(*op, scopeRef));
837+
llvm::MDNode *unionMD = nullptr;
838+
if (scopeMDs.size() == 1)
839+
unionMD = llvm::cast<llvm::MDNode>(scopeMDs.front());
840+
else if (scopeMDs.size() >= 2)
841+
unionMD = llvm::MDNode::get(module->getContext(), scopeMDs);
842+
inst->setMetadata(module->getMDKindID(llvmMetadataName), unionMD);
843+
};
844+
845+
populateScopeMetadata(LLVMDialect::getAliasScopesAttrName(), "alias.scope");
846+
populateScopeMetadata(LLVMDialect::getNoAliasScopesAttrName(), "noalias");
847+
}
848+
777849
llvm::Type *ModuleTranslation::convertType(Type type) {
778850
return typeTranslator.translateType(type);
779851
}
@@ -842,6 +914,8 @@ mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext,
842914
return nullptr;
843915
if (failed(translator.createAccessGroupMetadata()))
844916
return nullptr;
917+
if (failed(translator.createAliasScopeMetadata()))
918+
return nullptr;
845919
if (failed(translator.convertFunctions()))
846920
return nullptr;
847921
if (llvm::verifyModule(*translator.llvmModule, &llvm::errs()))

0 commit comments

Comments
 (0)