Skip to content

[TableGen] Add a !listflatten operator to TableGen #109346

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions llvm/docs/TableGen/ProgRef.rst
Original file line number Diff line number Diff line change
Expand Up @@ -223,12 +223,12 @@ TableGen provides "bang operators" that have a wide variety of uses:
: !div !empty !eq !exists !filter
: !find !foldl !foreach !ge !getdagarg
: !getdagname !getdagop !gt !head !if
: !interleave !isa !le !listconcat !listremove
: !listsplat !logtwo !lt !mul !ne
: !not !or !range !repr !setdagarg
: !setdagname !setdagop !shl !size !sra
: !srl !strconcat !sub !subst !substr
: !tail !tolower !toupper !xor
: !interleave !isa !le !listconcat !listflatten
: !listremove !listsplat !logtwo !lt !mul
: !ne !not !or !range !repr
: !setdagarg !setdagname !setdagop !shl !size
: !sra !srl !strconcat !sub !subst
: !substr !tail !tolower !toupper !xor

The ``!cond`` operator has a slightly different
syntax compared to other bang operators, so it is defined separately:
Expand Down Expand Up @@ -1832,6 +1832,12 @@ and non-0 as true.
This operator concatenates the list arguments *list1*, *list2*, etc., and
produces the resulting list. The lists must have the same element type.

``!listflatten(``\ *list*\ ``)``
This operator flattens a list of lists *list* and produces a list with all
elements of the constituent lists concatenated. If *list* is of type
``list<list<X>>`` the resulting list is of type ``list<X>``. If *list*'s
element type is not a list, the result is *list* itself.

``!listremove(``\ *list1*\ ``,`` *list2*\ ``)``
This operator returns a copy of *list1* removing all elements that also occur in
*list2*. The lists must have the same element type.
Expand Down
3 changes: 2 additions & 1 deletion llvm/include/llvm/TableGen/Record.h
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,8 @@ class UnOpInit : public OpInit, public FoldingSetNode {
EMPTY,
GETDAGOP,
LOG2,
REPR
REPR,
LISTFLATTEN,
};

private:
Expand Down
29 changes: 29 additions & 0 deletions llvm/lib/TableGen/Record.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,32 @@ Init *UnOpInit::Fold(Record *CurRec, bool IsFinal) const {
}
}
break;

case LISTFLATTEN:
if (ListInit *LHSList = dyn_cast<ListInit>(LHS)) {
ListRecTy *InnerListTy = dyn_cast<ListRecTy>(LHSList->getElementType());
// list of non-lists, !listflatten() is a NOP.
if (!InnerListTy)
return LHS;

auto Flatten = [](ListInit *List) -> std::optional<std::vector<Init *>> {
std::vector<Init *> Flattened;
// Concatenate elements of all the inner lists.
for (Init *InnerInit : List->getValues()) {
ListInit *InnerList = dyn_cast<ListInit>(InnerInit);
if (!InnerList)
return std::nullopt;
for (Init *InnerElem : InnerList->getValues())
Flattened.push_back(InnerElem);
};
return Flattened;
};

auto Flattened = Flatten(LHSList);
if (Flattened)
return ListInit::get(*Flattened, InnerListTy->getElementType());
}
break;
}
return const_cast<UnOpInit *>(this);
}
Expand All @@ -1010,6 +1036,9 @@ std::string UnOpInit::getAsString() const {
case EMPTY: Result = "!empty"; break;
case GETDAGOP: Result = "!getdagop"; break;
case LOG2 : Result = "!logtwo"; break;
case LISTFLATTEN:
Result = "!listflatten";
break;
case REPR:
Result = "!repr";
break;
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/TableGen/TGLexer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,7 @@ tgtok::TokKind TGLexer::LexExclaim() {
.Case("foreach", tgtok::XForEach)
.Case("filter", tgtok::XFilter)
.Case("listconcat", tgtok::XListConcat)
.Case("listflatten", tgtok::XListFlatten)
.Case("listsplat", tgtok::XListSplat)
.Case("listremove", tgtok::XListRemove)
.Case("range", tgtok::XRange)
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/TableGen/TGLexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ enum TokKind {
XSRL,
XSHL,
XListConcat,
XListFlatten,
XListSplat,
XStrConcat,
XInterleave,
Expand Down
32 changes: 27 additions & 5 deletions llvm/lib/TableGen/TGParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1190,6 +1190,7 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
case tgtok::XNOT:
case tgtok::XToLower:
case tgtok::XToUpper:
case tgtok::XListFlatten:
case tgtok::XLOG2:
case tgtok::XHead:
case tgtok::XTail:
Expand Down Expand Up @@ -1235,6 +1236,11 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
Code = UnOpInit::NOT;
Type = IntRecTy::get(Records);
break;
case tgtok::XListFlatten:
Lex.Lex(); // eat the operation.
Code = UnOpInit::LISTFLATTEN;
Type = IntRecTy::get(Records); // Bogus type used here.
break;
case tgtok::XLOG2:
Lex.Lex(); // eat the operation
Code = UnOpInit::LOG2;
Expand Down Expand Up @@ -1309,7 +1315,8 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
}
}

if (Code == UnOpInit::HEAD || Code == UnOpInit::TAIL) {
if (Code == UnOpInit::HEAD || Code == UnOpInit::TAIL ||
Code == UnOpInit::LISTFLATTEN) {
ListInit *LHSl = dyn_cast<ListInit>(LHS);
TypedInit *LHSt = dyn_cast<TypedInit>(LHS);
if (!LHSl && !LHSt) {
Expand All @@ -1328,19 +1335,34 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
TokError("empty list argument in unary operator");
return nullptr;
}
bool UseElementType =
Code == UnOpInit::HEAD || Code == UnOpInit::LISTFLATTEN;
if (LHSl) {
Init *Item = LHSl->getElement(0);
TypedInit *Itemt = dyn_cast<TypedInit>(Item);
if (!Itemt) {
TokError("untyped list element in unary operator");
return nullptr;
}
Type = (Code == UnOpInit::HEAD) ? Itemt->getType()
: ListRecTy::get(Itemt->getType());
Type = UseElementType ? Itemt->getType()
: ListRecTy::get(Itemt->getType());
} else {
assert(LHSt && "expected list type argument in unary operator");
ListRecTy *LType = dyn_cast<ListRecTy>(LHSt->getType());
Type = (Code == UnOpInit::HEAD) ? LType->getElementType() : LType;
Type = UseElementType ? LType->getElementType() : LType;
}

// for !listflatten, we expect a list of lists, but also support a list of
// non-lists, where !listflatten will be a NOP.
if (Code == UnOpInit::LISTFLATTEN) {
ListRecTy *InnerListTy = dyn_cast<ListRecTy>(Type);
if (InnerListTy) {
// listflatten will convert list<list<X>> to list<X>.
Type = ListRecTy::get(InnerListTy->getElementType());
} else {
// If its a list of non-lists, !listflatten will be a NOP.
Type = ListRecTy::get(Type);
}
}
}

Expand Down Expand Up @@ -1378,7 +1400,7 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {

case tgtok::XExists: {
// Value ::= !exists '<' Type '>' '(' Value ')'
Lex.Lex(); // eat the operation
Lex.Lex(); // eat the operation.

RecTy *Type = ParseOperatorType();
if (!Type)
Expand Down
6 changes: 6 additions & 0 deletions llvm/test/TableGen/listflatten-error.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// RUN: not llvm-tblgen %s 2>&1 | FileCheck %s -DFILE=%s

// CHECK: [[FILE]]:[[@LINE+2]]:33: error: expected list type argument in unary operator
class Flatten<int A> {
list<int> F = !listflatten(A);
}
32 changes: 32 additions & 0 deletions llvm/test/TableGen/listflatten.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// RUN: llvm-tblgen %s | FileCheck %s

class Flatten<list<int> A, list<int> B> {
list<int> Flat1 = !listflatten([A, B, [6], [7, 8]]);

list<list<int>> X = [A, B];
list<int> Flat2 = !listflatten(!listconcat(X, [[7]]));

// Generate a nested list of integers.
list<int> Y0 = [1, 2, 3, 4];
list<list<int>> Y1 = !foreach(elem, Y0, [elem]);
list<list<list<int>>> Y2 = !foreach(elem, Y1, [elem]);
list<list<list<list<int>>>> Y3 = !foreach(elem, Y2, [elem]);

// Flatten it completely.
list<int> Flat3=!listflatten(!listflatten(!listflatten(Y3)));

// Flatten it partially.
list<list<list<int>>> Flat4 = !listflatten(Y3);
list<list<int>> Flat5 = !listflatten(!listflatten(Y3));

// Test NOP flattening.
list<string> Flat6 = !listflatten(["a", "b"]);
}

// CHECK: list<int> Flat1 = [1, 2, 3, 4, 5, 6, 7, 8];
// CHECK: list<int> Flat2 = [1, 2, 3, 4, 5, 7];
// CHECK: list<int> Flat3 = [1, 2, 3, 4];
// CHECK{LITERAL}: list<list<list<int>>> Flat4 = [[[1]], [[2]], [[3]], [[4]]];
// CHECK: list<string> Flat6 = ["a", "b"];
def F : Flatten<[1,2], [3,4,5]>;

Loading