Skip to content

Commit e6260ad

Browse files
committed
[mlir] Simplify various pieces of code now that Identifier has access to the Context/Dialect
This also exposed a bug in Dialect loading where it was not correctly identifying identifiers that had the dialect namespace as a prefix. Differential Revision: https://reviews.llvm.org/D97431
1 parent 16abaca commit e6260ad

28 files changed

+74
-97
lines changed

mlir/examples/toy/Ch2/mlir/MLIRGen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class MLIRGenImpl {
9393

9494
/// Helper conversion for a Toy AST location to an MLIR location.
9595
mlir::Location loc(Location loc) {
96-
return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
96+
return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line,
9797
loc.col);
9898
}
9999

mlir/examples/toy/Ch3/mlir/MLIRGen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class MLIRGenImpl {
9393

9494
/// Helper conversion for a Toy AST location to an MLIR location.
9595
mlir::Location loc(Location loc) {
96-
return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
96+
return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line,
9797
loc.col);
9898
}
9999

mlir/examples/toy/Ch4/mlir/MLIRGen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class MLIRGenImpl {
9393

9494
/// Helper conversion for a Toy AST location to an MLIR location.
9595
mlir::Location loc(Location loc) {
96-
return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
96+
return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line,
9797
loc.col);
9898
}
9999

mlir/examples/toy/Ch5/mlir/MLIRGen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class MLIRGenImpl {
9393

9494
/// Helper conversion for a Toy AST location to an MLIR location.
9595
mlir::Location loc(Location loc) {
96-
return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
96+
return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line,
9797
loc.col);
9898
}
9999

mlir/examples/toy/Ch6/mlir/MLIRGen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class MLIRGenImpl {
9393

9494
/// Helper conversion for a Toy AST location to an MLIR location.
9595
mlir::Location loc(Location loc) {
96-
return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
96+
return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line,
9797
loc.col);
9898
}
9999

mlir/examples/toy/Ch7/mlir/MLIRGen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ class MLIRGenImpl {
113113

114114
/// Helper conversion for a Toy AST location to an MLIR location.
115115
mlir::Location loc(Location loc) {
116-
return builder.getFileLineColLoc(builder.getIdentifier(*loc.file), loc.line,
116+
return mlir::FileLineColLoc::get(builder.getIdentifier(*loc.file), loc.line,
117117
loc.col);
118118
}
119119

mlir/include/mlir/IR/Builders.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,6 @@ class Builder {
5656

5757
// Locations.
5858
Location getUnknownLoc();
59-
Location getFileLineColLoc(Identifier filename, unsigned line,
60-
unsigned column);
6159
Location getFusedLoc(ArrayRef<Location> locs,
6260
Attribute metadata = Attribute());
6361

mlir/include/mlir/IR/BuiltinAttributes.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,7 @@ class OpaqueAttr : public Attribute::AttrBase<OpaqueAttr, Attribute,
296296
using Base::getChecked;
297297

298298
/// Get or create a new OpaqueAttr with the provided dialect and string data.
299-
static OpaqueAttr get(MLIRContext *context, Identifier dialect,
300-
StringRef attrData, Type type);
299+
static OpaqueAttr get(Identifier dialect, StringRef attrData, Type type);
301300

302301
/// Get or create a new OpaqueAttr with the provided dialect and string data.
303302
/// If the given identifier is not a valid namespace for a dialect, then a

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,15 @@ def Builtin_Opaque : Builtin_Type<"Opaque"> {
293293
"Identifier":$dialectNamespace,
294294
StringRefParameter<"">:$typeData
295295
);
296+
297+
let builders = [
298+
TypeBuilderWithInferredContext<(ins
299+
"Identifier":$dialectNamespace, CArg<"StringRef", "{}">:$typeData
300+
), [{
301+
return $_get(dialectNamespace.getContext(), dialectNamespace, typeData);
302+
}]>
303+
];
304+
let skipDefaultBuilders = 1;
296305
let genVerifyDecl = 1;
297306
}
298307

mlir/include/mlir/IR/Location.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,7 @@ class FileLineColLoc
129129
using Base::Base;
130130

131131
/// Return a uniqued FileLineCol location object.
132-
static Location get(Identifier filename, unsigned line, unsigned column,
133-
MLIRContext *context);
132+
static Location get(Identifier filename, unsigned line, unsigned column);
134133
static Location get(StringRef filename, unsigned line, unsigned column,
135134
MLIRContext *context);
136135

@@ -174,7 +173,7 @@ class NameLoc : public Attribute::AttrBase<NameLoc, LocationAttr,
174173
static Location get(Identifier name, Location child);
175174

176175
/// Return a uniqued name location object with an unknown child.
177-
static Location get(Identifier name, MLIRContext *context);
176+
static Location get(Identifier name);
178177

179178
/// Return the name identifier.
180179
Identifier getName() const;

mlir/include/mlir/IR/OpBase.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ def AnyComplex : Type<CPred<"$_self.isa<::mlir::ComplexType>()">,
491491
class OpaqueType<string dialect, string name, string summary>
492492
: Type<CPred<"isOpaqueTypeWithName($_self, \""#dialect#"\", \""#name#"\")">,
493493
summary, "::mlir::OpaqueType">,
494-
BuildableType<"::mlir::OpaqueType::get($_builder.getContext(), "
494+
BuildableType<"::mlir::OpaqueType::get("
495495
"$_builder.getIdentifier(\"" # dialect # "\"), \""
496496
# name # "\")">;
497497

mlir/include/mlir/IR/OperationSupport.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,15 @@ class OperationName {
314314
OperationName(StringRef name, MLIRContext *context);
315315

316316
/// Return the name of the dialect this operation is registered to.
317-
StringRef getDialect() const;
317+
StringRef getDialectNamespace() const;
318+
319+
/// Return the Dialect this operation is registered to if it is loaded in the
320+
/// context, or nullptr if the dialect isn't loaded.
321+
Dialect *getDialect() const {
322+
if (const auto *abstractOp = getAbstractOperation())
323+
return &abstractOp->dialect;
324+
return representation.get<Identifier>().getDialect();
325+
}
318326

319327
/// Return the operation name with dialect name stripped, if it has one.
320328
StringRef stripDialect() const;

mlir/lib/CAPI/IR/BuiltinAttributes.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,9 @@ bool mlirAttributeIsAOpaque(MlirAttribute attr) {
163163
MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace,
164164
intptr_t dataLength, const char *data,
165165
MlirType type) {
166-
return wrap(OpaqueAttr::get(
167-
unwrap(ctx), Identifier::get(unwrap(dialectNamespace), unwrap(ctx)),
168-
StringRef(data, dataLength), unwrap(type)));
166+
return wrap(
167+
OpaqueAttr::get(Identifier::get(unwrap(dialectNamespace), unwrap(ctx)),
168+
StringRef(data, dataLength), unwrap(type)));
169169
}
170170

171171
MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) {

mlir/lib/IR/Builders.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,6 @@ Identifier Builder::getIdentifier(StringRef str) {
2929

3030
Location Builder::getUnknownLoc() { return UnknownLoc::get(context); }
3131

32-
Location Builder::getFileLineColLoc(Identifier filename, unsigned line,
33-
unsigned column) {
34-
return FileLineColLoc::get(filename, line, column, context);
35-
}
36-
3732
Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
3833
return FusedLoc::get(locs, metadata, context);
3934
}

mlir/lib/IR/BuiltinAttributes.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -382,9 +382,8 @@ IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; }
382382
// OpaqueAttr
383383
//===----------------------------------------------------------------------===//
384384

385-
OpaqueAttr OpaqueAttr::get(MLIRContext *context, Identifier dialect,
386-
StringRef attrData, Type type) {
387-
return Base::get(context, dialect, attrData, type);
385+
OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type) {
386+
return Base::get(dialect.getContext(), dialect, attrData, type);
388387
}
389388

390389
OpaqueAttr OpaqueAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,

mlir/lib/IR/Dialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ Attribute Dialect::parseAttribute(DialectAsmParser &parser, Type type) const {
127127
Type Dialect::parseType(DialectAsmParser &parser) const {
128128
// If this dialect allows unknown types, then represent this with OpaqueType.
129129
if (allowsUnknownTypes()) {
130-
auto ns = Identifier::get(getNamespace(), getContext());
131-
return OpaqueType::get(getContext(), ns, parser.getFullSymbolSpec());
130+
Identifier ns = Identifier::get(getNamespace(), getContext());
131+
return OpaqueType::get(ns, parser.getFullSymbolSpec());
132132
}
133133

134134
parser.emitError(parser.getNameLoc())

mlir/lib/IR/Location.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@ Location CallSiteLoc::getCaller() const { return getImpl()->caller; }
4848
//===----------------------------------------------------------------------===//
4949

5050
Location FileLineColLoc::get(Identifier filename, unsigned line,
51-
unsigned column, MLIRContext *context) {
52-
return Base::get(context, filename, line, column);
51+
unsigned column) {
52+
return Base::get(filename.getContext(), filename, line, column);
5353
}
5454

5555
Location FileLineColLoc::get(StringRef filename, unsigned line, unsigned column,
5656
MLIRContext *context) {
5757
return get(Identifier::get(filename.empty() ? "-" : filename, context), line,
58-
column, context);
58+
column);
5959
}
6060

6161
StringRef FileLineColLoc::getFilename() const { return getImpl()->filename; }
@@ -112,8 +112,8 @@ Location NameLoc::get(Identifier name, Location child) {
112112
return Base::get(child->getContext(), name, child);
113113
}
114114

115-
Location NameLoc::get(Identifier name, MLIRContext *context) {
116-
return get(name, UnknownLoc::get(context));
115+
Location NameLoc::get(Identifier name) {
116+
return get(name, UnknownLoc::get(name.getContext()));
117117
}
118118

119119
/// Return the name identifier.

mlir/lib/IR/MLIRContext.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -520,9 +520,11 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
520520
// Refresh all the identifiers dialect field, this catches cases where a
521521
// dialect may be loaded after identifier prefixed with this dialect name
522522
// were already created.
523+
llvm::SmallString<32> dialectPrefix(dialectNamespace);
524+
dialectPrefix.push_back('.');
523525
for (auto &identifierEntry : impl.identifiers)
524-
if (!identifierEntry.second &&
525-
identifierEntry.first().startswith(dialectNamespace))
526+
if (identifierEntry.second.is<MLIRContext *>() &&
527+
identifierEntry.first().startswith(dialectPrefix))
526528
identifierEntry.second = dialect.get();
527529

528530
// Actually register the interfaces with delayed registration.

mlir/lib/IR/Operation.cpp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@ OperationName::OperationName(StringRef name, MLIRContext *context) {
3535
}
3636

3737
/// Return the name of the dialect this operation is registered to.
38-
StringRef OperationName::getDialect() const {
39-
return getStringRef().split('.').first;
38+
StringRef OperationName::getDialectNamespace() const {
39+
if (Dialect *dialect = getDialect())
40+
return dialect->getNamespace();
41+
return representation.get<Identifier>().strref().split('.').first;
4042
}
4143

4244
/// Return the operation name with dialect name stripped, if it has one.
@@ -213,14 +215,7 @@ MLIRContext *Operation::getContext() { return location->getContext(); }
213215

214216
/// Return the dialect this operation is associated with, or nullptr if the
215217
/// associated dialect is not registered.
216-
Dialect *Operation::getDialect() {
217-
if (auto *abstractOp = getAbstractOperation())
218-
return &abstractOp->dialect;
219-
220-
// If this operation hasn't been registered or doesn't have abstract
221-
// operation, try looking up the dialect name in the context.
222-
return getContext()->getLoadedDialect(getName().getDialect());
223-
}
218+
Dialect *Operation::getDialect() { return getName().getDialect(); }
224219

225220
Region *Operation::getParentRegion() {
226221
return block ? block->getParent() : nullptr;

mlir/lib/IR/Verifier.cpp

Lines changed: 14 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,6 @@ class OperationVerifier {
4646
/// Verify the given operation.
4747
LogicalResult verify(Operation &op);
4848

49-
/// Returns the registered dialect for a dialect-specific attribute.
50-
Dialect *getDialectForAttribute(const NamedAttribute &attr) {
51-
assert(attr.first.strref().contains('.') && "expected dialect attribute");
52-
auto dialectNamePair = attr.first.strref().split('.');
53-
return ctx->getLoadedDialect(dialectNamePair.first);
54-
}
55-
5649
private:
5750
/// Verify the given potentially nested region or block.
5851
LogicalResult verifyRegion(Region &region);
@@ -81,10 +74,6 @@ class OperationVerifier {
8174

8275
/// Dominance information for this operation, when checking dominance.
8376
DominanceInfo *domInfo = nullptr;
84-
85-
/// Mapping between dialect namespace and if that dialect supports
86-
/// unregistered operations.
87-
llvm::StringMap<bool> dialectAllowsUnknownOps;
8877
};
8978
} // end anonymous namespace
9079

@@ -170,15 +159,14 @@ LogicalResult OperationVerifier::verifyOperation(Operation &op) {
170159
/// Verify that all of the attributes are okay.
171160
for (auto attr : op.getAttrs()) {
172161
// Check for any optional dialect specific attributes.
173-
if (!attr.first.strref().contains('.'))
174-
continue;
175-
if (auto *dialect = getDialectForAttribute(attr))
162+
if (auto *dialect = attr.first.getDialect())
176163
if (failed(dialect->verifyOperationAttribute(&op, attr)))
177164
return failure();
178165
}
179166

180167
// If we can get operation info for this, check the custom hook.
181-
auto *opInfo = op.getAbstractOperation();
168+
OperationName opName = op.getName();
169+
auto *opInfo = opName.getAbstractOperation();
182170
if (opInfo && failed(opInfo->verifyInvariants(&op)))
183171
return failure();
184172

@@ -213,33 +201,21 @@ LogicalResult OperationVerifier::verifyOperation(Operation &op) {
213201
return success();
214202

215203
// Otherwise, verify that the parent dialect allows un-registered operations.
216-
auto dialectPrefix = op.getName().getDialect();
217-
218-
// Check for an existing answer for the operation dialect.
219-
auto it = dialectAllowsUnknownOps.find(dialectPrefix);
220-
if (it == dialectAllowsUnknownOps.end()) {
221-
// If the operation dialect is registered, query it directly.
222-
if (auto *dialect = ctx->getLoadedDialect(dialectPrefix))
223-
it = dialectAllowsUnknownOps
224-
.try_emplace(dialectPrefix, dialect->allowsUnknownOperations())
225-
.first;
226-
// Otherwise, unregistered dialects (when allowed by the context)
227-
// conservatively allow unknown operations.
228-
else {
229-
if (!op.getContext()->allowsUnregisteredDialects() && !op.getDialect())
230-
return op.emitOpError()
231-
<< "created with unregistered dialect. If this is "
232-
"intended, please call allowUnregisteredDialects() on the "
233-
"MLIRContext, or use -allow-unregistered-dialect with "
234-
"mlir-opt";
235-
236-
it = dialectAllowsUnknownOps.try_emplace(dialectPrefix, true).first;
204+
Dialect *dialect = opName.getDialect();
205+
if (!dialect) {
206+
if (!ctx->allowsUnregisteredDialects()) {
207+
return op.emitOpError()
208+
<< "created with unregistered dialect. If this is "
209+
"intended, please call allowUnregisteredDialects() on the "
210+
"MLIRContext, or use -allow-unregistered-dialect with "
211+
"mlir-opt";
237212
}
213+
return success();
238214
}
239215

240-
if (!it->second) {
216+
if (!dialect->allowsUnknownOperations()) {
241217
return op.emitError("unregistered operation '")
242-
<< op.getName() << "' found in dialect ('" << dialectPrefix
218+
<< op.getName() << "' found in dialect ('" << dialect->getNamespace()
243219
<< "') that does not allow unknown operations";
244220
}
245221

mlir/lib/Parser/DialectSymbolParser.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ Type Parser::parseExtendedType() {
563563

564564
// Otherwise, form a new opaque type.
565565
return OpaqueType::getChecked(
566-
getEncodedSourceLocation(loc), state.context,
566+
getEncodedSourceLocation(loc),
567567
Identifier::get(dialectName, state.context), symbolData);
568568
});
569569
}

mlir/lib/Parser/LocationParser.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) {
145145
"expected ')' after child location of NameLoc"))
146146
return failure();
147147
} else {
148-
loc = NameLoc::get(Identifier::get(str, ctx), ctx);
148+
loc = NameLoc::get(Identifier::get(str, ctx));
149149
}
150150

151151
return success();

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1944,8 +1944,8 @@ Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) {
19441944
auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
19451945
if (fileName.empty())
19461946
fileName = "<unknown>";
1947-
return opBuilder.getFileLineColLoc(opBuilder.getIdentifier(fileName),
1948-
debugLine->line, debugLine->col);
1947+
return FileLineColLoc::get(opBuilder.getIdentifier(fileName), debugLine->line,
1948+
debugLine->col);
19491949
}
19501950

19511951
LogicalResult

mlir/lib/Transforms/LocationSnapshot.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ static void generateLocationsFromIR(raw_ostream &os, StringRef fileName,
4444
if (it == opToLineCol.end())
4545
return;
4646
const std::pair<unsigned, unsigned> &lineCol = it->second;
47-
auto newLoc =
48-
builder.getFileLineColLoc(file, lineCol.first, lineCol.second);
47+
auto newLoc = FileLineColLoc::get(file, lineCol.first, lineCol.second);
4948

5049
// If we don't have a tag, set the location directly
5150
if (!tagIdentifier) {

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2702,10 +2702,10 @@ auto ConversionTarget::getOpInfo(OperationName op) const
27022702
if (it != legalOperations.end())
27032703
return it->second;
27042704
// Check for info for the parent dialect.
2705-
auto dialectIt = legalDialects.find(op.getDialect());
2705+
auto dialectIt = legalDialects.find(op.getDialectNamespace());
27062706
if (dialectIt != legalDialects.end()) {
27072707
Optional<DynamicLegalityCallbackFn> callback;
2708-
auto dialectFn = dialectLegalityFns.find(op.getDialect());
2708+
auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace());
27092709
if (dialectFn != dialectLegalityFns.end())
27102710
callback = dialectFn->second;
27112711
return LegalizationInfo{dialectIt->second, /*isRecursivelyLegal=*/false,

0 commit comments

Comments
 (0)