Skip to content

Commit 759a7b5

Browse files
authored
[mlir] Add the ability to define dialect-specific location attrs. (#105584)
This patch adds the capability to define dialect-specific location attrs. This is useful in particular for defining location structure that doesn't necessarily fit within the core MLIR location hierarchy, but doesn't make sense to push upstream (i.e. a custom use case). This patch adds an AttributeTrait, `IsLocation`, which is tagged onto all the builtin location attrs, as well as the test location attribute. This is necessary because previously LocationAttr::classof only returned true if the attribute was one of the builtin location attributes, and well, the point of this patch is to allow dialects to define their own location attributes. There was an alternate implementation I considered wherein LocationAttr becomes an AttrInterface, but that was discarded because there are likely to be *many* locations in a single program, and I was concerned that forcing every MLIR user to pay the cost of the additional lookup/dispatch was unacceptable. It also would have been a *much* more invasive change. It would have allowed for more flexibility in terms of pretty printing, but it's unclear how useful/necessary that flexibility would be given how much customizability there already is for attribute definitions.
1 parent c098435 commit 759a7b5

File tree

13 files changed

+126
-48
lines changed

13 files changed

+126
-48
lines changed

mlir/include/mlir/IR/AttrTypeBase.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,12 @@ class AttrDef<Dialect dialect, string name, list<Trait> traits = [],
281281
let predicate = CPred<"::llvm::isa<" # cppType # ">($_self)">;
282282
}
283283

284+
// Provide a LocationAttrDef for dialects to provide their own locations
285+
// that subclass LocationAttr.
286+
class LocationAttrDef<Dialect dialect, string name, list<Trait> traits = []>
287+
: AttrDef<dialect, name, traits # [NativeAttrTrait<"IsLocation">],
288+
"::mlir::LocationAttr">;
289+
284290
// Define a new type, named `name`, belonging to `dialect` that inherits from
285291
// the given C++ base class.
286292
class TypeDef<Dialect dialect, string name, list<Trait> traits = [],

mlir/include/mlir/IR/Attributes.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,12 +322,19 @@ class AttributeInterface
322322
// Core AttributeTrait
323323
//===----------------------------------------------------------------------===//
324324

325-
/// This trait is used to determine if an attribute is mutable or not. It is
326-
/// attached on an attribute if the corresponding ImplType defines a `mutate`
327-
/// function with proper signature.
328325
namespace AttributeTrait {
326+
/// This trait is used to determine if an attribute is mutable or not. It is
327+
/// attached on an attribute if the corresponding ConcreteType defines a
328+
/// `mutate` function with proper signature.
329329
template <typename ConcreteType>
330330
using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>;
331+
332+
/// This trait is used to determine if an attribute is a location or not. It is
333+
/// attached to an attribute by the user if they intend the attribute to be used
334+
/// as a location.
335+
template <typename ConcreteType>
336+
struct IsLocation : public AttributeTrait::TraitBase<ConcreteType, IsLocation> {
337+
};
331338
} // namespace AttributeTrait
332339

333340
} // namespace mlir.

mlir/include/mlir/IR/BuiltinLocationAttributes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ include "mlir/IR/BuiltinDialect.td"
1818

1919
// Base class for Builtin dialect location attributes.
2020
class Builtin_LocationAttr<string name, list<Trait> traits = []>
21-
: AttrDef<Builtin_Dialect, name, traits, "::mlir::LocationAttr"> {
21+
: LocationAttrDef<Builtin_Dialect, name, traits> {
2222
let cppClassName = name;
2323
let mnemonic = ?;
2424
}

mlir/include/mlir/IR/Location.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@ class LocationAttr : public Attribute {
3232
public:
3333
using Attribute::Attribute;
3434

35-
/// Walk all of the locations nested under, and including, the current.
35+
/// Walk all of the locations nested directly under, and including, the
36+
/// current. This means that if a location is nested under a non-location
37+
/// attribute, it will *not* be walked by this method. This walk is performed
38+
/// in pre-order to get this behavior.
3639
WalkResult walk(function_ref<WalkResult(Location)> walkFn);
3740

3841
/// Return an instance of the given location type if one is nested under the

mlir/lib/AsmParser/Parser.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,8 @@ class OperationParser : public Parser {
631631

632632
/// Parse a location alias, that is a sequence looking like: #loc42
633633
/// The alias may have already be defined or may be defined later, in which
634-
/// case an OpaqueLoc is used a placeholder.
634+
/// case an OpaqueLoc is used a placeholder. The caller must ensure that the
635+
/// token is actually an alias, which means it must not contain a dot.
635636
ParseResult parseLocationAlias(LocationAttr &loc);
636637

637638
/// This is the structure of a result specifier in the assembly syntax,
@@ -1917,9 +1918,11 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
19171918

19181919
Token tok = parser.getToken();
19191920

1920-
// Check to see if we are parsing a location alias.
1921-
// Otherwise, we parse the location directly.
1922-
if (tok.is(Token::hash_identifier)) {
1921+
// Check to see if we are parsing a location alias. We are parsing a
1922+
// location alias if the token is a hash identifier *without* a dot in it -
1923+
// the dot signifies a dialect attribute. Otherwise, we parse the location
1924+
// directly.
1925+
if (tok.is(Token::hash_identifier) && !tok.getSpelling().contains('.')) {
19231926
if (parser.parseLocationAlias(directLoc))
19241927
return failure();
19251928
} else if (parser.parseLocationInstance(directLoc)) {
@@ -2086,11 +2089,9 @@ ParseResult OperationParser::parseLocationAlias(LocationAttr &loc) {
20862089
Token tok = getToken();
20872090
consumeToken(Token::hash_identifier);
20882091
StringRef identifier = tok.getSpelling().drop_front();
2089-
if (identifier.contains('.')) {
2090-
return emitError(tok.getLoc())
2091-
<< "expected location, but found dialect attribute: '#" << identifier
2092-
<< "'";
2093-
}
2092+
assert(!identifier.contains('.') &&
2093+
"unexpected dialect attribute token, expected alias");
2094+
20942095
if (state.asmState)
20952096
state.asmState->addAttrAliasUses(identifier, tok.getLocRange());
20962097

@@ -2120,10 +2121,11 @@ OperationParser::parseTrailingLocationSpecifier(OpOrArgument opOrArgument) {
21202121
return failure();
21212122
Token tok = getToken();
21222123

2123-
// Check to see if we are parsing a location alias.
2124-
// Otherwise, we parse the location directly.
2124+
// Check to see if we are parsing a location alias. We are parsing a location
2125+
// alias if the token is a hash identifier *without* a dot in it - the dot
2126+
// signifies a dialect attribute. Otherwise, we parse the location directly.
21252127
LocationAttr directLoc;
2126-
if (tok.is(Token::hash_identifier)) {
2128+
if (tok.is(Token::hash_identifier) && !tok.getSpelling().contains('.')) {
21272129
if (parseLocationAlias(directLoc))
21282130
return failure();
21292131
} else if (parseLocationInstance(directLoc)) {

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2064,6 +2064,11 @@ void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty,
20642064
[&](Location loc) { printLocationInternal(loc, pretty); },
20652065
[&]() { os << ", "; });
20662066
os << ']';
2067+
})
2068+
.Default([&](LocationAttr loc) {
2069+
// Assumes that this is a dialect-specific attribute and prints it
2070+
// directly.
2071+
printAttribute(loc);
20672072
});
20682073
}
20692074

mlir/lib/IR/Location.cpp

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -38,34 +38,20 @@ void BuiltinDialect::registerLocationAttributes() {
3838
//===----------------------------------------------------------------------===//
3939

4040
WalkResult LocationAttr::walk(function_ref<WalkResult(Location)> walkFn) {
41-
if (walkFn(*this).wasInterrupted())
42-
return WalkResult::interrupt();
43-
44-
return TypeSwitch<LocationAttr, WalkResult>(*this)
45-
.Case([&](CallSiteLoc callLoc) -> WalkResult {
46-
if (callLoc.getCallee()->walk(walkFn).wasInterrupted())
47-
return WalkResult::interrupt();
48-
return callLoc.getCaller()->walk(walkFn);
49-
})
50-
.Case([&](FusedLoc fusedLoc) -> WalkResult {
51-
for (Location subLoc : fusedLoc.getLocations())
52-
if (subLoc->walk(walkFn).wasInterrupted())
53-
return WalkResult::interrupt();
54-
return WalkResult::advance();
55-
})
56-
.Case([&](NameLoc nameLoc) -> WalkResult {
57-
return nameLoc.getChildLoc()->walk(walkFn);
58-
})
59-
.Case([&](OpaqueLoc opaqueLoc) -> WalkResult {
60-
return opaqueLoc.getFallbackLocation()->walk(walkFn);
61-
})
62-
.Default(WalkResult::advance());
41+
AttrTypeWalker walker;
42+
// Walk locations, but skip any other attribute.
43+
walker.addWalk([&](Attribute attr) {
44+
if (auto loc = llvm::dyn_cast<LocationAttr>(attr))
45+
return walkFn(loc);
46+
47+
return WalkResult::skip();
48+
});
49+
return walker.walk<WalkOrder::PreOrder>(*this);
6350
}
6451

6552
/// Methods for support type inquiry through isa, cast, and dyn_cast.
6653
bool LocationAttr::classof(Attribute attr) {
67-
return llvm::isa<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, OpaqueLoc,
68-
UnknownLoc>(attr);
54+
return attr.hasTrait<AttributeTrait::IsLocation>();
6955
}
7056

7157
//===----------------------------------------------------------------------===//

mlir/test/IR/invalid-locations.mlir

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,6 @@ func.func @location_fused_missing_r_square() {
9494

9595
// -----
9696

97-
func.func @location_invalid_alias() {
98-
// expected-error@+1 {{expected location, but found dialect attribute: '#foo.loc'}}
99-
return loc(#foo.loc)
100-
}
101-
102-
// -----
103-
10497
func.func @location_invalid_alias() {
10598
// expected-error@+1 {{operation location alias was never defined}}
10699
return loc(#invalid_alias)

mlir/test/IR/locations.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,10 @@ func.func @optional_location_specifier() {
8989
test.attr_with_loc("foo" loc("foo_loc"))
9090
return
9191
}
92+
93+
// CHECK-LABEL: @dialect_location
94+
// CHECK: test.attr_with_loc("dialectLoc" loc(#test.custom_location<"foo.mlir" * 32>))
95+
func.func @dialect_location() {
96+
test.attr_with_loc("dialectLoc" loc(#test.custom_location<"foo.mlir"*32>))
97+
return
98+
}

mlir/test/IR/pretty-locations.mlir

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ func.func @inline_notation() -> i32 {
2424
affine.if #set0(%2) {
2525
} loc(fused<"myPass">["foo", "foo2"])
2626

27+
// CHECK: "foo.op"() : () -> () #test.custom_location<"foo.mlir" * 1234>
28+
"foo.op"() : () -> () loc(#test.custom_location<"foo.mlir" * 1234>)
29+
2730
// CHECK: return %0 : i32 [unknown]
2831
return %1 : i32 loc(unknown)
2932
}

mlir/test/lib/Dialect/Test/TestAttrDefs.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ include "mlir/IR/OpAsmInterface.td"
2727
class Test_Attr<string name, list<Trait> traits = []>
2828
: AttrDef<Test_Dialect, name, traits>;
2929

30+
class Test_LocAttr<string name> : LocationAttrDef<Test_Dialect, name, []>;
31+
3032
def SimpleAttrA : Test_Attr<"SimpleA"> {
3133
let mnemonic = "smpla";
3234
}
@@ -377,4 +379,14 @@ def NestedPolynomialAttr2 : Test_Attr<"NestedPolynomialAttr2"> {
377379
}
378380

379381

382+
// Test custom location handling.
383+
def TestCustomLocationAttr : Test_LocAttr<"TestCustomLocation"> {
384+
let mnemonic = "custom_location";
385+
let parameters = (ins "mlir::StringAttr":$file, "unsigned":$line);
386+
387+
// Choose a silly separator token so we know it's hitting this code path
388+
// and not another.
389+
let assemblyFormat = "`<` $file `*` $line `>`";
390+
}
391+
380392
#endif // TEST_ATTRDEFS

mlir/unittests/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_mlir_unittest(MLIRIRTests
88
InterfaceTest.cpp
99
IRMapping.cpp
1010
InterfaceAttachmentTest.cpp
11+
LocationTest.cpp
1112
OperationSupportTest.cpp
1213
PatternMatchTest.cpp
1314
ShapedTypeTest.cpp

mlir/unittests/IR/LocationTest.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
//===- LocationTest.cpp - unit tests for affine map API -------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/IR/Location.h"
10+
#include "mlir/IR/Builders.h"
11+
#include "gtest/gtest.h"
12+
13+
using namespace mlir;
14+
15+
// Check that we only walk *locations* and not non-location attributes.
16+
TEST(LocationTest, Walk) {
17+
MLIRContext ctx;
18+
Builder builder(&ctx);
19+
BoolAttr trueAttr = builder.getBoolAttr(true);
20+
21+
Location loc1 = FileLineColLoc::get(builder.getStringAttr("foo"), 1, 2);
22+
Location loc2 = FileLineColLoc::get(builder.getStringAttr("foo"), 3, 4);
23+
Location fused = builder.getFusedLoc({loc1, loc2}, trueAttr);
24+
25+
SmallVector<Attribute> visited;
26+
fused->walk([&](Location l) {
27+
visited.push_back(LocationAttr(l));
28+
return WalkResult::advance();
29+
});
30+
31+
EXPECT_EQ(llvm::ArrayRef(visited), ArrayRef<Attribute>({fused, loc1, loc2}));
32+
}
33+
34+
// Check that we skip location attrs nested under a non-location attr.
35+
TEST(LocationTest, SkipNested) {
36+
MLIRContext ctx;
37+
Builder builder(&ctx);
38+
39+
Location loc1 = FileLineColLoc::get(builder.getStringAttr("foo"), 1, 2);
40+
Location loc2 = FileLineColLoc::get(builder.getStringAttr("foo"), 3, 4);
41+
Location loc3 = FileLineColLoc::get(builder.getStringAttr("bar"), 1, 2);
42+
Location loc4 = FileLineColLoc::get(builder.getStringAttr("bar"), 3, 4);
43+
ArrayAttr arr = builder.getArrayAttr({loc3, loc4});
44+
Location fused = builder.getFusedLoc({loc1, loc2}, arr);
45+
46+
SmallVector<Attribute> visited;
47+
fused->walk([&](Location l) {
48+
visited.push_back(LocationAttr(l));
49+
return WalkResult::advance();
50+
});
51+
52+
EXPECT_EQ(llvm::ArrayRef(visited), ArrayRef<Attribute>({fused, loc1, loc2}));
53+
}

0 commit comments

Comments
 (0)