Skip to content

Commit 4086072

Browse files
committed
Reland "[mlir][linalg] Support parsing attributes in named op spec"
With this, now we can specify a list of attributes on named ops generated from the spec. The format is defined as ``` attr-id ::= bare-id (`?`)? attr-typedef ::= type (`[` `]`)? attr-def ::= attr-id `:` attr-typedef tc-attr-def ::= `attr` `(` attr-def-list `)` tc-def ::= `def` bare-id `(`tensor-def-list`)` `->` `(` tensor-def-list`)` (tc-attr-def)? ``` For example, ``` ods_def<SomeCppOp> def some_op(...) -> (...) attr( f32_attr: f32, i32_attr: i32, array_attr : f32[], optional_attr? : f32 ) ``` where `?` means optional attribute and `[]` means array type. Reviewed By: hanchung, nicolasvasilache Differential Revision: https://reviews.llvm.org/D94240
1 parent 3f7b4ce commit 4086072

File tree

2 files changed

+199
-4
lines changed

2 files changed

+199
-4
lines changed

mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,25 @@ ods_def<Test3Op> :
7272
def test3(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) {
7373
C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(k, n)));
7474
}
75+
76+
// Test attribute definitions
77+
// ODS-LABEL: def Test4Op
78+
// ODS: F32ArrayAttr:$array_attr,
79+
// ODS: F32:$f32_attr,
80+
// ODS: RankedF32ElementsAttr<[4]>:$fvec_attr,
81+
// ODS: I32:$i32_attr,
82+
// ODS: RankedI32ElementsAttr<[5, 6]>:$ivec_attr,
83+
// ODS: OptionalAttr<F32>:$optional_attr
84+
//
85+
ods_def<Test4Op> :
86+
def test4(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N))
87+
attr(
88+
f32_attr: f32,
89+
i32_attr: i32,
90+
fvec_attr: 4xf32,
91+
ivec_attr: 5x6xi32,
92+
array_attr : f32[],
93+
optional_attr? : f32
94+
) {
95+
C(b, m, n) = std_addf<k>(std_mulf(A(b, m, k), B(k, n)));
96+
}

mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp

Lines changed: 177 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,17 @@
2020
#include "mlir/Support/LLVM.h"
2121
#include "mlir/Support/LogicalResult.h"
2222
#include "llvm/ADT/SetVector.h"
23+
#include "llvm/ADT/SmallVector.h"
24+
#include "llvm/ADT/StringRef.h"
25+
#include "llvm/ADT/StringSwitch.h"
26+
#include "llvm/ADT/Twine.h"
2327
#include "llvm/Support/Casting.h"
2428
#include "llvm/Support/CommandLine.h"
2529
#include "llvm/Support/FormatVariadic.h"
2630
#include "llvm/Support/ToolOutputFile.h"
2731

32+
#include <map>
33+
2834
#define DEBUG_TYPE "linalg-ods-gen"
2935

3036
static llvm::cl::OptionCategory ODSGenCat("Linalg ODS Gen");
@@ -79,18 +85,22 @@ class Token {
7985
gt,
8086
l_brace,
8187
l_paren,
88+
l_square,
8289
lt,
8390
minus,
8491
plus,
92+
question,
8593
r_brace,
8694
r_paren,
95+
r_square,
8796
semicolon,
8897
star,
8998

9099
// Keywords.
91100
kw_def,
92101
FIRST_KEYWORD = kw_def,
93102
kw_ods_def,
103+
kw_attr_def,
94104
kw_floordiv,
95105
kw_ceildiv,
96106
kw_mod,
@@ -151,6 +161,10 @@ class Lexer {
151161
Token emitError(llvm::SMLoc loc, const Twine &msg);
152162
Token emitError(const char *loc, const Twine &msg);
153163

164+
/// Change the position of the lexer cursor. The next token we lex will start
165+
/// at the designated point in the input.
166+
void resetPointer(const char *newPtr) { curPtr = newPtr; }
167+
154168
private:
155169
Token formToken(Token::Kind kind, const char *tokStart) {
156170
return Token(kind, StringRef(tokStart, curPtr - tokStart));
@@ -247,10 +261,14 @@ Token Lexer::lexToken() {
247261
return formToken(Token::Kind::l_brace, tokStart);
248262
case '(':
249263
return formToken(Token::Kind::l_paren, tokStart);
264+
case '[':
265+
return formToken(Token::Kind::l_square, tokStart);
250266
case '}':
251267
return formToken(Token::Kind::r_brace, tokStart);
252268
case ')':
253269
return formToken(Token::Kind::r_paren, tokStart);
270+
case ']':
271+
return formToken(Token::Kind::r_square, tokStart);
254272
case '<':
255273
return formToken(Token::Kind::lt, tokStart);
256274
case '>':
@@ -263,6 +281,8 @@ Token Lexer::lexToken() {
263281
return formToken(Token::Kind::semicolon, tokStart);
264282
case '*':
265283
return formToken(Token::Kind::star, tokStart);
284+
case '?':
285+
return formToken(Token::Kind::question, tokStart);
266286
case '/':
267287
if (*curPtr == '/') {
268288
skipComment();
@@ -289,6 +309,7 @@ Token Lexer::lexIdentifier(const char *tokStart) {
289309
// Check to see if this identifier is a keyword.
290310
StringRef str(tokStart, curPtr - tokStart);
291311
Token::Kind kind = StringSwitch<Token::Kind>(str)
312+
.Case("attr", Token::Kind::kw_attr_def)
292313
.Case("def", Token::Kind::kw_def)
293314
.Case("ods_def", Token::Kind::kw_ods_def)
294315
.Case("floordiv", Token::Kind::kw_floordiv)
@@ -352,29 +373,40 @@ class Parser {
352373
"shouldn't advance past EOF or errors");
353374
curToken = lexer.lexToken();
354375
}
376+
355377
void consumeToken(Token::Kind kind) {
356378
assert(curToken.getKind() == kind && "unexpected token");
357379
curToken = lexer.lexToken();
358380
}
381+
359382
LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
360383
if (curToken.getKind() != kind)
361384
return emitError(curToken.getLoc(), msg);
362385
consumeToken();
363386
return success();
364387
}
388+
389+
/// Parses an optional token and returns failure if failed to parse.
390+
LogicalResult parseOptionalToken(Token::Kind kind) {
391+
return success(consumeIf(kind));
392+
}
393+
365394
LogicalResult emitError(llvm::SMLoc loc, const Twine &msg) {
366395
lexer.emitError(loc, msg);
367396
return failure();
368397
}
398+
369399
LogicalResult emitError(const Twine &msg) {
370400
return emitError(curToken.getLoc(), msg);
371401
}
402+
372403
bool consumeIf(Token::Kind kind) {
373404
if (curToken.isNot(kind))
374405
return false;
375406
consumeToken(kind);
376407
return true;
377408
}
409+
378410
LogicalResult
379411
parseCommaSeparatedList(llvm::function_ref<ParseResult()> parseElement) {
380412
// Non-empty case starts with an element.
@@ -388,6 +420,7 @@ class Parser {
388420
}
389421
return success();
390422
}
423+
391424
LogicalResult
392425
parseCommaSeparatedListUntil(Token::Kind rightToken,
393426
llvm::function_ref<ParseResult()> parseElement,
@@ -961,6 +994,8 @@ class TCParser {
961994
LogicalResult parseTensorUse(TensorUse &result,
962995
ComprehensionParsingState &state);
963996

997+
LogicalResult parseAttrDef();
998+
964999
/// Parses a tensor expression.
9651000
LogicalResult parseExpression(TensorUse currentDefinition,
9661001
std::unique_ptr<Expression> &result,
@@ -1010,15 +1045,29 @@ class TCParser {
10101045
unsigned index;
10111046
};
10121047

1048+
//===--------------------------------------------------------------------===//
1049+
// Internal bookkeeping of attributes.
1050+
//===--------------------------------------------------------------------===//
1051+
struct RegisteredAttr {
1052+
StringRef elementType;
1053+
SmallVector<uint64_t, 4> vectorDims;
1054+
bool isArray;
1055+
bool isOptional;
1056+
};
1057+
10131058
//===--------------------------------------------------------------------===//
10141059
// Per-TC def state.
10151060
//===--------------------------------------------------------------------===//
10161061
/// Symbols are per TC def.
10171062
AffineSymbolList symbols;
1063+
10181064
/// Tensors are per TC def.
10191065
llvm::StringMap<RegisteredTensor> registeredTensors;
10201066
unsigned nextRegisteredTensorIndex;
10211067

1068+
/// Attributes are per TC def.
1069+
std::map<std::string, RegisteredAttr> registeredAttrs;
1070+
10221071
Parser &parser;
10231072
};
10241073
} // namespace
@@ -1170,6 +1219,73 @@ LogicalResult TCParser::parseTensorUse(TensorUse &result,
11701219
return success();
11711220
}
11721221

1222+
/// Parse the information for an attribute def of the form:
1223+
///
1224+
/// affine-expr-list ::= affine-expr (`,` affine-expr )*
1225+
/// attr-id ::= bare-id (`?`)?
1226+
/// dim-list ::= (integer-literal 'x')+
1227+
/// attr-typedef ::= dim-list? type (`[` `]`)?
1228+
/// attr-def ::= attr-id `:` attr-typedef
1229+
LogicalResult TCParser::parseAttrDef() {
1230+
auto attrLoc = parser.curToken.getLoc();
1231+
StringRef attrName = parser.curToken.getSpelling();
1232+
if (failed(parser.parseToken(Token::Kind::id, "expected an id")))
1233+
return failure();
1234+
bool isOptional = succeeded(parser.parseOptionalToken(Token::Kind::question));
1235+
if (failed(parser.parseToken(Token::Kind::colon, "expected colon")))
1236+
return failure();
1237+
1238+
// Parse the attribute's type. We don't expect the type to be arbitrary
1239+
// complex, so just use this ad-hoc handling here.
1240+
1241+
// Parse potential dimension list
1242+
SmallVector<uint64_t, 4> vectorDims;
1243+
while (parser.curToken.is(Token::Kind::integer)) {
1244+
vectorDims.push_back(parser.curToken.getUInt64IntegerValue().getValue());
1245+
parser.consumeToken();
1246+
1247+
StringRef spelling = parser.curToken.getSpelling();
1248+
if (spelling[0] != 'x')
1249+
return parser.emitError(parser.curToken.getLoc(),
1250+
"expected 'x' in dimension list");
1251+
1252+
// If we had a prefix of 'x', lex the next token immediately after the 'x'.
1253+
if (spelling.size() != 1)
1254+
parser.lexer.resetPointer(spelling.data() + 1);
1255+
1256+
parser.consumeToken();
1257+
}
1258+
1259+
StringRef elementType = parser.curToken.getSpelling();
1260+
if (failed(parser.parseToken(Token::Kind::id, "expected an id")))
1261+
return failure();
1262+
1263+
bool isArray = false;
1264+
auto arrayLoc = parser.curToken.getLoc();
1265+
if (succeeded(parser.parseOptionalToken(Token::Kind::l_square))) {
1266+
isArray = true;
1267+
if (failed(parser.parseToken(Token::Kind::r_square, "expected ']'")))
1268+
return failure();
1269+
}
1270+
1271+
if (!vectorDims.empty() && isArray)
1272+
return parser.emitError(arrayLoc, "unsupported vector array attribute");
1273+
1274+
auto iterBoolPair = registeredAttrs.emplace(
1275+
attrName.str(),
1276+
RegisteredAttr{elementType, vectorDims, isArray, isOptional});
1277+
if (!iterBoolPair.second)
1278+
return parser.emitError(attrLoc,
1279+
"Failed to register attribute '" + attrName + "'");
1280+
1281+
LLVM_DEBUG(llvm::dbgs() << "Recorded: " << (isOptional ? "[optional]" : "")
1282+
<< " " << attrName << " "
1283+
<< "with type: " << elementType
1284+
<< (isArray ? "[]" : "") << "\n");
1285+
1286+
return success();
1287+
}
1288+
11731289
/// Parses a tensor expression of the form:
11741290
///
11751291
/// op-spec ::= bare-id `<` reduction-dims-list `>`
@@ -1341,10 +1457,13 @@ TCParser::parseOneComprehension(StringRef cppOpName, StringRef linalgOpName,
13411457
/// Parse and print the information for a ODS def.
13421458
///
13431459
/// tensor-def-list ::= tensor-def (`,` tensor-def )*
1460+
/// attr-def-list ::= attr-def (`,` attr-def )*
13441461
///
13451462
/// comprehension-list ::= comprehension comprehension*
13461463
///
1464+
/// tc-attr-def ::= `attr` `(` attr-def-list `)`
13471465
/// tc-def ::= `def` bare-id `(`tensor-def-list`)` `->` `(` tensor-def-list`)`
1466+
/// (tc-attr-def)?
13481467
/// `{` comprehension-list `}`
13491468
///
13501469
/// ods-def ::= `ods_def` `<` bare-id `>` `:` tc-def
@@ -1353,6 +1472,7 @@ TCParser::parseOneComprehension(StringRef cppOpName, StringRef linalgOpName,
13531472
/// contain only expressions involving symbols and constants), but can
13541473
/// otherwise contain arbitrary affine expressions.
13551474
LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
1475+
// Parse def header (including C++ op name)
13561476
if (failed(parser.parseToken(Token::Kind::kw_ods_def,
13571477
"expected 'ods_def' to define a TC ODS")) ||
13581478
failed(parser.parseToken(Token::Kind::lt, "expected '<'")))
@@ -1364,12 +1484,15 @@ LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
13641484
failed(parser.parseToken(Token::Kind::gt, "expected '>'")) ||
13651485
failed(parser.parseToken(Token::Kind::colon, "expected ':'")))
13661486
return failure();
1487+
13671488
if (failed(parser.parseToken(Token::Kind::kw_def,
13681489
"expected 'def' to define a TC")))
13691490
return failure();
13701491

13711492
StringRef tcName = parser.curToken.getSpelling();
13721493
LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing TC: " << tcName << "\n");
1494+
1495+
// Parse input/output tensor definitions
13731496
if (failed(parser.parseToken(Token::Kind::id, "expected id")) ||
13741497
failed(parser.parseToken(Token::Kind::l_paren, "expected '('")))
13751498
return failure();
@@ -1392,6 +1515,16 @@ LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
13921515
Token::Kind::r_paren, parseOutputDef, /*allowEmptyList=*/false)))
13931516
return failure();
13941517

1518+
// Parse optional attribute definitions
1519+
if (succeeded(parser.parseOptionalToken(Token::Kind::kw_attr_def))) {
1520+
if (failed(parser.parseToken(Token::Kind::l_paren, "expected '('")))
1521+
return failure();
1522+
if (failed(parser.parseCommaSeparatedListUntil(
1523+
Token::Kind::r_paren, std::bind(&TCParser::parseAttrDef, this),
1524+
/*allowEmptyList=*/false)))
1525+
return failure();
1526+
}
1527+
13951528
// Since we don't declare symbols separately, we discover them eagerly: each
13961529
// newly encountered id in a tensor shape expression is treated as a new
13971530
// symbolic. At this point, all tensors have been parsed and all the symbols
@@ -1450,12 +1583,52 @@ LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
14501583
void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
14511584
StringRef linalgOpName,
14521585
ComprehensionParsingState &state) {
1586+
SmallVector<std::string, 4> attributes;
1587+
for (const auto &attr : registeredAttrs) {
1588+
llvm::StringRef name = attr.first;
1589+
1590+
llvm::StringRef elementType = attr.second.elementType;
1591+
std::string odsType = llvm::StringSwitch<std::string>(elementType)
1592+
.Case("f32", "F32")
1593+
.Case("i32", "I32")
1594+
.Default("");
1595+
if (odsType.empty()) {
1596+
parser.emitError("unimplemented support for attribute element type: " +
1597+
elementType);
1598+
return;
1599+
}
1600+
1601+
const auto &dims = attr.second.vectorDims;
1602+
if (!dims.empty()) {
1603+
SmallVector<std::string, 4> dimStrs;
1604+
for (uint64_t dim : dims)
1605+
dimStrs.push_back(std::to_string(dim));
1606+
odsType = llvm::formatv("Ranked{0}ElementsAttr<[{1}]>", odsType,
1607+
llvm::join(dimStrs, ", "));
1608+
}
1609+
1610+
assert(dims.empty() || !attr.second.isArray);
1611+
if (attr.second.isArray)
1612+
odsType = llvm::formatv("{0}ArrayAttr", odsType);
1613+
1614+
if (attr.second.isOptional)
1615+
odsType = llvm::formatv("OptionalAttr<{0}>", odsType);
1616+
1617+
attributes.push_back(llvm::formatv("{0}:${1}", odsType, name));
1618+
}
1619+
1620+
std::string attrList = llvm::join(attributes, ",\n");
1621+
if (!attrList.empty())
1622+
attrList = ",\n" + attrList;
1623+
14531624
const char *header = R"FMT( def {0} : LinalgStructuredBase_Op<"{1}", [
14541625
AttrSizedOperandSegments,
14551626
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
14561627
SingleBlockImplicitTerminator<"YieldOp">]> {
1457-
let arguments = (ins Variadic<AnyShaped>:$inputs,
1458-
Variadic<AnyShaped>:$outputs);
1628+
let arguments = (ins
1629+
Variadic<AnyShaped>:$inputs,
1630+
Variadic<AnyShaped>:$outputs{4}
1631+
);
14591632
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
14601633
let regions = (region AnyRegion:$region);
14611634
@@ -1515,7 +1688,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
15151688
static std::function<void(Block &)> getRegionBuilder() {{ return regionBuilder; }
15161689
15171690
// Generic methods.
1518-
static unsigned getNumRegionArgs() {{ return {4}; }
1691+
static unsigned getNumRegionArgs() {{ return {5}; }
15191692
std::string getLibraryCallName() {{
15201693
return generateLibraryCallName(getOperation());
15211694
}
@@ -1531,7 +1704,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
15311704
}
15321705

15331706
os << llvm::formatv(header, cppOpName, linalgOpName, nInputs, nOutputs,
1534-
state.orderedTensorArgs.size());
1707+
attrList, state.orderedTensorArgs.size());
15351708
}
15361709

15371710
/// Print the C++ StructuredOpsInterface impl of `iterator_types`.

0 commit comments

Comments
 (0)