Skip to content

Commit 292f895

Browse files
committed
[mlir] Add the ability to define dialect-specific location attrs.
This patch adds the capability to define dialect-specific location attrs. This is useful in particular for defining location structure that doesn't necessarily fit within the core MLIR location hierarchy, but doesn't make sense to push upstream (i.e. a custom use case). This patch adds an AttributeTrait, `IsLocation`, which is tagged onto all the builtin location attrs, as well as the test location attribute. This is necessary because previously LocationAttr::classof only returned true if the attribute was one of the builtin location attributes, and well, the point of this patch is to allow dialects to define their own location attributes. There was an alternate implementation I considered wherein LocationAttr becomes an AttrInterface, but that was discarded because there are likely to be *many* locations in a single program, and I was concerned that forcing every MLIR user to pay the cost of the additional lookup/dispatch was unacceptable. It also would have been a *much* more invasive change. It would have allowed for more flexibility in terms of pretty printing, but it's unclear how useful/necessary that flexibility would be given how much customizability there already is for attribute definitions.
1 parent 716594d commit 292f895

File tree

9 files changed

+69
-6
lines changed

9 files changed

+69
-6
lines changed

mlir/include/mlir/IR/Attributes.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,12 +322,19 @@ class AttributeInterface
322322
// Core AttributeTrait
323323
//===----------------------------------------------------------------------===//
324324

325-
/// This trait is used to determine if an attribute is mutable or not. It is
326-
/// attached on an attribute if the corresponding ImplType defines a `mutate`
327-
/// function with proper signature.
328325
namespace AttributeTrait {
326+
/// This trait is used to determine if an attribute is mutable or not. It is
327+
/// attached on an attribute if the corresponding ConcreteType defines a
328+
/// `mutate` function with proper signature.
329329
template <typename ConcreteType>
330330
using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>;
331+
332+
/// This trait is used to determine if an attribute is a location or not. It is
333+
/// attached to an attribute by the user if they intend the attribute to be used
334+
/// as a location.
335+
template <typename ConcreteType>
336+
struct IsLocation : public AttributeTrait::TraitBase<ConcreteType, IsLocation> {
337+
};
331338
} // namespace AttributeTrait
332339

333340
} // namespace mlir.

mlir/include/mlir/IR/BuiltinLocationAttributes.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ include "mlir/IR/BuiltinDialect.td"
1818

1919
// Base class for Builtin dialect location attributes.
2020
class Builtin_LocationAttr<string name, list<Trait> traits = []>
21-
: AttrDef<Builtin_Dialect, name, traits, "::mlir::LocationAttr"> {
21+
: AttrDef<Builtin_Dialect, name, traits # [NativeAttrTrait<"IsLocation">],
22+
"::mlir::LocationAttr"> {
2223
let cppClassName = name;
2324
let mnemonic = ?;
2425
}

mlir/lib/AsmParser/LocationParser.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,29 @@ ParseResult Parser::parseNameOrFileLineColLocation(LocationAttr &loc) {
153153
return success();
154154
}
155155

156+
ParseResult Parser::parseDialectLocation(LocationAttr &loc) {
157+
consumeToken(Token::bare_identifier);
158+
159+
if (parseToken(Token::less,
160+
"expected `<` to start dialect location attribute"))
161+
return failure();
162+
163+
Attribute locAttr = parseAttribute(Type{});
164+
// No attribute parsed, someone else has returned an error already.
165+
if (!locAttr)
166+
return failure();
167+
168+
loc = llvm::dyn_cast<LocationAttr>(locAttr);
169+
if (!loc)
170+
return emitError() << "expected a LocationAttr subclass";
171+
172+
if (parseToken(Token::greater,
173+
"expected `>` to end dialect location attribute"))
174+
return failure();
175+
176+
return success();
177+
}
178+
156179
ParseResult Parser::parseLocationInstance(LocationAttr &loc) {
157180
// Handle aliases.
158181
if (getToken().is(Token::hash_identifier)) {
@@ -187,5 +210,8 @@ ParseResult Parser::parseLocationInstance(LocationAttr &loc) {
187210
return success();
188211
}
189212

213+
if (getToken().getSpelling() == "dialect")
214+
return parseDialectLocation(loc);
215+
190216
return emitWrongTokenError("expected location instance");
191217
}

mlir/lib/AsmParser/Parser.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,9 @@ class Parser {
303303
/// Parse a name or FileLineCol location instance.
304304
ParseResult parseNameOrFileLineColLocation(LocationAttr &loc);
305305

306+
/// Parse a dialect-specific location.
307+
ParseResult parseDialectLocation(LocationAttr &loc);
308+
306309
//===--------------------------------------------------------------------===//
307310
// Affine Parsing
308311
//===--------------------------------------------------------------------===//

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2061,6 +2061,12 @@ void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty,
20612061
[&](Location loc) { printLocationInternal(loc, pretty); },
20622062
[&]() { os << ", "; });
20632063
os << ']';
2064+
})
2065+
.Default([&](LocationAttr loc) {
2066+
// Assumes that this is a dialect-specific attribute.
2067+
os << "dialect<";
2068+
printAttribute(loc);
2069+
os << ">";
20642070
});
20652071
}
20662072

mlir/lib/IR/Location.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ WalkResult LocationAttr::walk(function_ref<WalkResult(Location)> walkFn) {
6464

6565
/// Methods for support type inquiry through isa, cast, and dyn_cast.
6666
bool LocationAttr::classof(Attribute attr) {
67-
return llvm::isa<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, OpaqueLoc,
68-
UnknownLoc>(attr);
67+
return attr.hasTrait<AttributeTrait::IsLocation>();
6968
}
7069

7170
//===----------------------------------------------------------------------===//

mlir/test/IR/locations.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,10 @@ func.func @optional_location_specifier() {
8989
test.attr_with_loc("foo" loc("foo_loc"))
9090
return
9191
}
92+
93+
// CHECK-LABEL: @dialect_location
94+
// CHECK: test.attr_with_loc("dialectLoc" loc(dialect<#test.custom_location<"foo.mlir" * 32>>))
95+
func.func @dialect_location() {
96+
test.attr_with_loc("dialectLoc" loc(dialect<#test.custom_location<"foo.mlir"*32>>))
97+
return
98+
}

mlir/test/IR/pretty-locations.mlir

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ func.func @inline_notation() -> i32 {
2424
affine.if #set0(%2) {
2525
} loc(fused<"myPass">["foo", "foo2"])
2626

27+
// CHECK: "foo.op"() : () -> () dialect<#test.custom_location<"foo.mlir" * 1234>>
28+
"foo.op"() : () -> () loc(dialect<#test.custom_location<"foo.mlir" * 1234>>)
29+
2730
// CHECK: return %0 : i32 [unknown]
2831
return %1 : i32 loc(unknown)
2932
}

mlir/test/lib/Dialect/Test/TestAttrDefs.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,4 +377,15 @@ def NestedPolynomialAttr2 : Test_Attr<"NestedPolynomialAttr2"> {
377377
}
378378

379379

380+
// Test custom location handling.
381+
def TestCustomLocationAttr
382+
: Test_Attr<"TestCustomLocation", [NativeAttrTrait<"IsLocation">]> {
383+
let mnemonic = "custom_location";
384+
let parameters = (ins "mlir::StringAttr":$file, "unsigned":$line);
385+
386+
// Choose a silly separator token so we know it's hitting this code path
387+
// and not another.
388+
let assemblyFormat = "`<` $file `*` $line `>`";
389+
}
390+
380391
#endif // TEST_ATTRDEFS

0 commit comments

Comments
 (0)