Index: llvm/docs/TableGen/ProgRef.rst =================================================================== --- llvm/docs/TableGen/ProgRef.rst +++ llvm/docs/TableGen/ProgRef.rst @@ -210,12 +210,12 @@ .. productionlist:: BangOperator: one of : !add !and !cast !con !dag - : !empty !eq !foldl !foreach !ge - : !getdagop !gt !head !if !interleave - : !isa !le !listconcat !listsplat !lt - : !mul !ne !not !or !setdagop - : !shl !size !sra !srl !strconcat - : !sub !subst !tail !xor + : !empty !eq !foldl !foreach !filter + : !ge !getdagop !gt !head !if + : !interleave !isa !le !listconcat !listsplat + : !lt !mul !ne !not !or + : !setdagop !shl !size !sra !srl + : !strconcat !sub !subst !tail !xor The ``!cond`` operator has a slightly different syntax compared to other bang operators, so it is defined separately: @@ -1563,6 +1563,17 @@ The arguments must be ``bit``, ``bits``, ``int``, or ``string`` values. Use ``!cast`` to compare other types of objects. +``!filter(``\ *var*\ ``,`` *list*\ ``,`` *predicate*\ ``)`` + + This operator creates a new ``list`` by filtering the elements in + *list*. To perform the filtering, TableGen binds the variable *var* to each + element and then evaluates the *predicate* expression, which presumably + refers to *var*. The predicate must + produce a boolean value (``bit``, ``bits``, or ``int``). The value is + interpreted as with ``!if``: + if the value is 0, the element is not included in the new list. If the value + is anything else, the element is included. + ``!foldl(``\ *init*\ ``,`` *list*\ ``,`` *acc*\ ``,`` *var*\ ``,`` *expr*\ ``)`` This operator performs a left-fold over the items in *list*. The variable *acc* acts as the accumulator and is initialized to *init*. @@ -1577,6 +1588,9 @@ int x = !foldl(0, RecList, total, rec, !add(total, rec.Number)); + If your goal is to filter the list and produce a new list that includes only + some of the elements, see ``!filter``. + ``!foreach(``\ *var*\ ``,`` *sequence*\ ``,`` *expr*\ ``)`` This operator creates a new ``list``/``dag`` in which each element is a function of the corresponding element in the *sequence* ``list``/``dag``. Index: llvm/include/llvm/IR/IntrinsicsNVVM.td =================================================================== --- llvm/include/llvm/IR/IntrinsicsNVVM.td +++ llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -225,10 +225,8 @@ ldst_bit_ab_ops, ldst_subint_cd_ops); // Separate A/B/C fragments (loads) from D (stores). - list all_ld_ops = !foldl([], all_ldst_ops, a, b, - !listconcat(a, !if(!eq(b.frag,"d"), [],[b]))); - list all_st_ops = !foldl([], all_ldst_ops, a, b, - !listconcat(a, !if(!eq(b.frag,"d"), [b],[]))); + list all_ld_ops = !filter(op, all_ldst_ops, !ne(op.frag, "d")); + list all_st_ops = !filter(op, all_ldst_ops, !eq(op.frag, "d")); } def NVVM_MMA_OPS : NVVM_MMA_OPS; Index: llvm/include/llvm/TableGen/Record.h =================================================================== --- llvm/include/llvm/TableGen/Record.h +++ llvm/include/llvm/TableGen/Record.h @@ -865,7 +865,7 @@ /// !op (X, Y, Z) - Combine two inits. class TernOpInit : public OpInit, public FoldingSetNode { public: - enum TernaryOp : uint8_t { SUBST, FOREACH, IF, DAG }; + enum TernaryOp : uint8_t { SUBST, FOREACH, FILTER, IF, DAG }; private: Init *LHS, *MHS, *RHS; Index: llvm/lib/TableGen/Record.cpp =================================================================== --- llvm/lib/TableGen/Record.cpp +++ llvm/lib/TableGen/Record.cpp @@ -1162,7 +1162,7 @@ ProfileTernOpInit(ID, getOpcode(), getLHS(), getMHS(), getRHS(), getType()); } -static Init *ForeachApply(Init *LHS, Init *MHSe, Init *RHS, Record *CurRec) { +static Init *ItemApply(Init *LHS, Init *MHSe, Init *RHS, Record *CurRec) { MapResolver R(CurRec); R.set(LHS, MHSe); return RHS->resolveReferences(R); @@ -1171,7 +1171,7 @@ static Init *ForeachDagApply(Init *LHS, DagInit *MHSd, Init *RHS, Record *CurRec) { bool Change = false; - Init *Val = ForeachApply(LHS, MHSd->getOperator(), RHS, CurRec); + Init *Val = ItemApply(LHS, MHSd->getOperator(), RHS, CurRec); if (Val != MHSd->getOperator()) Change = true; @@ -1184,7 +1184,7 @@ if (DagInit *Argd = dyn_cast(Arg)) NewArg = ForeachDagApply(LHS, Argd, RHS, CurRec); else - NewArg = ForeachApply(LHS, Arg, RHS, CurRec); + NewArg = ItemApply(LHS, Arg, RHS, CurRec); NewArgs.push_back(std::make_pair(NewArg, ArgName)); if (Arg != NewArg) @@ -1206,7 +1206,7 @@ SmallVector NewList(MHSl->begin(), MHSl->end()); for (Init *&Item : NewList) { - Init *NewItem = ForeachApply(LHS, Item, RHS, CurRec); + Init *NewItem = ItemApply(LHS, Item, RHS, CurRec); if (NewItem != Item) Item = NewItem; } @@ -1216,6 +1216,31 @@ return nullptr; } +// Evaluates RHS for all elements of MHS, using LHS as a temp variable. +// Creates a new list with the elements that evaluated to true. +static Init *FilterHelper(Init *LHS, Init *MHS, Init *RHS, RecTy *Type, + Record *CurRec) { + if (ListInit *MHSl = dyn_cast(MHS)) { + SmallVector NewList; + + for (Init *Item : MHSl->getValues()) { + Init *Include = ItemApply(LHS, Item, RHS, CurRec); + if (!Include) + return nullptr; + if (IntInit *IncludeInt = dyn_cast_or_null( + Include->convertInitializerTo(IntRecTy::get()))) { + if (IncludeInt->getValue()) + NewList.push_back(Item); + } else { + return nullptr; + } + } + return ListInit::get(NewList, cast(Type)->getElementType()); + } + + return nullptr; +} + Init *TernOpInit::Fold(Record *CurRec) const { switch (getOpcode()) { case SUBST: { @@ -1268,6 +1293,12 @@ break; } + case FILTER: { + if (Init *Result = FilterHelper(LHS, MHS, RHS, getType(), CurRec)) + return Result; + break; + } + case IF: { if (IntInit *LHSi = dyn_cast_or_null( LHS->convertInitializerTo(IntRecTy::get()))) { @@ -1322,7 +1353,7 @@ Init *mhs = MHS->resolveReferences(R); Init *rhs; - if (getOpcode() == FOREACH) { + if (getOpcode() == FOREACH || getOpcode() == FILTER) { ShadowResolver SR(R); SR.addShadow(lhs); rhs = RHS->resolveReferences(SR); @@ -1342,6 +1373,7 @@ switch (getOpcode()) { case SUBST: Result = "!subst"; break; case FOREACH: Result = "!foreach"; UnquotedLHS = true; break; + case FILTER: Result = "!filter"; UnquotedLHS = true; break; case IF: Result = "!if"; break; case DAG: Result = "!dag"; break; } Index: llvm/lib/TableGen/TGLexer.h =================================================================== --- llvm/lib/TableGen/TGLexer.h +++ llvm/lib/TableGen/TGLexer.h @@ -54,8 +54,8 @@ // Bang operators. XConcat, XADD, XSUB, XMUL, XNOT, XAND, XOR, XXOR, XSRA, XSRL, XSHL, XListConcat, XListSplat, XStrConcat, XInterleave, XCast, XSubst, XForEach, - XFoldl, XHead, XTail, XSize, XEmpty, XIf, XCond, XEq, XIsA, XDag, XNe, - XLe, XLt, XGe, XGt, XSetDagOp, XGetDagOp, + XFilter, XFoldl, XHead, XTail, XSize, XEmpty, XIf, XCond, XEq, XIsA, + XDag, XNe, XLe, XLt, XGe, XGt, XSetDagOp, XGetDagOp, // Boolean literals. TrueVal, FalseVal, Index: llvm/lib/TableGen/TGLexer.cpp =================================================================== --- llvm/lib/TableGen/TGLexer.cpp +++ llvm/lib/TableGen/TGLexer.cpp @@ -584,6 +584,7 @@ .Case("subst", tgtok::XSubst) .Case("foldl", tgtok::XFoldl) .Case("foreach", tgtok::XForEach) + .Case("filter", tgtok::XFilter) .Case("listconcat", tgtok::XListConcat) .Case("listsplat", tgtok::XListSplat) .Case("strconcat", tgtok::XStrConcat) Index: llvm/lib/TableGen/TGParser.h =================================================================== --- llvm/lib/TableGen/TGParser.h +++ llvm/lib/TableGen/TGParser.h @@ -254,6 +254,7 @@ TypedInit *FirstItem = nullptr); RecTy *ParseType(); Init *ParseOperation(Record *CurRec, RecTy *ItemType); + Init *ParseOperationForEachFilter(Record *CurRec, RecTy *ItemType); Init *ParseOperationCond(Record *CurRec, RecTy *ItemType); RecTy *ParseOperatorType(); Init *ParseObjectName(MultiClass *CurMultiClass); Index: llvm/lib/TableGen/TGParser.cpp =================================================================== --- llvm/lib/TableGen/TGParser.cpp +++ llvm/lib/TableGen/TGParser.cpp @@ -1343,114 +1343,9 @@ return nullptr; } - case tgtok::XForEach: { - // Value ::= !foreach '(' Id ',' Value ',' Value ')' - SMLoc OpLoc = Lex.getLoc(); - Lex.Lex(); // eat the operation - if (Lex.getCode() != tgtok::l_paren) { - TokError("expected '(' after !foreach"); - return nullptr; - } - - if (Lex.Lex() != tgtok::Id) { // eat the '(' - TokError("first argument of !foreach must be an identifier"); - return nullptr; - } - - Init *LHS = StringInit::get(Lex.getCurStrVal()); - Lex.Lex(); - - if (CurRec && CurRec->getValue(LHS)) { - TokError((Twine("iteration variable '") + LHS->getAsString() + - "' already defined") - .str()); - return nullptr; - } - - if (!consume(tgtok::comma)) { // eat the id - TokError("expected ',' in ternary operator"); - return nullptr; - } - - Init *MHS = ParseValue(CurRec); - if (!MHS) - return nullptr; - - if (!consume(tgtok::comma)) { - TokError("expected ',' in ternary operator"); - return nullptr; - } - - TypedInit *MHSt = dyn_cast(MHS); - if (!MHSt) { - TokError("could not get type of !foreach input"); - return nullptr; - } - - RecTy *InEltType = nullptr; - RecTy *OutEltType = nullptr; - bool IsDAG = false; - - if (ListRecTy *InListTy = dyn_cast(MHSt->getType())) { - InEltType = InListTy->getElementType(); - if (ItemType) { - if (ListRecTy *OutListTy = dyn_cast(ItemType)) { - OutEltType = OutListTy->getElementType(); - } else { - Error(OpLoc, - "expected value of type '" + Twine(ItemType->getAsString()) + - "', but got !foreach of list type"); - return nullptr; - } - } - } else if (DagRecTy *InDagTy = dyn_cast(MHSt->getType())) { - InEltType = InDagTy; - if (ItemType && !isa(ItemType)) { - Error(OpLoc, - "expected value of type '" + Twine(ItemType->getAsString()) + - "', but got !foreach of dag type"); - return nullptr; - } - IsDAG = true; - } else { - TokError("!foreach must have list or dag input"); - return nullptr; - } - - // We need to create a temporary record to provide a scope for the - // iteration variable. - std::unique_ptr ParseRecTmp; - Record *ParseRec = CurRec; - if (!ParseRec) { - ParseRecTmp = std::make_unique(".parse", ArrayRef{}, Records); - ParseRec = ParseRecTmp.get(); - } - - ParseRec->addValue(RecordVal(LHS, InEltType, false)); - Init *RHS = ParseValue(ParseRec, OutEltType); - ParseRec->removeValue(LHS); - if (!RHS) - return nullptr; - - if (!consume(tgtok::r_paren)) { - TokError("expected ')' in binary operator"); - return nullptr; - } - - RecTy *OutType; - if (IsDAG) { - OutType = InEltType; - } else { - TypedInit *RHSt = dyn_cast(RHS); - if (!RHSt) { - TokError("could not get type of !foreach result"); - return nullptr; - } - OutType = RHSt->getType()->getListTy(); - } - - return (TernOpInit::get(TernOpInit::FOREACH, LHS, MHS, RHS, OutType)) - ->Fold(CurRec); + case tgtok::XForEach: + case tgtok::XFilter: { + return ParseOperationForEachFilter(CurRec, ItemType); } case tgtok::XDag: @@ -1746,6 +1641,130 @@ return Type; } +/// Parse the !foreach and !filter operations. Return null on error. +/// +/// ForEach ::= !foreach(ID, list-or-dag, expr) => list +/// Filter ::= !foreach(ID, list, predicate) ==> list +Init *TGParser::ParseOperationForEachFilter(Record *CurRec, RecTy *ItemType) { + SMLoc OpLoc = Lex.getLoc(); + tgtok::TokKind Operation = Lex.getCode(); + Lex.Lex(); // eat the operation + if (Lex.getCode() != tgtok::l_paren) { + TokError("expected '(' after !foreach/!filter"); + return nullptr; + } + + if (Lex.Lex() != tgtok::Id) { // eat the '(' + TokError("first argument of !foreach/!filter must be an identifier"); + return nullptr; + } + + Init *LHS = StringInit::get(Lex.getCurStrVal()); + Lex.Lex(); // eat the ID. + + if (CurRec && CurRec->getValue(LHS)) { + TokError((Twine("iteration variable '") + LHS->getAsString() + + "' is already defined") + .str()); + return nullptr; + } + + if (!consume(tgtok::comma)) { + TokError("expected ',' in !foreach/!filter"); + return nullptr; + } + + Init *MHS = ParseValue(CurRec); + if (!MHS) + return nullptr; + + if (!consume(tgtok::comma)) { + TokError("expected ',' in !foreach/!filter"); + return nullptr; + } + + TypedInit *MHSt = dyn_cast(MHS); + if (!MHSt) { + TokError("could not get type of !foreach/!filter list or dag"); + return nullptr; + } + + RecTy *InEltType = nullptr; + RecTy *ExprEltType = nullptr; + bool IsDAG = false; + + if (ListRecTy *InListTy = dyn_cast(MHSt->getType())) { + InEltType = InListTy->getElementType(); + if (ItemType) { + if (ListRecTy *OutListTy = dyn_cast(ItemType)) { + ExprEltType = (Operation == tgtok::XForEach) + ? OutListTy->getElementType() : IntRecTy::get(); + } else { + Error(OpLoc, + "expected value of type '" + Twine(ItemType->getAsString()) + + "', but got list type"); + return nullptr; + } + } + } else if (DagRecTy *InDagTy = dyn_cast(MHSt->getType())) { + if (Operation == tgtok::XFilter) { + TokError("!filter must have a list argument"); + return nullptr; + } + InEltType = InDagTy; + if (ItemType && !isa(ItemType)) { + Error(OpLoc, + "expected value of type '" + Twine(ItemType->getAsString()) + + "', but got dag type"); + return nullptr; + } + IsDAG = true; + } else { + if (Operation == tgtok::XForEach) + TokError("!foreach must have a list or dag argument"); + else + TokError("!filter must have a list argument"); + return nullptr; + } + + // We need to create a temporary record to provide a scope for the + // iteration variable. + std::unique_ptr ParseRecTmp; + Record *ParseRec = CurRec; + if (!ParseRec) { + ParseRecTmp = std::make_unique(".parse", ArrayRef{}, Records); + ParseRec = ParseRecTmp.get(); + } + + ParseRec->addValue(RecordVal(LHS, InEltType, false)); + Init *RHS = ParseValue(ParseRec, ExprEltType); + ParseRec->removeValue(LHS); + if (!RHS) + return nullptr; + + if (!consume(tgtok::r_paren)) { + TokError("expected ')' in !foreach/!filter"); + return nullptr; + } + + RecTy *OutType = InEltType; + if (Operation == tgtok::XForEach && !IsDAG) { + TypedInit *RHSt = dyn_cast(RHS); + if (!RHSt) { + TokError("could not get type of !foreach result expression"); + return nullptr; + } + OutType = RHSt->getType()->getListTy(); + } else if (Operation == tgtok::XFilter) { + OutType = InEltType->getListTy(); + } + + return (TernOpInit::get((Operation == tgtok::XForEach) ? TernOpInit::FOREACH + : TernOpInit::FILTER, + LHS, MHS, RHS, OutType)) + ->Fold(CurRec); +} + Init *TGParser::ParseOperationCond(Record *CurRec, RecTy *ItemType) { Lex.Lex(); // eat the operation 'cond' @@ -2169,6 +2188,7 @@ case tgtok::XCond: case tgtok::XFoldl: case tgtok::XForEach: + case tgtok::XFilter: case tgtok::XSubst: { // Value ::= !ternop '(' Value ',' Value ',' Value ')' return ParseOperation(CurRec, ItemType); } Index: llvm/lib/Target/AMDGPU/AMDGPUInstructions.td =================================================================== --- llvm/lib/Target/AMDGPU/AMDGPUInstructions.td +++ llvm/lib/Target/AMDGPU/AMDGPUInstructions.td @@ -83,9 +83,8 @@ // Add a predicate to the list if does not already exist to deduplicate it. class PredConcat lst, Predicate pred> { list ret = - !foldl([pred], lst, acc, cur, - !listconcat(acc, !if(!eq(!cast(cur),!cast(pred)), - [], [cur]))); + !listconcat([pred], !filter(item, lst, + !ne(!cast(item), !cast(pred)))); } class PredicateControl { Index: llvm/lib/Target/AMDGPU/MIMGInstructions.td =================================================================== --- llvm/lib/Target/AMDGPU/MIMGInstructions.td +++ llvm/lib/Target/AMDGPU/MIMGInstructions.td @@ -180,11 +180,12 @@ let Key = ["Opcode"]; } -// This is a separate class so that TableGen memoizes the computations. +// This class used to use !foldl to memoize the AddrAsmNames list. +// It turned out that that was much slower than using !filter. class MIMGNSAHelper { list AddrAsmNames = - !foldl([], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], lhs, i, - !if(!lt(i, num_addrs), !listconcat(lhs, ["vaddr"#!size(lhs)]), lhs)); + !foreach(i, !filter(i, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + !lt(i, num_addrs)), "vaddr" # i); dag AddrIns = !dag(ins, !foreach(arg, AddrAsmNames, VGPR_32), AddrAsmNames); string AddrAsm = "[$" # !interleave(AddrAsmNames, ", $") # "]"; Index: llvm/lib/Target/AMDGPU/SIRegisterInfo.td =================================================================== --- llvm/lib/Target/AMDGPU/SIRegisterInfo.td +++ llvm/lib/Target/AMDGPU/SIRegisterInfo.td @@ -17,9 +17,7 @@ 24, 25, 26, 27, 28, 29, 30, 31]; // Returns list of indexes [0..N) - list slice = - !foldl([], all, acc, cur, - !listconcat(acc, !if(!lt(cur, N), [cur], []))); + list slice = !filter(i, all, !lt(i, N)); } let Namespace = "AMDGPU" in { Index: llvm/test/TableGen/filter.td =================================================================== --- /dev/null +++ llvm/test/TableGen/filter.td @@ -0,0 +1,75 @@ +// RUN: llvm-tblgen %s | FileCheck %s +// RUN: not llvm-tblgen -DERROR1 %s 2>&1 | FileCheck --check-prefix=ERROR1 %s + +defvar EmptyList = []; +defvar OneList = ["foo"]; +defvar WordList = ["foo", "bar", "zoo", "foo", "snork", "snork", "quux"]; + +class Predicate; +def pred1 : Predicate; +def pred2 : Predicate; +def pred3 : Predicate; +def pred4 : Predicate; +def pred5 : Predicate; + +class DeduplicatePredList predlist, Predicate pred> { + list ret = + !listconcat([pred], !filter(item, predlist, + !ne(!cast(item), !cast(pred)))); +} + +// CHECK: def rec1 +// CHECK: list1 = []; +// CHECK: list2 = []; +// CHECK: list3 = ["foo"]; +// CHECK: list4 = []; +// CHECK: list5 = ["foo", "bar", "zoo", "foo", "snork", "snork", "quux"]; +// CHECK: list6 = []; + +def rec1 { + list list1 = !filter(item, EmptyList, true); + list list2 = !filter(item, EmptyList, false); + list list3 = !filter(item, OneList, true); + list list4 = !filter(item, OneList, false); + list list5 = !filter(item, WordList, true); + list list6 = !filter(item, WordList, false); +} + +// CHECK: def rec2 +// CHECK: list1 = ["foo", "foo"]; +// CHECK: list2 = ["bar", "zoo", "snork", "snork", "quux"]; +// CHECK: list3 = ["snork", "snork", "quux"]; + +def rec2 { + list list1 = !filter(item, WordList, !eq(item, "foo")); + list list2 = !filter(item, WordList, !ne(item, "foo")); + list list3 = !filter(item, WordList, !ge(!size(item), 4)); +} + +// CHECK: def rec3 +// CHECK: list1 = [4, 5, 6, 7, 8, 9, 10]; +// CHECK: list2 = [4, 5, 6, 7, 8]; + +def rec3 { + list list1 = !filter(num, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], !gt(num, 3)); + list list2 = !filter(num, list1, !lt(num, 9)); +} + +// CHECK: def rec4 +// CHECK: duplist = [pred1, pred2, pred1, pred3, pred4, pred1, pred5]; +// CHECK: deduplist = [pred1, pred2, pred3, pred4, pred5]; + +def rec4 { + list duplist = [pred1, pred2, pred1, pred3, pred4, pred1, pred5]; + list deduplist = DeduplicatePredList.ret; +} + +#ifdef ERROR1 + +// ERROR1: could not be fully resolved + +def rec9 { + list list1 = !filter(item, WordList, !if(true, "oops!", "wrong!")); +} + +#endif