Skip to content

[mlir] Add the ability to define dialect-specific location attrs. #105584

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir/IR/AttrTypeBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,12 @@ class AttrDef<Dialect dialect, string name, list<Trait> traits = [],
let predicate = CPred<"::llvm::isa<" # cppType # ">($_self)">;
}

// Provide a LocationAttrDef for dialects to provide their own locations
// that subclass LocationAttr.
class LocationAttrDef<Dialect dialect, string name, list<Trait> traits = []>
: AttrDef<dialect, name, traits # [NativeAttrTrait<"IsLocation">],
"::mlir::LocationAttr">;

// Define a new type, named `name`, belonging to `dialect` that inherits from
// the given C++ base class.
class TypeDef<Dialect dialect, string name, list<Trait> traits = [],
Expand Down
13 changes: 10 additions & 3 deletions mlir/include/mlir/IR/Attributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -322,12 +322,19 @@ class AttributeInterface
// Core AttributeTrait
//===----------------------------------------------------------------------===//

/// This trait is used to determine if an attribute is mutable or not. It is
/// attached on an attribute if the corresponding ImplType defines a `mutate`
/// function with proper signature.
namespace AttributeTrait {
/// This trait is used to determine if an attribute is mutable or not. It is
/// attached on an attribute if the corresponding ConcreteType defines a
/// `mutate` function with proper signature.
template <typename ConcreteType>
using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>;

/// This trait is used to determine if an attribute is a location or not. It is
/// attached to an attribute by the user if they intend the attribute to be used
/// as a location.
template <typename ConcreteType>
struct IsLocation : public AttributeTrait::TraitBase<ConcreteType, IsLocation> {
};
} // namespace AttributeTrait

} // namespace mlir.
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/IR/BuiltinLocationAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ include "mlir/IR/BuiltinDialect.td"

// Base class for Builtin dialect location attributes.
class Builtin_LocationAttr<string name, list<Trait> traits = []>
: AttrDef<Builtin_Dialect, name, traits, "::mlir::LocationAttr"> {
: LocationAttrDef<Builtin_Dialect, name, traits> {
let cppClassName = name;
let mnemonic = ?;
}
Expand Down
5 changes: 4 additions & 1 deletion mlir/include/mlir/IR/Location.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ class LocationAttr : public Attribute {
public:
using Attribute::Attribute;

/// Walk all of the locations nested under, and including, the current.
/// Walk all of the locations nested directly under, and including, the
/// current. This means that if a location is nested under a non-location
/// attribute, it will *not* be walked by this method. This walk is performed
/// in pre-order to get this behavior.
WalkResult walk(function_ref<WalkResult(Location)> walkFn);

/// Return an instance of the given location type if one is nested under the
Expand Down
26 changes: 14 additions & 12 deletions mlir/lib/AsmParser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,8 @@ class OperationParser : public Parser {

/// Parse a location alias, that is a sequence looking like: #loc42
/// The alias may have already be defined or may be defined later, in which
/// case an OpaqueLoc is used a placeholder.
/// case an OpaqueLoc is used a placeholder. The caller must ensure that the
/// token is actually an alias, which means it must not contain a dot.
ParseResult parseLocationAlias(LocationAttr &loc);

/// This is the structure of a result specifier in the assembly syntax,
Expand Down Expand Up @@ -1917,9 +1918,11 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {

Token tok = parser.getToken();

// Check to see if we are parsing a location alias.
// Otherwise, we parse the location directly.
if (tok.is(Token::hash_identifier)) {
// Check to see if we are parsing a location alias. We are parsing a
// location alias if the token is a hash identifier *without* a dot in it -
// the dot signifies a dialect attribute. Otherwise, we parse the location
// directly.
if (tok.is(Token::hash_identifier) && !tok.getSpelling().contains('.')) {
if (parser.parseLocationAlias(directLoc))
return failure();
} else if (parser.parseLocationInstance(directLoc)) {
Expand Down Expand Up @@ -2086,11 +2089,9 @@ ParseResult OperationParser::parseLocationAlias(LocationAttr &loc) {
Token tok = getToken();
consumeToken(Token::hash_identifier);
StringRef identifier = tok.getSpelling().drop_front();
if (identifier.contains('.')) {
return emitError(tok.getLoc())
<< "expected location, but found dialect attribute: '#" << identifier
<< "'";
}
assert(!identifier.contains('.') &&
"unexpected dialect attribute token, expected alias");

if (state.asmState)
state.asmState->addAttrAliasUses(identifier, tok.getLocRange());

Expand Down Expand Up @@ -2120,10 +2121,11 @@ OperationParser::parseTrailingLocationSpecifier(OpOrArgument opOrArgument) {
return failure();
Token tok = getToken();

// Check to see if we are parsing a location alias.
// Otherwise, we parse the location directly.
// Check to see if we are parsing a location alias. We are parsing a location
// alias if the token is a hash identifier *without* a dot in it - the dot
// signifies a dialect attribute. Otherwise, we parse the location directly.
LocationAttr directLoc;
if (tok.is(Token::hash_identifier)) {
if (tok.is(Token::hash_identifier) && !tok.getSpelling().contains('.')) {
if (parseLocationAlias(directLoc))
return failure();
} else if (parseLocationInstance(directLoc)) {
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2061,6 +2061,11 @@ void AsmPrinter::Impl::printLocationInternal(LocationAttr loc, bool pretty,
[&](Location loc) { printLocationInternal(loc, pretty); },
[&]() { os << ", "; });
os << ']';
})
.Default([&](LocationAttr loc) {
// Assumes that this is a dialect-specific attribute and prints it
// directly.
printAttribute(loc);
});
}

Expand Down
34 changes: 10 additions & 24 deletions mlir/lib/IR/Location.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,34 +38,20 @@ void BuiltinDialect::registerLocationAttributes() {
//===----------------------------------------------------------------------===//

WalkResult LocationAttr::walk(function_ref<WalkResult(Location)> walkFn) {
if (walkFn(*this).wasInterrupted())
return WalkResult::interrupt();

return TypeSwitch<LocationAttr, WalkResult>(*this)
.Case([&](CallSiteLoc callLoc) -> WalkResult {
if (callLoc.getCallee()->walk(walkFn).wasInterrupted())
return WalkResult::interrupt();
return callLoc.getCaller()->walk(walkFn);
})
.Case([&](FusedLoc fusedLoc) -> WalkResult {
for (Location subLoc : fusedLoc.getLocations())
if (subLoc->walk(walkFn).wasInterrupted())
return WalkResult::interrupt();
return WalkResult::advance();
})
.Case([&](NameLoc nameLoc) -> WalkResult {
return nameLoc.getChildLoc()->walk(walkFn);
})
.Case([&](OpaqueLoc opaqueLoc) -> WalkResult {
return opaqueLoc.getFallbackLocation()->walk(walkFn);
})
.Default(WalkResult::advance());
AttrTypeWalker walker;
// Walk locations, but skip any other attribute.
walker.addWalk([&](Attribute attr) {
if (auto loc = llvm::dyn_cast<LocationAttr>(attr))
return walkFn(loc);

return WalkResult::skip();
});
return walker.walk<WalkOrder::PreOrder>(*this);
}

/// Methods for support type inquiry through isa, cast, and dyn_cast.
bool LocationAttr::classof(Attribute attr) {
return llvm::isa<CallSiteLoc, FileLineColLoc, FusedLoc, NameLoc, OpaqueLoc,
UnknownLoc>(attr);
return attr.hasTrait<AttributeTrait::IsLocation>();
}

//===----------------------------------------------------------------------===//
Expand Down
7 changes: 0 additions & 7 deletions mlir/test/IR/invalid-locations.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,6 @@ func.func @location_fused_missing_r_square() {

// -----

func.func @location_invalid_alias() {
// expected-error@+1 {{expected location, but found dialect attribute: '#foo.loc'}}
return loc(#foo.loc)
}

// -----

func.func @location_invalid_alias() {
// expected-error@+1 {{operation location alias was never defined}}
return loc(#invalid_alias)
Expand Down
7 changes: 7 additions & 0 deletions mlir/test/IR/locations.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,10 @@ func.func @optional_location_specifier() {
test.attr_with_loc("foo" loc("foo_loc"))
return
}

// CHECK-LABEL: @dialect_location
// CHECK: test.attr_with_loc("dialectLoc" loc(#test.custom_location<"foo.mlir" * 32>))
func.func @dialect_location() {
test.attr_with_loc("dialectLoc" loc(#test.custom_location<"foo.mlir"*32>))
return
}
3 changes: 3 additions & 0 deletions mlir/test/IR/pretty-locations.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ func.func @inline_notation() -> i32 {
affine.if #set0(%2) {
} loc(fused<"myPass">["foo", "foo2"])

// CHECK: "foo.op"() : () -> () #test.custom_location<"foo.mlir" * 1234>
"foo.op"() : () -> () loc(#test.custom_location<"foo.mlir" * 1234>)

// CHECK: return %0 : i32 [unknown]
return %1 : i32 loc(unknown)
}
12 changes: 12 additions & 0 deletions mlir/test/lib/Dialect/Test/TestAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ include "mlir/IR/OpAsmInterface.td"
class Test_Attr<string name, list<Trait> traits = []>
: AttrDef<Test_Dialect, name, traits>;

class Test_LocAttr<string name> : LocationAttrDef<Test_Dialect, name, []>;

def SimpleAttrA : Test_Attr<"SimpleA"> {
let mnemonic = "smpla";
}
Expand Down Expand Up @@ -377,4 +379,14 @@ def NestedPolynomialAttr2 : Test_Attr<"NestedPolynomialAttr2"> {
}


// Test custom location handling.
def TestCustomLocationAttr : Test_LocAttr<"TestCustomLocation"> {
let mnemonic = "custom_location";
let parameters = (ins "mlir::StringAttr":$file, "unsigned":$line);

// Choose a silly separator token so we know it's hitting this code path
// and not another.
let assemblyFormat = "`<` $file `*` $line `>`";
}

#endif // TEST_ATTRDEFS
1 change: 1 addition & 0 deletions mlir/unittests/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_mlir_unittest(MLIRIRTests
InterfaceTest.cpp
IRMapping.cpp
InterfaceAttachmentTest.cpp
LocationTest.cpp
OperationSupportTest.cpp
PatternMatchTest.cpp
ShapedTypeTest.cpp
Expand Down
53 changes: 53 additions & 0 deletions mlir/unittests/IR/LocationTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
//===- LocationTest.cpp - unit tests for affine map API -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/IR/Location.h"
#include "mlir/IR/Builders.h"
#include "gtest/gtest.h"

using namespace mlir;

// Check that we only walk *locations* and not non-location attributes.
TEST(LocationTest, Walk) {
MLIRContext ctx;
Builder builder(&ctx);
BoolAttr trueAttr = builder.getBoolAttr(true);

Location loc1 = FileLineColLoc::get(builder.getStringAttr("foo"), 1, 2);
Location loc2 = FileLineColLoc::get(builder.getStringAttr("foo"), 3, 4);
Location fused = builder.getFusedLoc({loc1, loc2}, trueAttr);

SmallVector<Attribute> visited;
fused->walk([&](Location l) {
visited.push_back(LocationAttr(l));
return WalkResult::advance();
});

EXPECT_EQ(llvm::ArrayRef(visited), ArrayRef<Attribute>({fused, loc1, loc2}));
}

// Check that we skip location attrs nested under a non-location attr.
TEST(LocationTest, SkipNested) {
MLIRContext ctx;
Builder builder(&ctx);

Location loc1 = FileLineColLoc::get(builder.getStringAttr("foo"), 1, 2);
Location loc2 = FileLineColLoc::get(builder.getStringAttr("foo"), 3, 4);
Location loc3 = FileLineColLoc::get(builder.getStringAttr("bar"), 1, 2);
Location loc4 = FileLineColLoc::get(builder.getStringAttr("bar"), 3, 4);
ArrayAttr arr = builder.getArrayAttr({loc3, loc4});
Location fused = builder.getFusedLoc({loc1, loc2}, arr);

SmallVector<Attribute> visited;
fused->walk([&](Location l) {
visited.push_back(LocationAttr(l));
return WalkResult::advance();
});

EXPECT_EQ(llvm::ArrayRef(visited), ArrayRef<Attribute>({fused, loc1, loc2}));
}
Loading