diff --git a/llvm/include/llvm/TableGen/Record.h b/llvm/include/llvm/TableGen/Record.h --- a/llvm/include/llvm/TableGen/Record.h +++ b/llvm/include/llvm/TableGen/Record.h @@ -313,6 +313,7 @@ IK_FoldOpInit, IK_IsAOpInit, IK_ExistsOpInit, + IK_ApplyOpInit, IK_AnonymousNameInit, IK_StringInit, IK_VarInit, @@ -1491,6 +1492,41 @@ } }; +/// !apply(dag, params) - Apply params to function represented in dag. +class ApplyOpInit : public TypedInit, public FoldingSetNode { +private: + RecTy *ReturnType; + Init *Func; + SmallVector Params; + + ApplyOpInit(RecTy *ReturnType, Init *Func, SmallVector Params) + : TypedInit(IK_ApplyOpInit, ReturnType), ReturnType(ReturnType), + Func(Func), Params(Params) {} + + Init *parseDAG(Record *CurRec, DagInit *Dag) const; + +public: + ApplyOpInit(const ApplyOpInit &) = delete; + ApplyOpInit &operator=(const ApplyOpInit &) = delete; + + static bool classof(const Init *I) { return I->getKind() == IK_ApplyOpInit; } + + static ApplyOpInit *get(RecTy *ReturnType, Init *Func, + SmallVector Params); + + void Profile(FoldingSetNodeID &ID) const; + + // Fold - If possible, fold this to a simpler init. Return this if not + // possible to fold. + Init *Fold(Record *CurRec, bool IsFinal = false) const; + + Init *resolveReferences(Resolver &R) const override; + + Init *getBit(unsigned Bit) const override; + + std::string getAsString() const override; +}; + //===----------------------------------------------------------------------===// // High-Level Classes //===----------------------------------------------------------------------===// diff --git a/llvm/lib/TableGen/Record.cpp b/llvm/lib/TableGen/Record.cpp --- a/llvm/lib/TableGen/Record.cpp +++ b/llvm/lib/TableGen/Record.cpp @@ -30,6 +30,7 @@ #include "llvm/TableGen/Error.h" #include #include +#include #include #include #include @@ -81,6 +82,7 @@ FoldingSet TheFoldOpInitPool; FoldingSet TheIsAOpInitPool; FoldingSet TheExistsOpInitPool; + FoldingSet TheApplyOpInitPool; DenseMap, VarInit *> TheVarInitPool; DenseMap, VarBitInit *> TheVarBitInitPool; DenseMap, VarListElementInit *> @@ -1838,6 +1840,126 @@ .str(); } +static void ProfileApplyOpInit(FoldingSetNodeID &ID, RecTy *ReturnType, + Init *Dag, SmallVector Params) { + ID.AddPointer(ReturnType); + ID.AddPointer(Dag); + for (auto *Param : Params) + ID.AddPointer(Param); +} + +ApplyOpInit *ApplyOpInit::get(RecTy *ReturnType, Init *Dag, + SmallVector Params) { + FoldingSetNodeID ID; + ProfileApplyOpInit(ID, ReturnType, Dag, Params); + + detail::RecordKeeperImpl &RK = ReturnType->getRecordKeeper().getImpl(); + void *IP = nullptr; + if (ApplyOpInit *I = RK.TheApplyOpInitPool.FindNodeOrInsertPos(ID, IP)) + return I; + + ApplyOpInit *I = new (RK.Allocator) ApplyOpInit(ReturnType, Dag, Params); + RK.TheApplyOpInitPool.InsertNode(I, IP); + return I; +} + +void ApplyOpInit::Profile(FoldingSetNodeID &ID) const { + ProfileApplyOpInit(ID, ReturnType, Func, Params); +} + +Init *ApplyOpInit::parseDAG(Record *CurRec, DagInit *Dag) const { + SmallVector Args; + for (unsigned I = 0; I < Dag->getNumArgs(); I++) { + Init *Value = Dag->getArg(I); + StringRef Name = Dag->getArgNameStr(I); + UnsetInit *ValueUnset = dyn_cast(Value); + if (!ValueUnset && !Name.empty()) + PrintError(CurRec->getLoc(), Twine("The value and name of DAG in !apply " + "can't exist at the same time: ") + + Value->getAsUnquotedString() + ": " + + Name); + Init *Arg = Value; + // If the argument is refered to inputs. We replace it with input argument. + if (!Name.empty()) { + if (!Name.starts_with("a")) + PrintError(CurRec->getLoc(), + Twine("The argument should be prefixed with 'a': ") + Name); + int ArgIndex = atoi(Name.drop_front().str().c_str()); + Arg = Params[ArgIndex]; + } + // If the value is a DAG, we parse it recursively. + if (DagInit *ValueDag = dyn_cast(Value)) + Arg = parseDAG(CurRec, ValueDag); + Args.push_back(Arg); + } + + Init *Op = Dag->getOperator(); + std::string OpStr = Op->getAsUnquotedString(); + if (OpStr == "add") + return BinOpInit::get(BinOpInit::ADD, Args[0], Args[1], ReturnType) + ->Fold(CurRec); + if (OpStr == "sub") + return BinOpInit::get(BinOpInit::SUB, Args[0], Args[1], ReturnType) + ->Fold(CurRec); + if (OpStr == "and") + return BinOpInit::get(BinOpInit::AND, Args[0], Args[1], ReturnType) + ->Fold(CurRec); + if (OpStr == "or") + return BinOpInit::get(BinOpInit::OR, Args[0], Args[1], ReturnType) + ->Fold(CurRec); + if (OpStr == "xor") + return BinOpInit::get(BinOpInit::XOR, Args[0], Args[1], ReturnType) + ->Fold(CurRec); + if (OpStr == "srl") + return BinOpInit::get(BinOpInit::SRL, Args[0], Args[1], ReturnType) + ->Fold(CurRec); + if (OpStr == "sra") + return BinOpInit::get(BinOpInit::SRA, Args[0], Args[1], ReturnType) + ->Fold(CurRec); + if (OpStr == "shl") + return BinOpInit::get(BinOpInit::SHL, Args[0], Args[1], ReturnType) + ->Fold(CurRec); + + PrintError(CurRec->getLoc(), + Twine("Unsupported operator in !apply: ") + OpStr); + return UnsetInit::get(CurRec->getRecords()); +} + +Init *ApplyOpInit::Fold(Record *CurRec, bool IsFinal) const { + if (DagInit *Dag = dyn_cast(Func)) + return parseDAG(CurRec, Dag); + return const_cast(this); +} + +Init *ApplyOpInit::resolveReferences(Resolver &R) const { + Init *NewDag = Func->resolveReferences(R); + SmallVector NewParams; + for (auto *Param : Params) { + Init *NewParam = Param->resolveReferences(R); + NewParams.push_back(NewParam); + } + + ApplyOpInit *Ret = get(ReturnType, NewDag, NewParams); + if (isa(NewDag)) + return Ret->Fold(R.getCurrentRecord()); + + return const_cast(Ret); +} + +Init *ApplyOpInit::getBit(unsigned Bit) const { + return VarBitInit::get(const_cast(this), Bit); +} + +std::string ApplyOpInit::getAsString() const { + std::string Result = + "!apply<" + ReturnType->getAsString() + ">(" + Func->getAsString(); + for (auto *Param : Params) { + Result += ", "; + Result += Param->getAsString(); + } + return Result + ")"; +} + RecTy *TypedInit::getFieldType(StringInit *FieldName) const { if (RecordRecTy *RecordType = dyn_cast(getType())) { for (Record *Rec : RecordType->getClasses()) { diff --git a/llvm/lib/TableGen/TGLexer.h b/llvm/lib/TableGen/TGLexer.h --- a/llvm/lib/TableGen/TGLexer.h +++ b/llvm/lib/TableGen/TGLexer.h @@ -56,7 +56,7 @@ XSHL, XListConcat, XListSplat, XStrConcat, XInterleave, XSubstr, XFind, XCast, XSubst, XForEach, XFilter, XFoldl, XHead, XTail, XSize, XEmpty, XIf, XCond, XEq, XIsA, XDag, XNe, XLe, XLt, XGe, XGt, XSetDagOp, XGetDagOp, - XExists, XListRemove, XToLower, XToUpper, + XExists, XListRemove, XToLower, XToUpper, XApply, // Boolean literals. TrueVal, FalseVal, diff --git a/llvm/lib/TableGen/TGLexer.cpp b/llvm/lib/TableGen/TGLexer.cpp --- a/llvm/lib/TableGen/TGLexer.cpp +++ b/llvm/lib/TableGen/TGLexer.cpp @@ -594,6 +594,7 @@ .Case("exists", tgtok::XExists) .Case("tolower", tgtok::XToLower) .Case("toupper", tgtok::XToUpper) + .Case("apply", tgtok::XApply) .Default(tgtok::Error); return Kind != tgtok::Error ? Kind : ReturnError(Start-1, "Unknown operator"); diff --git a/llvm/lib/TableGen/TGParser.cpp b/llvm/lib/TableGen/TGParser.cpp --- a/llvm/lib/TableGen/TGParser.cpp +++ b/llvm/lib/TableGen/TGParser.cpp @@ -12,6 +12,7 @@ #include "TGParser.h" #include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/Twine.h" @@ -20,6 +21,7 @@ #include "llvm/Support/Compiler.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/TableGen/Record.h" #include #include #include @@ -955,6 +957,13 @@ CurRec->getNameInit() == Name) return UnOpInit::get(UnOpInit::CAST, Name, CurRec->getType()); + static std::string ApplySubroutines[] = {"add", "sub", "and", "or", + "xor", "srl", "sra", "shl"}; + if (find_if(ApplySubroutines, + [&](std::string &a) { return a == Name->getValue(); })) { + return Name; + } + Error(NameLoc.Start, "Variable not defined: '" + Name->getValue() + "'"); return nullptr; } @@ -1197,6 +1206,52 @@ return (ExistsOpInit::get(Type, Expr))->Fold(CurRec); } + case tgtok::XApply: { + // Value ::= !apply '<' Type '>' '(' dag [',' Value] ')' + Lex.Lex(); // eat the operation + + RecTy *Type = ParseOperatorType(); + if (!Type) + return nullptr; + + if (!consume(tgtok::l_paren)) { + TokError("expected '(' after type of !apply"); + return nullptr; + } + SMLoc DagLoc = Lex.getLoc(); + Init *Dag = ParseValue(CurRec); + if (!Dag) + return nullptr; + + TypedInit *TypedValue = dyn_cast(Dag); + if (!TypedValue) { + Error(DagLoc, "expected dag type argument in !apply operator"); + return nullptr; + } + + DagRecTy *DagType = dyn_cast(TypedValue->getType()); + if (!DagType) { + Error(DagLoc, "expected dag type argument in !apply operator"); + return nullptr; + } + + SmallVector Params; + if (consume(tgtok::comma)) { + for (;;) { + Params.push_back(ParseValue(CurRec)); + if (!consume(tgtok::comma)) + break; + } + } + + if (!consume(tgtok::r_paren)) { + TokError("expected ')' in operator"); + return nullptr; + } + + return (ApplyOpInit::get(Type, Dag, Params))->Fold(CurRec); + } + case tgtok::XConcat: case tgtok::XADD: case tgtok::XSUB: @@ -2520,6 +2575,7 @@ case tgtok::XFilter: case tgtok::XSubst: case tgtok::XSubstr: + case tgtok::XApply: case tgtok::XFind: { // Value ::= !ternop '(' Value ',' Value ',' Value ')' return ParseOperation(CurRec, ItemType); } diff --git a/llvm/test/TableGen/apply.td b/llvm/test/TableGen/apply.td new file mode 100644 --- /dev/null +++ b/llvm/test/TableGen/apply.td @@ -0,0 +1,46 @@ +// RUN: llvm-tblgen %s | FileCheck %s +// XFAIL: vg_leak + +class Dag{ + dag value = v; +} + +// CHECK: --- Defs --- +// CHECK: def A0 { +// CHECK: int add_result = 24; +// CHECK: int sub_result = 18; +// CHECK: int and_result = 1; +// CHECK: int or_result = 23; +// CHECK: int xor_result = 22; +// CHECK: int srl_result = 2; +// CHECK: int sra_result = 2; +// CHECK: int shl_result = 168; + +defvar add_dag = Dag<(add $a0, $a1)>.value; +defvar sub_dag = Dag<(sub $a0, $a1)>.value; +defvar xor_dag = !dag(xor, [?, ?], ["a0", "a1"]); +defvar srl_dag = !dag(srl, [?, ?], ["a0", "a1"]); + +class A { + int add_result = !apply(add_dag, a, b); + int sub_result = !apply(sub_dag, a, b); + int and_result = !apply(Dag<(and $a0, $a1)>.value, a, b); + int or_result = !apply(Dag<(or $a0, $a1)>.value, a, b); + int xor_result = !apply(xor_dag, a, b); + int srl_result = !apply(srl_dag, a, b); + int sra_result = !apply((sra $a0, $a1), a, b); + int shl_result = !apply((shl $a0, $a1), a, b); +} + +def A0 : A<21, 3>; + +// CHECK: def B0 { +// CHECK: int result = 254; + +class B { + int result = !apply(op.value, a, b); +} + +defvar c = 233; +// Capture variables in lexical scope. +def B0 : B, 21, 3>;