20
20
#include " mlir/Support/LLVM.h"
21
21
#include " mlir/Support/LogicalResult.h"
22
22
#include " llvm/ADT/SetVector.h"
23
+ #include " llvm/ADT/SmallVector.h"
24
+ #include " llvm/ADT/StringRef.h"
25
+ #include " llvm/ADT/StringSwitch.h"
26
+ #include " llvm/ADT/Twine.h"
23
27
#include " llvm/Support/Casting.h"
24
28
#include " llvm/Support/CommandLine.h"
25
29
#include " llvm/Support/FormatVariadic.h"
26
30
#include " llvm/Support/ToolOutputFile.h"
27
31
32
+ #include < map>
33
+
28
34
#define DEBUG_TYPE " linalg-ods-gen"
29
35
30
36
static llvm::cl::OptionCategory ODSGenCat (" Linalg ODS Gen" );
@@ -79,18 +85,22 @@ class Token {
79
85
gt,
80
86
l_brace,
81
87
l_paren,
88
+ l_square,
82
89
lt,
83
90
minus,
84
91
plus,
92
+ question,
85
93
r_brace,
86
94
r_paren,
95
+ r_square,
87
96
semicolon,
88
97
star,
89
98
90
99
// Keywords.
91
100
kw_def,
92
101
FIRST_KEYWORD = kw_def,
93
102
kw_ods_def,
103
+ kw_attr_def,
94
104
kw_floordiv,
95
105
kw_ceildiv,
96
106
kw_mod,
@@ -151,6 +161,10 @@ class Lexer {
151
161
Token emitError (llvm::SMLoc loc, const Twine &msg);
152
162
Token emitError (const char *loc, const Twine &msg);
153
163
164
+ // / Change the position of the lexer cursor. The next token we lex will start
165
+ // / at the designated point in the input.
166
+ void resetPointer (const char *newPtr) { curPtr = newPtr; }
167
+
154
168
private:
155
169
Token formToken (Token::Kind kind, const char *tokStart) {
156
170
return Token (kind, StringRef (tokStart, curPtr - tokStart));
@@ -247,10 +261,14 @@ Token Lexer::lexToken() {
247
261
return formToken (Token::Kind::l_brace, tokStart);
248
262
case ' (' :
249
263
return formToken (Token::Kind::l_paren, tokStart);
264
+ case ' [' :
265
+ return formToken (Token::Kind::l_square, tokStart);
250
266
case ' }' :
251
267
return formToken (Token::Kind::r_brace, tokStart);
252
268
case ' )' :
253
269
return formToken (Token::Kind::r_paren, tokStart);
270
+ case ' ]' :
271
+ return formToken (Token::Kind::r_square, tokStart);
254
272
case ' <' :
255
273
return formToken (Token::Kind::lt, tokStart);
256
274
case ' >' :
@@ -263,6 +281,8 @@ Token Lexer::lexToken() {
263
281
return formToken (Token::Kind::semicolon, tokStart);
264
282
case ' *' :
265
283
return formToken (Token::Kind::star, tokStart);
284
+ case ' ?' :
285
+ return formToken (Token::Kind::question, tokStart);
266
286
case ' /' :
267
287
if (*curPtr == ' /' ) {
268
288
skipComment ();
@@ -289,6 +309,7 @@ Token Lexer::lexIdentifier(const char *tokStart) {
289
309
// Check to see if this identifier is a keyword.
290
310
StringRef str (tokStart, curPtr - tokStart);
291
311
Token::Kind kind = StringSwitch<Token::Kind>(str)
312
+ .Case (" attr" , Token::Kind::kw_attr_def)
292
313
.Case (" def" , Token::Kind::kw_def)
293
314
.Case (" ods_def" , Token::Kind::kw_ods_def)
294
315
.Case (" floordiv" , Token::Kind::kw_floordiv)
@@ -352,29 +373,40 @@ class Parser {
352
373
" shouldn't advance past EOF or errors" );
353
374
curToken = lexer.lexToken ();
354
375
}
376
+
355
377
void consumeToken (Token::Kind kind) {
356
378
assert (curToken.getKind () == kind && " unexpected token" );
357
379
curToken = lexer.lexToken ();
358
380
}
381
+
359
382
LogicalResult parseToken (Token::Kind kind, const Twine &msg) {
360
383
if (curToken.getKind () != kind)
361
384
return emitError (curToken.getLoc (), msg);
362
385
consumeToken ();
363
386
return success ();
364
387
}
388
+
389
+ // / Parses an optional token and returns failure if failed to parse.
390
+ LogicalResult parseOptionalToken (Token::Kind kind) {
391
+ return success (consumeIf (kind));
392
+ }
393
+
365
394
LogicalResult emitError (llvm::SMLoc loc, const Twine &msg) {
366
395
lexer.emitError (loc, msg);
367
396
return failure ();
368
397
}
398
+
369
399
LogicalResult emitError (const Twine &msg) {
370
400
return emitError (curToken.getLoc (), msg);
371
401
}
402
+
372
403
bool consumeIf (Token::Kind kind) {
373
404
if (curToken.isNot (kind))
374
405
return false ;
375
406
consumeToken (kind);
376
407
return true ;
377
408
}
409
+
378
410
LogicalResult
379
411
parseCommaSeparatedList (llvm::function_ref<ParseResult()> parseElement) {
380
412
// Non-empty case starts with an element.
@@ -388,6 +420,7 @@ class Parser {
388
420
}
389
421
return success ();
390
422
}
423
+
391
424
LogicalResult
392
425
parseCommaSeparatedListUntil (Token::Kind rightToken,
393
426
llvm::function_ref<ParseResult()> parseElement,
@@ -961,6 +994,8 @@ class TCParser {
961
994
LogicalResult parseTensorUse (TensorUse &result,
962
995
ComprehensionParsingState &state);
963
996
997
+ LogicalResult parseAttrDef ();
998
+
964
999
// / Parses a tensor expression.
965
1000
LogicalResult parseExpression (TensorUse currentDefinition,
966
1001
std::unique_ptr<Expression> &result,
@@ -1010,15 +1045,29 @@ class TCParser {
1010
1045
unsigned index;
1011
1046
};
1012
1047
1048
+ // ===--------------------------------------------------------------------===//
1049
+ // Internal bookkeeping of attributes.
1050
+ // ===--------------------------------------------------------------------===//
1051
+ struct RegisteredAttr {
1052
+ StringRef elementType;
1053
+ SmallVector<uint64_t , 4 > vectorDims;
1054
+ bool isArray;
1055
+ bool isOptional;
1056
+ };
1057
+
1013
1058
// ===--------------------------------------------------------------------===//
1014
1059
// Per-TC def state.
1015
1060
// ===--------------------------------------------------------------------===//
1016
1061
// / Symbols are per TC def.
1017
1062
AffineSymbolList symbols;
1063
+
1018
1064
// / Tensors are per TC def.
1019
1065
llvm::StringMap<RegisteredTensor> registeredTensors;
1020
1066
unsigned nextRegisteredTensorIndex;
1021
1067
1068
+ // / Attributes are per TC def.
1069
+ std::map<std::string, RegisteredAttr> registeredAttrs;
1070
+
1022
1071
Parser &parser;
1023
1072
};
1024
1073
} // namespace
@@ -1170,6 +1219,73 @@ LogicalResult TCParser::parseTensorUse(TensorUse &result,
1170
1219
return success ();
1171
1220
}
1172
1221
1222
+ // / Parse the information for an attribute def of the form:
1223
+ // /
1224
+ // / affine-expr-list ::= affine-expr (`,` affine-expr )*
1225
+ // / attr-id ::= bare-id (`?`)?
1226
+ // / dim-list ::= (integer-literal 'x')+
1227
+ // / attr-typedef ::= dim-list? type (`[` `]`)?
1228
+ // / attr-def ::= attr-id `:` attr-typedef
1229
+ LogicalResult TCParser::parseAttrDef () {
1230
+ auto attrLoc = parser.curToken .getLoc ();
1231
+ StringRef attrName = parser.curToken .getSpelling ();
1232
+ if (failed (parser.parseToken (Token::Kind::id, " expected an id" )))
1233
+ return failure ();
1234
+ bool isOptional = succeeded (parser.parseOptionalToken (Token::Kind::question));
1235
+ if (failed (parser.parseToken (Token::Kind::colon, " expected colon" )))
1236
+ return failure ();
1237
+
1238
+ // Parse the attribute's type. We don't expect the type to be arbitrary
1239
+ // complex, so just use this ad-hoc handling here.
1240
+
1241
+ // Parse potential dimension list
1242
+ SmallVector<uint64_t , 4 > vectorDims;
1243
+ while (parser.curToken .is (Token::Kind::integer)) {
1244
+ vectorDims.push_back (parser.curToken .getUInt64IntegerValue ().getValue ());
1245
+ parser.consumeToken ();
1246
+
1247
+ StringRef spelling = parser.curToken .getSpelling ();
1248
+ if (spelling[0 ] != ' x' )
1249
+ return parser.emitError (parser.curToken .getLoc (),
1250
+ " expected 'x' in dimension list" );
1251
+
1252
+ // If we had a prefix of 'x', lex the next token immediately after the 'x'.
1253
+ if (spelling.size () != 1 )
1254
+ parser.lexer .resetPointer (spelling.data () + 1 );
1255
+
1256
+ parser.consumeToken ();
1257
+ }
1258
+
1259
+ StringRef elementType = parser.curToken .getSpelling ();
1260
+ if (failed (parser.parseToken (Token::Kind::id, " expected an id" )))
1261
+ return failure ();
1262
+
1263
+ bool isArray = false ;
1264
+ auto arrayLoc = parser.curToken .getLoc ();
1265
+ if (succeeded (parser.parseOptionalToken (Token::Kind::l_square))) {
1266
+ isArray = true ;
1267
+ if (failed (parser.parseToken (Token::Kind::r_square, " expected ']'" )))
1268
+ return failure ();
1269
+ }
1270
+
1271
+ if (!vectorDims.empty () && isArray)
1272
+ return parser.emitError (arrayLoc, " unsupported vector array attribute" );
1273
+
1274
+ auto iterBoolPair = registeredAttrs.emplace (
1275
+ attrName.str (),
1276
+ RegisteredAttr{elementType, vectorDims, isArray, isOptional});
1277
+ if (!iterBoolPair.second )
1278
+ return parser.emitError (attrLoc,
1279
+ " Failed to register attribute '" + attrName + " '" );
1280
+
1281
+ LLVM_DEBUG (llvm::dbgs () << " Recorded: " << (isOptional ? " [optional]" : " " )
1282
+ << " " << attrName << " "
1283
+ << " with type: " << elementType
1284
+ << (isArray ? " []" : " " ) << " \n " );
1285
+
1286
+ return success ();
1287
+ }
1288
+
1173
1289
// / Parses a tensor expression of the form:
1174
1290
// /
1175
1291
// / op-spec ::= bare-id `<` reduction-dims-list `>`
@@ -1341,10 +1457,13 @@ TCParser::parseOneComprehension(StringRef cppOpName, StringRef linalgOpName,
1341
1457
// / Parse and print the information for a ODS def.
1342
1458
// /
1343
1459
// / tensor-def-list ::= tensor-def (`,` tensor-def )*
1460
+ // / attr-def-list ::= attr-def (`,` attr-def )*
1344
1461
// /
1345
1462
// / comprehension-list ::= comprehension comprehension*
1346
1463
// /
1464
+ // / tc-attr-def ::= `attr` `(` attr-def-list `)`
1347
1465
// / tc-def ::= `def` bare-id `(`tensor-def-list`)` `->` `(` tensor-def-list`)`
1466
+ // / (tc-attr-def)?
1348
1467
// / `{` comprehension-list `}`
1349
1468
// /
1350
1469
// / ods-def ::= `ods_def` `<` bare-id `>` `:` tc-def
@@ -1353,6 +1472,7 @@ TCParser::parseOneComprehension(StringRef cppOpName, StringRef linalgOpName,
1353
1472
// / contain only expressions involving symbols and constants), but can
1354
1473
// / otherwise contain arbitrary affine expressions.
1355
1474
LogicalResult TCParser::parseAndEmitODSDef (llvm::raw_ostream &os) {
1475
+ // Parse def header (including C++ op name)
1356
1476
if (failed (parser.parseToken (Token::Kind::kw_ods_def,
1357
1477
" expected 'ods_def' to define a TC ODS" )) ||
1358
1478
failed (parser.parseToken (Token::Kind::lt, " expected '<'" )))
@@ -1364,12 +1484,15 @@ LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
1364
1484
failed (parser.parseToken (Token::Kind::gt, " expected '>'" )) ||
1365
1485
failed (parser.parseToken (Token::Kind::colon, " expected ':'" )))
1366
1486
return failure ();
1487
+
1367
1488
if (failed (parser.parseToken (Token::Kind::kw_def,
1368
1489
" expected 'def' to define a TC" )))
1369
1490
return failure ();
1370
1491
1371
1492
StringRef tcName = parser.curToken .getSpelling ();
1372
1493
LLVM_DEBUG (llvm::dbgs () << " \n\n Start parsing TC: " << tcName << " \n " );
1494
+
1495
+ // Parse input/output tensor definitions
1373
1496
if (failed (parser.parseToken (Token::Kind::id, " expected id" )) ||
1374
1497
failed (parser.parseToken (Token::Kind::l_paren, " expected '('" )))
1375
1498
return failure ();
@@ -1392,6 +1515,16 @@ LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
1392
1515
Token::Kind::r_paren, parseOutputDef, /* allowEmptyList=*/ false )))
1393
1516
return failure ();
1394
1517
1518
+ // Parse optional attribute definitions
1519
+ if (succeeded (parser.parseOptionalToken (Token::Kind::kw_attr_def))) {
1520
+ if (failed (parser.parseToken (Token::Kind::l_paren, " expected '('" )))
1521
+ return failure ();
1522
+ if (failed (parser.parseCommaSeparatedListUntil (
1523
+ Token::Kind::r_paren, std::bind (&TCParser::parseAttrDef, this ),
1524
+ /* allowEmptyList=*/ false )))
1525
+ return failure ();
1526
+ }
1527
+
1395
1528
// Since we don't declare symbols separately, we discover them eagerly: each
1396
1529
// newly encountered id in a tensor shape expression is treated as a new
1397
1530
// symbolic. At this point, all tensors have been parsed and all the symbols
@@ -1450,12 +1583,52 @@ LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
1450
1583
void TCParser::printODS (llvm::raw_ostream &os, StringRef cppOpName,
1451
1584
StringRef linalgOpName,
1452
1585
ComprehensionParsingState &state) {
1586
+ SmallVector<std::string, 4 > attributes;
1587
+ for (const auto &attr : registeredAttrs) {
1588
+ llvm::StringRef name = attr.first ;
1589
+
1590
+ llvm::StringRef elementType = attr.second .elementType ;
1591
+ std::string odsType = llvm::StringSwitch<std::string>(elementType)
1592
+ .Case (" f32" , " F32" )
1593
+ .Case (" i32" , " I32" )
1594
+ .Default (" " );
1595
+ if (odsType.empty ()) {
1596
+ parser.emitError (" unimplemented support for attribute element type: " +
1597
+ elementType);
1598
+ return ;
1599
+ }
1600
+
1601
+ const auto &dims = attr.second .vectorDims ;
1602
+ if (!dims.empty ()) {
1603
+ SmallVector<std::string, 4 > dimStrs;
1604
+ for (uint64_t dim : dims)
1605
+ dimStrs.push_back (std::to_string (dim));
1606
+ odsType = llvm::formatv (" Ranked{0}ElementsAttr<[{1}]>" , odsType,
1607
+ llvm::join (dimStrs, " , " ));
1608
+ }
1609
+
1610
+ assert (dims.empty () || !attr.second .isArray );
1611
+ if (attr.second .isArray )
1612
+ odsType = llvm::formatv (" {0}ArrayAttr" , odsType);
1613
+
1614
+ if (attr.second .isOptional )
1615
+ odsType = llvm::formatv (" OptionalAttr<{0}>" , odsType);
1616
+
1617
+ attributes.push_back (llvm::formatv (" {0}:${1}" , odsType, name));
1618
+ }
1619
+
1620
+ std::string attrList = llvm::join (attributes, " ,\n " );
1621
+ if (!attrList.empty ())
1622
+ attrList = " ,\n " + attrList;
1623
+
1453
1624
const char *header = R"FMT( def {0} : LinalgStructuredBase_Op<"{1}", [
1454
1625
AttrSizedOperandSegments,
1455
1626
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
1456
1627
SingleBlockImplicitTerminator<"YieldOp">]> {
1457
- let arguments = (ins Variadic<AnyShaped>:$inputs,
1458
- Variadic<AnyShaped>:$outputs);
1628
+ let arguments = (ins
1629
+ Variadic<AnyShaped>:$inputs,
1630
+ Variadic<AnyShaped>:$outputs{4}
1631
+ );
1459
1632
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
1460
1633
let regions = (region AnyRegion:$region);
1461
1634
@@ -1515,7 +1688,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
1515
1688
static std::function<void(Block &)> getRegionBuilder() {{ return regionBuilder; }
1516
1689
1517
1690
// Generic methods.
1518
- static unsigned getNumRegionArgs() {{ return {4 }; }
1691
+ static unsigned getNumRegionArgs() {{ return {5 }; }
1519
1692
std::string getLibraryCallName() {{
1520
1693
return generateLibraryCallName(getOperation());
1521
1694
}
@@ -1531,7 +1704,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
1531
1704
}
1532
1705
1533
1706
os << llvm::formatv (header, cppOpName, linalgOpName, nInputs, nOutputs,
1534
- state.orderedTensorArgs .size ());
1707
+ attrList, state.orderedTensorArgs .size ());
1535
1708
}
1536
1709
1537
1710
// / Print the C++ StructuredOpsInterface impl of `iterator_types`.
0 commit comments