diff --git a/llvm/docs/TableGen/ProgRef.rst b/llvm/docs/TableGen/ProgRef.rst --- a/llvm/docs/TableGen/ProgRef.rst +++ b/llvm/docs/TableGen/ProgRef.rst @@ -153,7 +153,7 @@ The following are the basic punctuation tokens:: - - + [ ] { } ( ) < > : ; . ... = ? # + - + [ ] { } ( ) < > : ; . ... = ? # ' Literals -------- @@ -202,10 +202,11 @@ identifiers:: assert bit bits class code - dag def else false foreach - defm defset defvar field if - in include int let list - multiclass string then true + dag def defm defset defvar + else false field foreach function + if in include int let + list multiclass return string then + true .. warning:: The ``field`` reserved word is deprecated, except when used with the @@ -297,6 +298,11 @@ angle brackets. The element type is arbitrary; it can even be another list type. List elements are indexed from 0. +``function<>``\ *rettype*\ ``[,`` *argtype*\ ``]*>`` + This type represents a function who returns a value of the *rettype* and may + have several argument types *argtype* (can be empty). The return type and + argument type are arbitrary; it can even be another function type. + ``dag`` This type represents a nestable directed acyclic graph (DAG) of nodes. Each node has an *operator* and zero or more *arguments* (or *operands*). @@ -336,12 +342,14 @@ :| `Value` "#" [`Value`] ValueSuffix: "{" `RangeList` "}" :| "[" `RangeList` "]" + :| "'"? "(" `ValueList`? ")" :| "." `TokIdentifier` RangeList: `RangePiece` ("," `RangePiece`)* RangePiece: `TokInteger` :| `TokInteger` "..." `TokInteger` :| `TokInteger` "-" `TokInteger` :| `TokInteger` `TokInteger` + ValueList: `Value` ("," `Value`)* .. warning:: The peculiar last form of :token:`RangePiece` is due to the fact that the @@ -490,6 +498,11 @@ takes a list of pairs of arguments separated by colons. See `Appendix A: Bang Operators`_ for a description of each bang operator. +.. productionlist:: + SimpleValue10: "function" `Function` + +This form creates a new anonymous function definition and the value is a reference +to the function. See `function --- define a function`_ for more details. Suffixed values --------------- @@ -516,6 +529,11 @@ Elements may be included multiple times and in any order. This is the result only when more than one element is specified. +*value*\ ``(a, ...)`` + The final value is the result of calling *value* with arguments *(a, ...)*. + *value* should be callable, which means its final value should be a reference + to a function. + *value*\ ``.``\ *field* The final value is the value of the specified *field* in the specified record *value*. @@ -554,8 +572,8 @@ .. productionlist:: TableGenFile: (`Statement` | `IncludeDirective` :| `PreprocessorDirective`)* - Statement: `Assert` | `Class` | `Def` | `Defm` | `Defset` | `Defvar` - :| `Foreach` | `If` | `Let` | `MultiClass` + Statement: `Assert` | `Class` | `Function` | `Def` | `Defm` | `Defset` + :| `Defvar` | `Foreach` | `If` | `Let` | `MultiClass` The following sections describe each of these top-level statements. @@ -568,8 +586,8 @@ .. productionlist:: Class: "class" `ClassID` [`TemplateArgList`] `RecordBody` - TemplateArgList: "<" `TemplateArgDecl` ("," `TemplateArgDecl`)* ">" - TemplateArgDecl: `Type` `TokIdentifier` ["=" `Value`] + TemplateArgList: "<" `Declaration` ("," `Declaration`)* ">" + Declaration: `Type` `TokIdentifier` ["=" `Value`] A class can be parameterized by a list of "template arguments," whose values can be used in the class's record body. These template arguments are @@ -663,6 +681,40 @@ expanded with the template arguments before being merged into ``C2``. +``function`` --- define a function +--------------------------------------------- + +A ``function`` statement defines an function that can be called. + +.. productionlist:: + FunctionDefinition: "function" FunctionID Function + FunctionID: `TokIdentifier` + Function: `FunctionArgList` ":" `ReturnType` `FunctionBody` + FunctionArgList: "(" (`Declaration` ("," `Declaration`)*)? ")" + ReturnType: `Type` + FunctionBody ::= "{" `FunctionBodyItem`+ "}" + FunctionBodyItem: `Return` + :| `Defvar` + :| `Assert` + Return: "return" `Value` ";" + +A function can be parameterized by a list of "function arguments" whose values +can be used in the function's body. + +If a function argument is not assigned a default value with ``=``, it is +uninitialized (has the "value" ``?``) and must be specified in the function +argument list when the function is called (required argument). If an +argument is assigned a default value, then it need not be specified in the +argument list (optional argument). In the declaration, all required function +arguments must precede any optional arguments. The function argument default +values are evaluated from left to right. + +The name of function is treated as a function reference to itself. + +The function can be called with ``funcname(args)``. However, there is a grammar +ambiguity between DAG and function call. To distinguish them, calls inside +DAGs should add a single quote before arguments, like ``funcname'(args)``. + .. _def: ``def`` --- define a concrete record @@ -1498,6 +1550,26 @@ bit ValidSize = isValidSize.ret; } +.. note:: + Although we can achieve the same goal, but we encourage the user to use + function instead. For example above, you can rewrite it to function: + +.. code-block:: text + + function isValidSize(int size): bit { + return !cond(!eq(size, 1): 1, + !eq(size, 2): 1, + !eq(size, 4): 1, + !eq(size, 8): 1, + !eq(size, 16): 1, + true: 0); + } + + def Data1 { + int Size = ...; + bit ValidSize = isValidSize(Size); + } + Preprocessing Facilities ======================== diff --git a/llvm/include/llvm/TableGen/Record.h b/llvm/include/llvm/TableGen/Record.h --- a/llvm/include/llvm/TableGen/Record.h +++ b/llvm/include/llvm/TableGen/Record.h @@ -65,7 +65,8 @@ StringRecTyKind, ListRecTyKind, DagRecTyKind, - RecordRecTyKind + RecordRecTyKind, + FunctionTyKind }; private: @@ -269,6 +270,62 @@ bool typeIsA(const RecTy *RHS) const override; }; +/// 'function' - Type of functions that have zero or more +/// arguments. +class FunctionRecTy final : public RecTy, + public FoldingSetNode, + public TrailingObjects { + friend class Record; + friend detail::RecordKeeperImpl; + + RecTy *ReturnTy; + unsigned NumArgs; + + explicit FunctionRecTy(RecordKeeper &RK, RecTy *ReturnTy, unsigned NumArgs) + : RecTy(FunctionTyKind, RK), ReturnTy(ReturnTy), NumArgs(NumArgs) {} + +public: + FunctionRecTy(const FunctionRecTy &) = delete; + FunctionRecTy &operator=(const FunctionRecTy &) = delete; + + // Do not use sized deallocation due to trailing objects. + void operator delete(void *p) { ::operator delete(p); } + + static bool classof(const RecTy *RT) { + return RT->getRecTyKind() == FunctionTyKind; + } + + static FunctionRecTy *get(RecordKeeper &RK, RecTy *ReturnTy, + ArrayRef ArgTypes); + static FunctionRecTy *get(RecordKeeper &RK, RecTy *ReturnTy); + + void Profile(FoldingSetNodeID &ID) const; + + RecTy *getReturnType() const { return ReturnTy; } + + RecTy *getArg(unsigned i) const { + assert(i < NumArgs && "Argument index out of range!"); + return getTrailingObjects()[i]; + } + + ArrayRef args() const { + return ArrayRef(getTrailingObjects(), NumArgs); + } + + using const_ty_iterator = RecTy *const *; + const_ty_iterator args_begin() const { return args().begin(); } + const_ty_iterator args_end() const { return args().end(); } + + size_t args_size() const { return NumArgs; } + bool args_empty() const { return NumArgs == 0; } + + std::string getAsString() const override; + + bool typeIsConvertibleTo(const RecTy *RHS) const override; + + bool typeIsA(const RecTy *RHS) const override; +}; + /// Find a common type that T1 and T2 convert to. /// Return 0 if no such type exists. RecTy *resolveTypes(RecTy *T1, RecTy *T2); @@ -317,6 +374,8 @@ IK_VarListElementInit, IK_VarBitInit, IK_VarDefInit, + IK_FuncCallInit, + IK_FuncRefInit, IK_LastTypedInit, IK_UnsetInit }; @@ -1356,6 +1415,103 @@ } }; +using Capture = std::pair; +/// Represent a reference to function. +class FuncRefInit final : public TypedInit, + public FoldingSetNode, + public TrailingObjects { + Record *Func; + unsigned NumCaptures; + + explicit FuncRefInit(Record *Func, unsigned N); + +public: + FuncRefInit(const FuncRefInit &) = delete; + FuncRefInit &operator=(const FuncRefInit &) = delete; + + static bool classof(const Init *I) { return I->getKind() == IK_FuncRefInit; } + static FuncRefInit *get(Record *Func, ArrayRef Captures); + static FuncRefInit *get(Record *Func); + + void Profile(FoldingSetNodeID &ID) const; + + Record *getFunction() const { return Func; } + RecTy *getReturnType() const; + std::string getAsString() const override; + + bool isConcrete() const override { return true; } + + Init *resolveReferences(Resolver &R) const override; + + using const_iterator = Capture const *; + + const_iterator captures_begin() const { + return getTrailingObjects(); + } + const_iterator captures_end() const { return captures_begin() + NumCaptures; } + + size_t captures_size() const { return NumCaptures; } + bool captures_empty() const { return NumCaptures == 0; } + + ArrayRef captures() const { + return ArrayRef(captures_begin(), NumCaptures); + } + + Init *getBit(unsigned Bit) const override { + llvm_unreachable("Illegal bit reference of function"); + } +}; + +/// funcname(fargs...) - Represent function call to `funcname`. +class FuncCallInit final : public TypedInit, + public FoldingSetNode, + public TrailingObjects { + Init *FuncRef; + FunctionRecTy *FuncTy; + unsigned NumArgs; + SMLoc CallLoc; + Init *Ret = nullptr; // after evaluation + + explicit FuncCallInit(Init *FuncRef, FunctionRecTy *FuncTy, unsigned N, + SMLoc CallLoc); + + Init *evaluate(); + +public: + FuncCallInit(const FuncCallInit &) = delete; + FuncCallInit &operator=(const FuncCallInit &) = delete; + + static bool classof(const Init *I) { return I->getKind() == IK_FuncCallInit; } + static FuncCallInit *get(Init *FuncRef, FunctionRecTy *FuncTy, + ArrayRef Args, SMLoc CallLoc); + + void Profile(FoldingSetNodeID &ID) const; + + Init *resolveReferences(Resolver &R) const override; + Init *Fold() const; + + std::string getAsString() const override; + + Init *getArg(unsigned i) const { + assert(i < NumArgs && "Argument index out of range!"); + return getTrailingObjects()[i]; + } + + using const_iterator = Init *const *; + + const_iterator args_begin() const { return getTrailingObjects(); } + const_iterator args_end() const { return args_begin() + NumArgs; } + + size_t args_size() const { return NumArgs; } + bool args_empty() const { return NumArgs == 0; } + + ArrayRef args() const { return ArrayRef(args_begin(), NumArgs); } + + Init *getBit(unsigned Bit) const override { + llvm_unreachable("Illegal bit reference of function call"); + } +}; + /// X.Y - Represent a reference to a subfield of a variable class FieldInit : public TypedInit { Init *Rec; // Record we are referring to @@ -1503,12 +1659,14 @@ FK_Normal, // A normal record field. FK_NonconcreteOK, // A field that can be nonconcrete ('field' keyword). FK_TemplateArg, // A template argument. + FK_FunctionArg, // A function argument. + FK_ReturnValue, // A function return value. }; private: Init *Name; SMLoc Loc; // Source location of definition of name. - PointerIntPair TyAndKind; + PointerIntPair TyAndKind; Init *Value; bool IsUsed = false; @@ -1546,6 +1704,12 @@ return TyAndKind.getInt() == FK_TemplateArg; } + /// Is this a function argument? + bool isFunctionArg() const { return TyAndKind.getInt() == FK_FunctionArg; } + + /// Is this a function argument? + bool isReturnValue() const { return TyAndKind.getInt() == FK_ReturnValue; } + /// Get the type of the field value as a RecTy. RecTy *getType() const { return TyAndKind.getPointer(); } @@ -1568,7 +1732,7 @@ ArrayRef getReferenceLocs() const { return ReferenceLocs; } /// Whether this value is used. Useful for reporting warnings, for example - /// when a template argument is unused. + /// when a template/function argument is unused. void setUsed(bool Used) { IsUsed = Used; } bool isUsed() const { return IsUsed; } @@ -1596,6 +1760,10 @@ : Loc(Loc), Condition(Condition), Message(Message) {} }; + enum RecordType { Other, Class, Function }; + + static std::string ReturnValueName; + private: Init *Name; // Location where record was instantiated, followed by the location of @@ -1608,6 +1776,10 @@ SmallVector Values; SmallVector Assertions; + SmallVector FunctionArgs; + FunctionRecTy *FunctionType = nullptr; + Record *Parent = nullptr; + // All superclasses in the inheritance forest in post-order (yes, it // must be a forest; diamond-shaped inheritance is not allowed). SmallVector, 0> SuperClasses; @@ -1622,23 +1794,23 @@ unsigned ID; bool IsAnonymous; - bool IsClass; + RecordType Type; void checkName(); public: // Constructs a record. explicit Record(Init *N, ArrayRef locs, RecordKeeper &records, - bool Anonymous = false, bool Class = false) + bool Anonymous = false, RecordType Type = Other) : Name(N), Locs(locs.begin(), locs.end()), TrackedRecords(records), ID(getNewUID(N->getRecordKeeper())), IsAnonymous(Anonymous), - IsClass(Class) { + Type(Type) { checkName(); } explicit Record(StringRef N, ArrayRef locs, RecordKeeper &records, - bool Class = false) - : Record(StringInit::get(records, N), locs, records, false, Class) {} + RecordType Type = Other) + : Record(StringInit::get(records, N), locs, records, false, Type) {} // When copy-constructing a Record, we must still guarantee a globally unique // ID number. Don't copy CorrespondingDefInit either, since it's owned by the @@ -1646,9 +1818,10 @@ Record(const Record &O) : Name(O.Name), Locs(O.Locs), TemplateArgs(O.TemplateArgs), Values(O.Values), Assertions(O.Assertions), - SuperClasses(O.SuperClasses), TrackedRecords(O.TrackedRecords), - ID(getNewUID(O.getRecords())), IsAnonymous(O.IsAnonymous), - IsClass(O.IsClass) {} + FunctionArgs(O.FunctionArgs), FunctionType(O.FunctionType), + Parent(O.Parent), SuperClasses(O.SuperClasses), + TrackedRecords(O.TrackedRecords), ID(getNewUID(O.getRecords())), + IsAnonymous(O.IsAnonymous), Type(O.Type) {} static unsigned getNewUID(RecordKeeper &RK); @@ -1688,12 +1861,29 @@ /// get the corresponding DefInit. DefInit *getDefInit(); - bool isClass() const { return IsClass; } + bool isClass() const { return Type == Class; } + + bool isFunction() const { return Type == Function; } ArrayRef getTemplateArgs() const { return TemplateArgs; } + ArrayRef getFunctionArgs() const { return FunctionArgs; } + + FunctionRecTy *getFunctionType() const { return FunctionType; } + + RecTy *getReturnType() const { return FunctionType->getReturnType(); } + + const RecordVal *getReturnValue() const { + for (const RecordVal &Val : Values) + if (Val.isReturnValue()) + return &Val; + return nullptr; + } + + Record *getParent() { return Parent; } + ArrayRef getValues() const { return Values; } ArrayRef getAssertions() const { return Assertions; } @@ -1712,9 +1902,17 @@ return llvm::is_contained(TemplateArgs, Name); } + bool isFunctionArg(Init *Name) const { + return llvm::is_contained(FunctionArgs, Name); + } + const RecordVal *getValue(const Init *Name) const { for (const RecordVal &Val : Values) - if (Val.Name == Name) return &Val; + if (Val.Name == Name) + return &Val; + if (Parent) + if (auto *V = Parent->getValue(Name)) + return V->isTemplateArg() || V->isFunctionArg() ? V : nullptr; return nullptr; } @@ -1735,6 +1933,24 @@ TemplateArgs.push_back(Name); } + void addFunctionArg(Init *Name) { + assert(!isFunctionArg(Name) && "Function arg already defined!"); + FunctionArgs.push_back(Name); + } + + void setFunctionType(FunctionRecTy *FunctionType) { + assert(isFunction() && "Only for function type!"); + this->FunctionType = FunctionType; + } + + void setParent(Record *Parent) { this->Parent = Parent; } + + bool hasReturnValue() const { return getReturnValue(); } + + StringInit *getReturnValueName() const { + return StringInit::get(getRecords(), ReturnValueName); + } + void addValue(const RecordVal &RV) { assert(getValue(RV.getNameInit()) == nullptr && "Value already added!"); Values.push_back(RV); @@ -1763,6 +1979,7 @@ void checkRecordAssertions(); void checkUnusedTemplateArgs(); + void checkUnusedFunctionArgs(); bool isSubClassOf(const Record *R) const { for (const auto &SCPair : SuperClasses) @@ -1915,6 +2132,9 @@ /// Get the map of classes. const RecordMap &getClasses() const { return Classes; } + /// Get the map of functions. + const RecordMap &getFunctions() const { return Functions; } + /// Get the map of records (defs). const RecordMap &getDefs() const { return Defs; } @@ -1927,6 +2147,12 @@ return I == Classes.end() ? nullptr : I->second.get(); } + /// Get the function with the specified name. + Record *getFunction(StringRef Name) const { + auto I = Functions.find(Name); + return I == Functions.end() ? nullptr : I->second.get(); + } + /// Get the concrete record with the specified name. Record *getDef(StringRef Name) const { auto I = Defs.find(Name); @@ -1952,6 +2178,15 @@ assert(Ins && "Class already exists"); } + void addFunction(std::unique_ptr R) { + bool Ins = + Functions + .insert(std::make_pair(std::string(R->getName()), std::move(R))) + .second; + (void)Ins; + assert(Ins && "Function already exists"); + } + void addDef(std::unique_ptr R) { bool Ins = Defs.insert(std::make_pair(std::string(R->getName()), std::move(R))).second; @@ -1966,6 +2201,10 @@ assert(Ins && "Global already exists"); } + void removeExtraGlobal(StringRef Name) { + ExtraGlobals.erase(std::string(Name)); + } + Init *getNewAnonymousName(); /// Start phase timing; called if the --time-phases option is specified. @@ -2018,7 +2257,7 @@ RecordKeeper &operator=(const RecordKeeper &) = delete; std::string InputFilename; - RecordMap Classes, Defs; + RecordMap Classes, Functions, Defs; mutable StringMap> ClassRecordsMap; GlobalMap ExtraGlobals; diff --git a/llvm/lib/TableGen/Main.cpp b/llvm/lib/TableGen/Main.cpp --- a/llvm/lib/TableGen/Main.cpp +++ b/llvm/lib/TableGen/Main.cpp @@ -68,6 +68,10 @@ "no-warn-on-unused-template-args", cl::desc("Disable unused template argument warnings.")); +static cl::opt NoWarnOnUnusedFunctionArgs( + "no-warn-on-unused-function-args", + cl::desc("Disable unused function argument warnings.")); + static int reportError(const char *ProgName, Twine Msg) { errs() << ProgName << ": " << Msg; errs().flush(); @@ -121,7 +125,8 @@ // it later. SrcMgr.setIncludeDirs(IncludeDirs); - TGParser Parser(SrcMgr, MacroNames, Records, NoWarnOnUnusedTemplateArgs); + TGParser Parser(SrcMgr, MacroNames, Records, NoWarnOnUnusedTemplateArgs, + NoWarnOnUnusedFunctionArgs); if (Parser.ParseFile()) return 1; diff --git a/llvm/lib/TableGen/Parser.cpp b/llvm/lib/TableGen/Parser.cpp --- a/llvm/lib/TableGen/Parser.cpp +++ b/llvm/lib/TableGen/Parser.cpp @@ -30,6 +30,7 @@ TGParser Parser(SrcMgr, /*Macros=*/std::nullopt, Records, /*NoWarnOnUnusedTemplateArgs=*/false, + /*NoWarnOnUnusedFunctionArgs=*/false, /*TrackReferenceLocs=*/true); bool ParseResult = Parser.ParseFile(); diff --git a/llvm/lib/TableGen/Record.cpp b/llvm/lib/TableGen/Record.cpp --- a/llvm/lib/TableGen/Record.cpp +++ b/llvm/lib/TableGen/Record.cpp @@ -86,10 +86,13 @@ DenseMap, VarListElementInit *> TheVarListElementInitPool; FoldingSet TheVarDefInitPool; + FoldingSet TheFuncCallInitPool; + FoldingSet TheFuncRefInitPool; DenseMap, FieldInit *> TheFieldInitPool; FoldingSet TheCondOpInitPool; FoldingSet TheDagInitPool; FoldingSet RecordTypePool; + FoldingSet FunctionTypePool; unsigned AnonCounter; unsigned LastRecordID; @@ -309,6 +312,77 @@ return RecordRecTy::get(T1->getRecordKeeper(), CommonSuperClasses); } +static void ProfileFunctionRecTy(FoldingSetNodeID &ID, RecTy *ReturnTy, + ArrayRef ArgTypes) { + ID.AddPointer(ReturnTy); + ID.AddInteger(ArgTypes.size()); + for (RecTy *ArgType : ArgTypes) + ID.AddPointer(ArgType); +} + +FunctionRecTy *FunctionRecTy::get(RecordKeeper &RK, RecTy *ReturnTy, + ArrayRef ArgTypes) { + detail::RecordKeeperImpl &RKImpl = RK.getImpl(); + FoldingSet &ThePool = RKImpl.FunctionTypePool; + + FoldingSetNodeID ID; + ProfileFunctionRecTy(ID, ReturnTy, ArgTypes); + + void *IP = nullptr; + if (FunctionRecTy *Ty = ThePool.FindNodeOrInsertPos(ID, IP)) + return Ty; + + void *Mem = RKImpl.Allocator.Allocate( + totalSizeToAlloc(ArgTypes.size()), alignof(FunctionRecTy)); + FunctionRecTy *Ty = new (Mem) FunctionRecTy(RK, ReturnTy, ArgTypes.size()); + std::uninitialized_copy(ArgTypes.begin(), ArgTypes.end(), + Ty->getTrailingObjects()); + ThePool.InsertNode(Ty, IP); + return Ty; +} + +FunctionRecTy *FunctionRecTy::get(RecordKeeper &RK, RecTy *ReturnTy) { + return get(RK, ReturnTy, {}); +} + +void FunctionRecTy::Profile(FoldingSetNodeID &ID) const { + ProfileFunctionRecTy(ID, ReturnTy, args()); +} + +std::string FunctionRecTy::getAsString() const { + std::string Str = "function<"; + Str += ReturnTy->getAsString(); + for (RecTy *Ty : args()) { + Str += ", "; + Str += Ty->getAsString(); + } + Str += ">"; + return Str; +} + +bool FunctionRecTy::typeIsConvertibleTo(const RecTy *RHS) const { + if (this == RHS) + return true; + + const FunctionRecTy *RTy = dyn_cast(RHS); + if (!RTy) + return false; + if (NumArgs != RTy->NumArgs) + return false; + if (!ReturnTy->typeIsConvertibleTo(RTy->ReturnTy)) + return false; + + for (unsigned I = 0; I < NumArgs; I++) + if (!getArg(I)->typeIsConvertibleTo(RTy->getArg(I))) + return false; + + return true; +} + +bool FunctionRecTy::typeIsA(const RecTy *RHS) const { + return typeIsConvertibleTo(RHS); +} + RecTy *llvm::resolveTypes(RecTy *T1, RecTy *T2) { if (T1 == T2) return T1; @@ -2140,6 +2214,231 @@ return Result + ">"; } +static void ProfileFuncRefInit(FoldingSetNodeID &ID, Record *Func, + ArrayRef Captures) { + ID.AddPointer(Func); + for (auto &Capture : Captures) { + ID.AddPointer(Capture.first); + ID.AddPointer(Capture.second); + } +} + +FuncRefInit::FuncRefInit(Record *Func, unsigned N) + : TypedInit(IK_FuncRefInit, Func->getFunctionType()), Func(Func), + NumCaptures(N) {} + +FuncRefInit *FuncRefInit::get(Record *Func, ArrayRef Captures) { + assert(Func->isFunction() && "Only for function!"); + FoldingSetNodeID ID; + ProfileFuncRefInit(ID, Func, Captures); + + detail::RecordKeeperImpl &RK = Func->getRecords().getImpl(); + void *IP = nullptr; + if (FuncRefInit *I = RK.TheFuncRefInitPool.FindNodeOrInsertPos(ID, IP)) + return I; + + void *Mem = RK.Allocator.Allocate(totalSizeToAlloc(Captures.size()), + alignof(FuncRefInit)); + FuncRefInit *I = new (Mem) FuncRefInit(Func, Captures.size()); + std::uninitialized_copy(Captures.begin(), Captures.end(), + I->getTrailingObjects()); + RK.TheFuncRefInitPool.InsertNode(I, IP); + return I; +} + +FuncRefInit *FuncRefInit::get(Record *Func) { return get(Func, {}); } + +void FuncRefInit::Profile(FoldingSetNodeID &ID) const { + ProfileFuncRefInit(ID, Func, captures()); +} + +RecTy *FuncRefInit::getReturnType() const { return Func->getReturnType(); } + +std::string FuncRefInit::getAsString() const { + std::string Result = Func->getName().str(); + if (captures_empty()) + return Result; + + Result += "<"; + const char *Sep = ""; + for (auto &Capture : captures()) { + Result += Sep; + Sep = ", "; + Result += Capture.first->getAsUnquotedString(); + Result += " = "; + Result += Capture.second->getAsUnquotedString(); + } + return Result + ">"; +} + +Init *FuncRefInit::resolveReferences(Resolver &R) const { + bool Changed = false; + SmallVector NewCaptures; + for (auto &Capture : captures()) { + Init *VarName = Capture.first; + Init *VarValue = Capture.first; + Init *NewVarValue = R.resolve(Capture.first); + bool Resolved = NewVarValue != nullptr; + Changed |= NewVarValue != nullptr; + NewCaptures.push_back( + std::make_pair(VarName, Resolved ? NewVarValue : VarValue)); + } + return Changed ? FuncRefInit::get(Func, NewCaptures) + : const_cast(this); +} + +static void ProfileFuncCallInit(FoldingSetNodeID &ID, Init *FuncRef, + FunctionRecTy *FuncTy, ArrayRef Args, + SMLoc CallLoc) { + ID.AddInteger(Args.size()); + ID.AddPointer(FuncRef); + ID.AddPointer(FuncTy); + ID.AddPointer(CallLoc.getPointer()); + + for (Init *I : Args) + ID.AddPointer(I); +} + +FuncCallInit::FuncCallInit(Init *FuncRef, FunctionRecTy *FuncTy, unsigned N, + SMLoc CallLoc) + : TypedInit(IK_FuncCallInit, FuncTy->getReturnType()), FuncRef(FuncRef), + FuncTy(FuncTy), NumArgs(N), CallLoc(CallLoc) {} + +FuncCallInit *FuncCallInit::get(Init *FuncRef, FunctionRecTy *FuncTy, + ArrayRef Args, SMLoc CallLoc) { + FoldingSetNodeID ID; + ProfileFuncCallInit(ID, FuncRef, FuncTy, Args, CallLoc); + + detail::RecordKeeperImpl &RK = FuncRef->getRecordKeeper().getImpl(); + void *IP = nullptr; + if (FuncCallInit *I = RK.TheFuncCallInitPool.FindNodeOrInsertPos(ID, IP)) + return I; + + void *Mem = RK.Allocator.Allocate(totalSizeToAlloc(Args.size()), + alignof(FuncCallInit)); + FuncCallInit *I = + new (Mem) FuncCallInit(FuncRef, FuncTy, Args.size(), CallLoc); + std::uninitialized_copy(Args.begin(), Args.end(), + I->getTrailingObjects()); + RK.TheFuncCallInitPool.InsertNode(I, IP); + return I; +} + +void FuncCallInit::Profile(FoldingSetNodeID &ID) const { + ProfileFuncCallInit(ID, FuncRef, FuncTy, args(), CallLoc); +} + +Init *FuncCallInit::evaluate() { + if (!Ret) { + assert(isa(FuncRef) && "Should be a FuncRefInit"); + auto *Ref = cast(FuncRef); + Record *Func = Ref->getFunction(); + RecordKeeper &Records = Func->getRecords(); + auto NewRecOwner = std::make_unique(Records.getNewAnonymousName(), + Func->getLoc(), Records, + /*IsAnonymous=*/true); + Record *NewRec = NewRecOwner.get(); + + // Copy assertions from function to instance. + NewRec->appendAssertions(Func); + + // Copy values from function to evaluate. + for (const RecordVal &Val : Func->getValues()) + NewRec->addValue(Val); + + // Loop through the arguments that were not specified and make sure + // they have a complete value. + ArrayRef ActualArgs = args(); + ArrayRef FormalArgs = Func->getFunctionArgs(); + for (unsigned I = ActualArgs.size(), E = FormalArgs.size(); I < E; ++I) { + RecordVal *Arg = Func->getValue(FormalArgs[I]); + if (!Arg->getValue()->isComplete()) { + PrintError(Arg->getLoc(), + "Value not specified for function argument '" + + FormalArgs[I]->getAsUnquotedString() + "' (#" + + Twine(I) + ") of function '" + + Func->getNameInitAsString() + "'"); + PrintFatalNote(CallLoc, "called from here"); + return this; + } + } + + // Substitute and resolve function arguments. + MapResolver R(NewRec); + for (unsigned I = 0, E = FormalArgs.size(); I != E; ++I) { + if (I < args_size()) + R.set(FormalArgs[I], getArg(I)); + else + R.set(FormalArgs[I], NewRec->getValue(FormalArgs[I])->getValue()); + + NewRec->removeValue(FormalArgs[I]); + } + + // Add all the captures. + auto Captures = Ref->captures(); + for (auto &Capture : Captures) + R.set(Capture.first, Capture.second); + + NewRec->resolveReferences(R); + + // Check the assertions. + NewRec->checkRecordAssertions(); + + // Get the return value. + auto *Def = DefInit::get(NewRec); + Ret = FieldInit::get(Def, NewRec->getReturnValueName())->Fold(nullptr); + } + + return Ret; +} + +Init *FuncCallInit::resolveReferences(Resolver &R) const { + TrackUnresolvedResolver UR(&R); + bool Changed = false; + SmallVector NewArgs; + NewArgs.reserve(args_size()); + + for (Init *Arg : args()) { + Init *NewArg = Arg->resolveReferences(UR); + NewArgs.push_back(NewArg); + Changed |= NewArg != Arg; + } + + auto *NewFuncRef = FuncRef->resolveReferences(R); + Changed |= NewFuncRef != FuncRef; + if (Changed) { + auto *New = FuncCallInit::get(NewFuncRef, FuncTy, NewArgs, CallLoc); + if (!UR.foundUnresolved() && isa(NewFuncRef)) + return New->evaluate(); + return New; + } + return const_cast(this); +} + +Init *FuncCallInit::Fold() const { + if (Ret) + return Ret; + + TrackUnresolvedResolver R; + for (Init *Arg : args()) + Arg->resolveReferences(R); + + if (!R.foundUnresolved() && isa(FuncRef)) + return const_cast(this)->evaluate(); + return const_cast(this); +} + +std::string FuncCallInit::getAsString() const { + std::string Result = FuncRef->getAsString() + "("; + const char *Sep = ""; + for (Init *Arg : args()) { + Result += Sep; + Sep = ", "; + Result += Arg->getAsString(); + } + return Result + ")"; +} + FieldInit *FieldInit::get(Init *R, StringInit *FN) { detail::RecordKeeperImpl &RK = R->getRecordKeeper().getImpl(); FieldInit *&I = RK.TheFieldInitPool[std::make_pair(R, FN)]; @@ -2522,6 +2821,8 @@ if (PrintSem) OS << ";\n"; } +std::string Record::ReturnValueName = "__return__"; + void Record::updateClassLoc(SMLoc Loc) { assert(Locs.size() == 1); ForwardDeclarationLocs.push_back(Locs.front()); @@ -2653,19 +2954,26 @@ raw_ostream &llvm::operator<<(raw_ostream &OS, const Record &R) { OS << R.getNameInitAsString(); - ArrayRef TArgs = R.getTemplateArgs(); - if (!TArgs.empty()) { - OS << "<"; + bool IsFunction = R.isFunction(); + ArrayRef Args = + IsFunction ? R.getFunctionArgs() : R.getTemplateArgs(); + if (!Args.empty()) { + OS << (IsFunction ? "(" : "<"); bool NeedComma = false; - for (const Init *TA : TArgs) { - if (NeedComma) OS << ", "; + for (const Init *A : Args) { + if (NeedComma) + OS << ", "; NeedComma = true; - const RecordVal *RV = R.getValue(TA); - assert(RV && "Template argument record not found??"); + const RecordVal *RV = R.getValue(A); + assert(RV && "Template or function argument record not found??"); RV->print(OS, false); } - OS << ">"; - } + OS << (IsFunction ? ")" : ">"); + } else if (IsFunction) + OS << "()"; + + if (IsFunction) + OS << ": " << *R.getReturnType(); OS << " {"; ArrayRef> SC = R.getSuperClasses(); @@ -2676,12 +2984,16 @@ } OS << "\n"; - for (const RecordVal &Val : R.getValues()) - if (Val.isNonconcreteOK() && !R.isTemplateArg(Val.getNameInit())) - OS << Val; - for (const RecordVal &Val : R.getValues()) - if (!Val.isNonconcreteOK() && !R.isTemplateArg(Val.getNameInit())) - OS << Val; + if (IsFunction) + OS << " return " << *R.getReturnValue()->getValue() << ";\n"; + else { + for (const RecordVal &Val : R.getValues()) + if (Val.isNonconcreteOK() && !R.isTemplateArg(Val.getNameInit())) + OS << Val; + for (const RecordVal &Val : R.getValues()) + if (!Val.isNonconcreteOK() && !R.isTemplateArg(Val.getNameInit())) + OS << Val; + } return OS << "}\n"; } @@ -2903,6 +3215,16 @@ } } +// Report a warning if the record has unused function arguments. +void Record::checkUnusedFunctionArgs() { + for (const Init *FA : getFunctionArgs()) { + const RecordVal *Arg = getValue(FA); + if (!Arg->isUsed()) + PrintWarning(Arg->getLoc(), + "unused function argument: " + Twine(Arg->getName())); + } +} + RecordKeeper::RecordKeeper() : Impl(std::make_unique(*this)) {} RecordKeeper::~RecordKeeper() = default; @@ -2912,6 +3234,10 @@ #endif raw_ostream &llvm::operator<<(raw_ostream &OS, const RecordKeeper &RK) { + OS << "------------- Functions -----------------\n"; + for (const auto &C : RK.getFunctions()) + OS << "function " << *C.second; + OS << "------------- Classes -----------------\n"; for (const auto &C : RK.getClasses()) OS << "class " << *C.second; diff --git a/llvm/lib/TableGen/TGLexer.h b/llvm/lib/TableGen/TGLexer.h --- a/llvm/lib/TableGen/TGLexer.h +++ b/llvm/lib/TableGen/TGLexer.h @@ -28,38 +28,40 @@ class SourceMgr; class Twine; +// clang-format off namespace tgtok { enum TokKind { // Markers - Eof, Error, + Eof, Error, // Tokens with no info. - minus, plus, // - + - l_square, r_square, // [ ] - l_brace, r_brace, // { } - l_paren, r_paren, // ( ) - less, greater, // < > - colon, semi, // : ; - comma, dot, // , . - equal, question, // = ? - paste, // # - dotdotdot, // ... + minus, plus, // - + + l_square, r_square, // [ ] + l_brace, r_brace, // { } + l_paren, r_paren, // ( ) + less, greater, // < > + colon, semi, // : ; + comma, dot, // , . + equal, question, // = ? + paste, // # + dotdotdot, // ... + single_quote, // ' // Reserved keywords. ('ElseKW' is named to distinguish it from the // existing 'Else' that means the preprocessor #else.) - Assert, Bit, Bits, Class, Code, Dag, Def, Defm, Defset, Defvar, ElseKW, - FalseKW, Field, Foreach, If, In, Include, Int, Let, List, MultiClass, - String, Then, TrueKW, + Assert, Bit, Bits, Class, Code, Dag, Def, Defm, Defset, Defvar, ElseKW, + FalseKW, Field, Foreach, If, In, Include, Int, Let, List, MultiClass, + String, Then, TrueKW, Function, Return, // Bang operators. - XConcat, XADD, XSUB, XMUL, XDIV, XNOT, XLOG2, XAND, XOR, XXOR, XSRA, XSRL, - XSHL, XListConcat, XListSplat, XStrConcat, XInterleave, XSubstr, XFind, - XCast, XSubst, XForEach, XFilter, XFoldl, XHead, XTail, XSize, XEmpty, XIf, - XCond, XEq, XIsA, XDag, XNe, XLe, XLt, XGe, XGt, XSetDagOp, XGetDagOp, - XExists, XListRemove, XToLower, XToUpper, + XConcat, XADD, XSUB, XMUL, XDIV, XNOT, XLOG2, XAND, XOR, XXOR, XSRA, XSRL, + XSHL, XListConcat, XListSplat, XStrConcat, XInterleave, XSubstr, XFind, + XCast, XSubst, XForEach, XFilter, XFoldl, XHead, XTail, XSize, XEmpty, XIf, + XCond, XEq, XIsA, XDag, XNe, XLe, XLt, XGe, XGt, XSetDagOp, XGetDagOp, + XExists, XListRemove, XToLower, XToUpper, // Boolean literals. - TrueVal, FalseVal, + TrueVal, FalseVal, // Integer value. IntVal, @@ -69,13 +71,14 @@ BinaryIntVal, // String valued tokens. - Id, StrVal, VarName, CodeFragment, + Id, StrVal, VarName, CodeFragment, // Preprocessing tokens for internal usage by the lexer. // They are never returned as a result of Lex(). - Ifdef, Ifndef, Else, Endif, Define + Ifdef, Ifndef, Else, Endif, Define }; } +// clang-format on /// TGLexer - TableGen Lexer class. class TGLexer { diff --git a/llvm/lib/TableGen/TGLexer.cpp b/llvm/lib/TableGen/TGLexer.cpp --- a/llvm/lib/TableGen/TGLexer.cpp +++ b/llvm/lib/TableGen/TGLexer.cpp @@ -164,6 +164,8 @@ // Return EOF denoting the end of lexing. return tgtok::Eof; + // clang-format off + case '\'': return tgtok::single_quote; case ':': return tgtok::colon; case ';': return tgtok::semi; case ',': return tgtok::comma; @@ -176,6 +178,7 @@ case ')': return tgtok::r_paren; case '=': return tgtok::equal; case '?': return tgtok::question; + // clang-format on case '#': if (FileOrLineStart) { tgtok::TokKind Kind = prepIsDirective(); @@ -346,31 +349,33 @@ StringRef Str(IdentStart, CurPtr-IdentStart); tgtok::TokKind Kind = StringSwitch(Str) - .Case("int", tgtok::Int) - .Case("bit", tgtok::Bit) - .Case("bits", tgtok::Bits) - .Case("string", tgtok::String) - .Case("list", tgtok::List) - .Case("code", tgtok::Code) - .Case("dag", tgtok::Dag) - .Case("class", tgtok::Class) - .Case("def", tgtok::Def) - .Case("true", tgtok::TrueVal) - .Case("false", tgtok::FalseVal) - .Case("foreach", tgtok::Foreach) - .Case("defm", tgtok::Defm) - .Case("defset", tgtok::Defset) - .Case("multiclass", tgtok::MultiClass) - .Case("field", tgtok::Field) - .Case("let", tgtok::Let) - .Case("in", tgtok::In) - .Case("defvar", tgtok::Defvar) - .Case("include", tgtok::Include) - .Case("if", tgtok::If) - .Case("then", tgtok::Then) - .Case("else", tgtok::ElseKW) - .Case("assert", tgtok::Assert) - .Default(tgtok::Id); + .Case("int", tgtok::Int) + .Case("bit", tgtok::Bit) + .Case("bits", tgtok::Bits) + .Case("string", tgtok::String) + .Case("list", tgtok::List) + .Case("function", tgtok::Function) + .Case("code", tgtok::Code) + .Case("dag", tgtok::Dag) + .Case("class", tgtok::Class) + .Case("def", tgtok::Def) + .Case("true", tgtok::TrueVal) + .Case("false", tgtok::FalseVal) + .Case("foreach", tgtok::Foreach) + .Case("defm", tgtok::Defm) + .Case("defset", tgtok::Defset) + .Case("multiclass", tgtok::MultiClass) + .Case("field", tgtok::Field) + .Case("let", tgtok::Let) + .Case("in", tgtok::In) + .Case("defvar", tgtok::Defvar) + .Case("include", tgtok::Include) + .Case("if", tgtok::If) + .Case("then", tgtok::Then) + .Case("else", tgtok::ElseKW) + .Case("assert", tgtok::Assert) + .Case("return", tgtok::Return) + .Default(tgtok::Id); // A couple of tokens require special processing. switch (Kind) { diff --git a/llvm/lib/TableGen/TGParser.h b/llvm/lib/TableGen/TGParser.h --- a/llvm/lib/TableGen/TGParser.h +++ b/llvm/lib/TableGen/TGParser.h @@ -161,14 +161,17 @@ }; bool NoWarnOnUnusedTemplateArgs = false; + bool NoWarnOnUnusedFunctionArgs = false; bool TrackReferenceLocs = false; public: TGParser(SourceMgr &SM, ArrayRef Macros, RecordKeeper &records, const bool NoWarnOnUnusedTemplateArgs = false, + const bool NoWarnOnUnusedFunctionArgs = false, const bool TrackReferenceLocs = false) : Lex(SM, Macros), CurMultiClass(nullptr), Records(records), NoWarnOnUnusedTemplateArgs(NoWarnOnUnusedTemplateArgs), + NoWarnOnUnusedFunctionArgs(NoWarnOnUnusedFunctionArgs), TrackReferenceLocs(TrackReferenceLocs) {} /// ParseFile - Main entrypoint for parsing a tblgen file. These parser @@ -243,8 +246,16 @@ bool ParseBody(Record *CurRec); bool ParseBodyItem(Record *CurRec); + bool ParseFunctionDefinition(); + bool ParseFunction(Record *CurRec); + bool ParseFunctionArgList(Record *CurRec, SmallVectorImpl &ArgTypes); + bool ParseFunctionBody(Record *CurRec); + bool ParseFunctionBodyItem(Record *CurRec); + bool ParseReturn(Record *CurRec); + Init *ParseCall(Record *CurRec, Init *FuncRef, FunctionRecTy *FuncTy); + bool ParseTemplateArgList(Record *CurRec); - Init *ParseDeclaration(Record *CurRec, bool ParsingTemplateArgs); + Init *ParseDeclaration(Record *CurRec, bool ParsingTemplateOrFunctionArgs); VarInit *ParseForeachDeclaration(Init *&ForeachListValue); SubClassReference ParseSubClassReference(Record *CurRec, bool isDefm); @@ -255,11 +266,14 @@ Init *ParseSimpleValue(Record *CurRec, RecTy *ItemType = nullptr, IDParseMode Mode = ParseValueMode); Init *ParseValue(Record *CurRec, RecTy *ItemType = nullptr, - IDParseMode Mode = ParseValueMode); + IDParseMode Mode = ParseValueMode, bool ParsingDag = false); void ParseValueList(SmallVectorImpl &Result, Record *CurRec, RecTy *ItemType = nullptr); bool ParseTemplateArgValueList(SmallVectorImpl &Result, Record *CurRec, Record *ArgsRec); + bool ParseFunctionArgValueList(FunctionRecTy *FuncTy, + SmallVectorImpl &Result, + Record *CurRec); void ParseDagArgList( SmallVectorImpl> &Result, Record *CurRec); diff --git a/llvm/lib/TableGen/TGParser.cpp b/llvm/lib/TableGen/TGParser.cpp --- a/llvm/lib/TableGen/TGParser.cpp +++ b/llvm/lib/TableGen/TGParser.cpp @@ -214,13 +214,16 @@ std::string InitType; if (BitsInit *BI = dyn_cast(V)) InitType = (Twine("' of type bit initializer with length ") + - Twine(BI->getNumBits())).str(); + Twine(BI->getNumBits())) + .str(); else if (TypedInit *TI = dyn_cast(V)) InitType = (Twine("' of type '") + TI->getType()->getAsString()).str(); - return Error(Loc, "Field '" + ValName->getAsUnquotedString() + - "' of type '" + RV->getType()->getAsString() + - "' is incompatible with value '" + - V->getAsString() + InitType + "'"); + return RV->isReturnValue() + ? true + : Error(Loc, "Field '" + ValName->getAsUnquotedString() + + "' of type '" + RV->getType()->getAsString() + + "' is incompatible with value '" + + V->getAsString() + InitType + "'"); } return false; } @@ -540,8 +543,8 @@ static bool isObjectStart(tgtok::TokKind K) { return K == tgtok::Assert || K == tgtok::Class || K == tgtok::Def || K == tgtok::Defm || K == tgtok::Defset || K == tgtok::Defvar || - K == tgtok::Foreach || K == tgtok::If || K == tgtok::Let || - K == tgtok::MultiClass; + K == tgtok::Function || K == tgtok::Foreach || K == tgtok::If || + K == tgtok::Let || K == tgtok::MultiClass; } bool TGParser::consume(tgtok::TokKind K) { @@ -827,14 +830,15 @@ /// ParseType - Parse and return a tblgen type. This returns null on error. /// -/// Type ::= STRING // string type -/// Type ::= CODE // code type -/// Type ::= BIT // bit type -/// Type ::= BITS '<' INTVAL '>' // bits type -/// Type ::= INT // int type -/// Type ::= LIST '<' Type '>' // list type -/// Type ::= DAG // dag type -/// Type ::= ClassID // Record Type +/// Type ::= STRING // string type +/// Type ::= CODE // code type +/// Type ::= BIT // bit type +/// Type ::= BITS '<' INTVAL '>' // bits type +/// Type ::= INT // int type +/// Type ::= LIST '<' Type '>' // list type +/// Type ::= FUNCTION '<' ReturnType [ ',' ArgType ]* '>' // function type +/// Type ::= DAG // dag type +/// Type ::= ClassID // Record Type /// RecTy *TGParser::ParseType() { switch (Lex.getCode()) { @@ -889,6 +893,30 @@ } return ListRecTy::get(SubType); } + case tgtok::Function: { + if (Lex.Lex() != tgtok::less) { // Eat 'function' + TokError("expected '<' after function"); + return nullptr; + } + Lex.Lex(); // Eat '<' + RecTy *ReturnType = ParseType(); + if (!ReturnType) + return nullptr; + + SmallVector ArgTypes; + while (consume(tgtok::comma)) { + RecTy *ArgType = ParseType(); + if (!ArgType) + return nullptr; + ArgTypes.push_back(ArgType); + } + + if (!consume(tgtok::greater)) { + TokError("expected '>' at end of function type"); + return nullptr; + } + return FunctionRecTy::get(Records, ReturnType, ArgTypes); + } } } @@ -909,27 +937,32 @@ CurLocalScope->getVar(Name->getValue(), /* FindInParent*/ false)) return I; - // The ID is a class/multiclass template argument? - if ((CurRec && CurRec->isClass()) || CurMultiClass) { - Init *TemplateArgName; - if (CurMultiClass) { - TemplateArgName = - QualifyName(CurMultiClass->Rec, CurMultiClass, Name, "::"); - } else - TemplateArgName = QualifyName(*CurRec, CurMultiClass, Name, ":"); - - Record *TemplateRec = CurMultiClass ? &CurMultiClass->Rec : CurRec; - if (TemplateRec->isTemplateArg(TemplateArgName)) { - RecordVal *RV = TemplateRec->getValue(TemplateArgName); - assert(RV && "Template arg doesn't exist??"); + std::function FindValueInArgs = + [&](Record *Rec, StringInit *Name, StringRef Scoper) -> Init * { + if (!Rec) + return nullptr; + Init *ArgName = QualifyName(*Rec, CurMultiClass, Name, Scoper); + if (Rec->isTemplateArg(ArgName) || Rec->isFunctionArg(ArgName)) { + RecordVal *RV = Rec->getValue(ArgName); + assert(RV && "Template or function arg doesn't exist??"); RV->setUsed(true); if (TrackReferenceLocs) RV->addReferenceLoc(NameLoc); - return VarInit::get(TemplateArgName, RV->getType()); - } else if (Name->getValue() == "NAME") { - return VarInit::get(TemplateArgName, StringRecTy::get(Records)); + return VarInit::get(ArgName, RV->getType()); } - } + return Name->getValue() == "NAME" + ? VarInit::get(ArgName, StringRecTy::get(Records)) + : FindValueInArgs(Rec->getParent(), Name, Scoper); + }; + + // The ID is a class/multiclass template argument? + if (CurMultiClass) + if (auto *V = FindValueInArgs(&CurMultiClass->Rec, Name, "::")) + return V; + + if (CurRec && (CurRec->isClass() || CurRec->isFunction())) + if (auto *V = FindValueInArgs(CurRec, Name, ":")) + return V; // Then, we try to find the ID defined in parent local scope. if (CurLocalScope) @@ -2205,6 +2238,7 @@ /// SimpleValue ::= '?' /// SimpleValue ::= '{' ValueList '}' /// SimpleValue ::= ID '<' ValueListNE '>' +/// SimpleValue ::= FUNCTION Function /// SimpleValue ::= '[' ValueList ']' /// SimpleValue ::= '(' IDValue DagArgList ')' /// SimpleValue ::= CONCATTOK '(' Value ',' Value ')' @@ -2268,6 +2302,48 @@ R = UnsetInit::get(Records); Lex.Lex(); break; + case tgtok::Function: { + Lex.Lex(); + std::unique_ptr NewRec; + NewRec = std::make_unique( + cast(Records.getNewAnonymousName())->getNameInit(), + Lex.getLoc(), Records, + /*Anonymous=*/true, Record::RecordType::Function); + Record *Parent = CurRec; + NewRec->setParent(CurRec); + CurRec = NewRec.get(); + if (ParseFunction(CurRec)) { + TokError("can't parse function"); + return nullptr; + } + + // Add this anonumous function. + Records.addFunction(std::move(NewRec)); + + // Add all captures. + // 1. Add class/function arguments. + SmallVector Captures; + while (Parent) { + for (auto &V : Parent->getValues()) + if (V.isTemplateArg() || V.isFunctionArg()) + Captures.push_back(std::make_pair(V.getNameInit(), V.getValue())); + Parent = Parent->getParent(); + } + + // 2. Add multiclass arguments. + if (CurMultiClass) + for (auto &V : CurMultiClass->Rec.getValues()) + if (V.isTemplateArg()) + Captures.push_back(std::make_pair(V.getNameInit(), V.getValue())); + + // 3. Add loop iterators + for (const auto &L : Loops) + if (L->IterVar) + if (VarInit *IterVar = dyn_cast(L->IterVar)) + Captures.push_back(std::make_pair(IterVar->getNameInit(), IterVar)); + + return FuncRefInit::get(CurRec, Captures); + } case tgtok::Id: { SMRange NameLoc = Lex.getLocRange(); StringInit *Name = StringInit::get(Records, Lex.getCurStrVal()); @@ -2458,7 +2534,8 @@ return nullptr; } - Init *Operator = ParseValue(CurRec); + Init *Operator = + ParseValue(CurRec, nullptr, ParseValueMode, /* ParsingDag */ true); if (!Operator) return nullptr; // If the operator name is present, parse it. @@ -2541,17 +2618,46 @@ /// /// Value ::= SimpleValue ValueSuffix* /// ValueSuffix ::= '{' BitList '}' +/// ValueSuffix ::= ('\'')? FunctionArgList /// ValueSuffix ::= '[' BitList ']' /// ValueSuffix ::= '.' ID /// -Init *TGParser::ParseValue(Record *CurRec, RecTy *ItemType, IDParseMode Mode) { +Init *TGParser::ParseValue(Record *CurRec, RecTy *ItemType, IDParseMode Mode, + bool ParsingDag) { Init *Result = ParseSimpleValue(CurRec, ItemType, Mode); if (!Result) return nullptr; // Parse the suffixes now if present. while (true) { switch (Lex.getCode()) { - default: return Result; + default: + return Result; + case tgtok::l_paren: + case tgtok::single_quote: { + // The call inside a DAG should add single quote before arguments. If not + // then this shouldn't be a function call and just returns the result. + // This is a compromised solution on grammar ambiguity. + if (ParsingDag && Lex.getCode() != tgtok::single_quote) + return Result; + + if (Lex.getCode() == tgtok::single_quote) { + if (!ParsingDag) + PrintWarning(Lex.getLoc(), + "Call with single quote is only used inside DAG"); + Lex.Lex(); // Jus eat the single quote. + } + + // The value to be called should be with function type. + auto *Typed = dyn_cast(Result); + RecTy *Ty = Typed ? Typed->getType() : nullptr; + if (!Ty || !isa(Ty)) { + TokError(Result->getAsString() + " is not callable."); + return nullptr; + } + + // Value ::= Value '(' ValueList ')' + return ParseCall(CurRec, Result, cast(Ty)); + } case tgtok::l_brace: { if (Mode == ParseNameMode) // This is the beginning of the object body. @@ -2738,7 +2844,8 @@ Lex.Lex(); } else { // DagArg ::= Value (':' VARNAME)? - Init *Val = ParseValue(CurRec); + Init *Val = + ParseValue(CurRec, nullptr, ParseValueMode, /* ParsingDag */ true); if (!Val) { Result.clear(); return; @@ -2829,6 +2936,58 @@ } } +// ParseFunctionArgValueList - Parse a function argument list with the syntax +// shown, filling in the Result vector. An empty argument list is allowed. The +// count and types of arguments will be checked. Return false if okay, true if +// an error was detected. +// +// FunctionArgList ::= '(' [Value {',' Value}*] ')' +bool TGParser::ParseFunctionArgValueList(FunctionRecTy *FuncTy, + SmallVectorImpl &Result, + Record *CurRec) { + assert(Result.empty() && "Result vector is not empty"); + unsigned ArgIndex = 0; + + if (consume(tgtok::r_paren)) // empty argument list + return false; + + while (true) { + if (ArgIndex >= FuncTy->args_size()) + return TokError("Too many function arguments, we need " + + Twine(ArgIndex) + " arguments here"); + + SMLoc Loc = Lex.getLoc(); + RecTy *ArgType = FuncTy->getArg(ArgIndex); + Init *Value = ParseValue(CurRec, ArgType); + if (!Value) + return true; + + TypedInit *ArgValue = dyn_cast(Value); + auto *CastValue = ArgValue ? ArgValue->getCastTo(ArgType) : nullptr; + + // Report error if argument type is not compatible. + if (!CastValue) + return Error( + Loc, "Value specified for the #" + Twine(ArgIndex) + + " function argument is " + + (ArgValue ? "of type " + ArgValue->getType()->getAsString() + : "uninitialized") + + "; expected type " + ArgType->getAsString() + ": " + + Value->getAsString()); + + assert((!isa(CastValue) || + cast(CastValue)->getType()->typeIsA(ArgType)) && + "result of function arg value cast has wrong type"); + Result.push_back(CastValue); + + if (consume(tgtok::r_paren)) // end of argument list? + return false; + if (!consume(tgtok::comma)) + return TokError("Expected comma before next argument"); + ++ArgIndex; + } +} + /// ParseDeclaration - Read a declaration, returning the name of field ID, or an /// empty string on error. This can happen in a number of different contexts, /// including within a def or in the template args for a class (in which case @@ -2840,7 +2999,7 @@ /// Declaration ::= FIELD? Type ID ('=' Value)? /// Init *TGParser::ParseDeclaration(Record *CurRec, - bool ParsingTemplateArgs) { + bool ParsingTemplateOrFunctionArgs) { // Read the field prefix if present. bool HasField = consume(tgtok::Field); @@ -2863,17 +3022,17 @@ Lex.Lex(); bool BadField; - if (!ParsingTemplateArgs) { // def, possibly in a multiclass + if (!ParsingTemplateOrFunctionArgs) { // def, possibly in a multiclass BadField = AddValue(CurRec, IdLoc, RecordVal(DeclName, IdLoc, Type, HasField ? RecordVal::FK_NonconcreteOK : RecordVal::FK_Normal)); - - } else if (CurRec) { // class template argument + } else if (CurRec) { // class template argument or function argument + RecordVal::FieldKind Kind = CurRec->isFunction() + ? RecordVal::FK_FunctionArg + : RecordVal::FK_TemplateArg; DeclName = QualifyName(*CurRec, CurMultiClass, DeclName, ":"); - BadField = AddValue(CurRec, IdLoc, RecordVal(DeclName, IdLoc, Type, - RecordVal::FK_TemplateArg)); - + BadField = AddValue(CurRec, IdLoc, RecordVal(DeclName, IdLoc, Type, Kind)); } else { // multiclass template argument assert(CurMultiClass && "invalid context for template argument"); DeclName = QualifyName(CurMultiClass->Rec, CurMultiClass, DeclName, "::"); @@ -3297,6 +3456,229 @@ return false; } +/// ParseFunctionDefinition - Parse a function definition. +/// +/// FunctionDefinition ::= FUNCTION Id Function +/// +bool TGParser::ParseFunctionDefinition() { + assert(Lex.getCode() == tgtok::Function && "Unexpected token!"); + Lex.Lex(); + + if (Lex.getCode() != tgtok::Id) + return TokError("expected class name after 'class' keyword"); + + Record *CurRec = Records.getFunction(Lex.getCurStrVal()); + if (CurRec) { + // If the function was previously defined, this is an error. + return TokError("Function '" + CurRec->getNameInitAsString() + + "' already defined"); + } + // If this is the first reference to this function, create and add it. + auto NewRec = + std::make_unique(Lex.getCurStrVal(), Lex.getLoc(), Records, + /*Type=*/Record::RecordType::Function); + CurRec = NewRec.get(); + + Lex.Lex(); // eat the name. + + if (ParseFunction(CurRec)) + return true; + + Records.addFunction(std::move(NewRec)); + return false; +} + +/// ParseFunction - Parse a function. +/// +/// Function ::= FunctionArgList ':' ReturnType FunctionBody +/// +bool TGParser::ParseFunction(Record *CurRec) { + SmallVector ArgTypes; + if (ParseFunctionArgList(CurRec, ArgTypes)) + return true; + + if (!consume(tgtok::colon)) + return TokError("exptected ':' before return type"); + + RecTy *ReturnType = ParseType(); + FunctionRecTy *FunctionType = + FunctionRecTy::get(Records, ReturnType, ArgTypes); + CurRec->setFunctionType(FunctionType); + + // Each function name is a function reference to itself. + // We add the function reference before parsing function body so that we can + // refer the function name in function body to implement recursive function. + if (!CurRec->isAnonymous()) + Records.addExtraGlobal(CurRec->getName(), FuncRefInit::get(CurRec)); + if (ParseFunctionBody(CurRec)) { + // Remove the function reference if we failed to parse the function body. + if (!CurRec->isAnonymous()) + Records.removeExtraGlobal(CurRec->getName()); + return true; + } + + if (!NoWarnOnUnusedFunctionArgs) + CurRec->checkUnusedFunctionArgs(); + return false; +} + +/// ParseFunctionArgList - Parse a function argument list, which can be an empty +/// sequence of declarations in ()'s. Argument types will be stored in ArgTypes +/// in declaration order. +/// +/// +/// FunctionArgList ::= '(' (Declaration (',' Declaration)*)? ')' +/// +bool TGParser::ParseFunctionArgList(Record *CurRec, + SmallVectorImpl &ArgTypes) { + if (!consume(tgtok::l_paren)) + return TokError("exptected '(' before arguments"); + + // If no arguments. + if (consume(tgtok::r_paren)) + return false; + + // The function has arguments. + // Read the first declaration. + Init *FuncArg = ParseDeclaration(CurRec, true /*functionargs*/); + if (!FuncArg) + return true; + + CurRec->addFunctionArg(FuncArg); + ArgTypes.push_back(CurRec->getValue(FuncArg)->getType()); + + while (consume(tgtok::comma)) { + // Read the following declarations. + SMLoc Loc = Lex.getLoc(); + FuncArg = ParseDeclaration(CurRec, true /*functionargs*/); + if (!FuncArg) + return true; + + if (CurRec->isFunctionArg(FuncArg)) + return Error(Loc, "function argument with the same name has already been " + "defined"); + + CurRec->addFunctionArg(FuncArg); + ArgTypes.push_back(CurRec->getValue(FuncArg)->getType()); + } + + if (!consume(tgtok::r_paren)) + return TokError("expected ')' at end of function argument list"); + return false; +} + +/// ParseFunctionBody - Parse function body. There should be at least one return +/// statement. +/// +/// FunctionBody ::= '{' FunctionBodyItem+ '}' +/// +bool TGParser::ParseFunctionBody(Record *CurRec) { + if (!consume(tgtok::l_brace)) + return TokError("Expected '{' to start function body"); + + if (Lex.getCode() == tgtok::l_brace) + return TokError("Expected statements in function body"); + + // An function body introduces a new scope for local variables. + TGLocalVarScope *FunctionBodyScope = PushLocalScope(); + + while (Lex.getCode() != tgtok::r_brace) + if (ParseFunctionBodyItem(CurRec)) + return true; + + PopLocalScope(FunctionBodyScope); + + if (!CurRec->hasReturnValue()) + return TokError("Expected at least one return statement in function body"); + + // Eat the '}'. + Lex.Lex(); + + return false; +} + +/// ParseFunctionBodyItem - Parse a single item within the body of a function. +/// +/// FunctionBodyItem ::= Return +/// FunctionBodyItem ::= Defvar +/// FunctionBodyItem ::= Assert +/// +bool TGParser::ParseFunctionBodyItem(Record *CurRec) { + if (CurRec->hasReturnValue()) { + if (Lex.getCode() == tgtok::Assert || Lex.getCode() == tgtok::Defvar) + PrintWarning(Lex.getLoc(), + "statements after return statement will be ignored"); + else if (Lex.getCode() == tgtok::Return) + // Report error if there are multiple return statements. + return TokError( + "There should be only one return statement in function body"); + } + + if (Lex.getCode() == tgtok::Assert) + return ParseAssert(nullptr, CurRec); + + if (Lex.getCode() == tgtok::Defvar) + return ParseDefvar(CurRec); + + if (Lex.getCode() == tgtok::Return) + return ParseReturn(CurRec); + + return TokError("unexpected expression in function body"); +} + +/// ParseReturn - Parse a return statement. +/// +/// Return ::= 'return' Value ';' +/// +bool TGParser::ParseReturn(Record *CurRec) { + Lex.Lex(); // Eat the 'return'. + + SMLoc Loc = Lex.getLoc(); + RecTy *ReturnType = CurRec->getReturnType(); + Init *Val = ParseValue(CurRec, ReturnType); + if (!Val) + return true; + + if (auto *TypedVal = dyn_cast(Val)) { + // The name of return value makes no sense here. + auto *ReturnValueName = CurRec->getReturnValueName(); + AddValue( + CurRec, Loc, + RecordVal(ReturnValueName, Loc, ReturnType, RecordVal::FK_ReturnValue)); + if (SetValue(CurRec, Loc, ReturnValueName, std::nullopt, Val, + /*AllowSelfAssignment=*/false, /*OverrideDefLoc=*/false)) + return Error(Loc, "return value '" + TypedVal->getAsString() + + "' of type '" + TypedVal->getType()->getAsString() + + "' is incompatible with return type '" + + ReturnType->getAsString() + "'"); + } else + return Error(Loc, "return value shouldn't be uninitilized"); + + if (!consume(tgtok::semi)) + return TokError("expected ';' after return statement"); + + return false; +} + +/// ParseCall - Parse a function call. Value should be callable. +/// +/// Call ::= Value "(" FunctionArgList ")" +/// +Init *TGParser::ParseCall(Record *CurRec, Init *FuncRef, + FunctionRecTy *FuncTy) { + SMLoc CallLoc = Lex.getLoc(); + if (!consume(tgtok::l_paren)) { + TokError("expected '(' before function arguments"); + return nullptr; + } + + SmallVector Args; + if (ParseFunctionArgValueList(FuncTy, Args, CurRec)) + return nullptr; // Error parsing value list. + + return FuncCallInit::get(FuncRef, FuncTy, Args, CallLoc)->Fold(); +} + /// ParseForeach - Parse a for statement. Return the record corresponding /// to it. This returns true on error. /// @@ -3514,7 +3896,7 @@ // If this is the first reference to this class, create and add it. auto NewRec = std::make_unique(Lex.getCurStrVal(), Lex.getLoc(), Records, - /*Class=*/true); + /*Type=*/Record::RecordType::Class); CurRec = NewRec.get(); Records.addClass(std::move(NewRec)); } @@ -3891,6 +4273,8 @@ if (!Loops.empty()) return TokError("multiclass is not allowed inside foreach loop"); return ParseMultiClass(); + case tgtok::Function: + return ParseFunctionDefinition(); } } diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td --- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td +++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td @@ -293,8 +293,8 @@ def OpFunctionEnd: Op<56, (outs), (ins), "OpFunctionEnd"> { let isTerminator=1; } -def OpFunctionCall: Op<57, (outs ID:$res), (ins TYPE:$resType, ID:$function, variable_ops), - "$res = OpFunctionCall $resType $function">; +def OpFunctionCall: Op<57, (outs ID:$res), (ins TYPE:$resType, ID:$func, variable_ops), + "$res = OpFunctionCall $resType $func">; // 3.42.10 Image Instructions diff --git a/llvm/lib/Target/SystemZ/SystemZPatterns.td b/llvm/lib/Target/SystemZ/SystemZPatterns.td --- a/llvm/lib/Target/SystemZ/SystemZPatterns.td +++ b/llvm/lib/Target/SystemZ/SystemZPatterns.td @@ -170,6 +170,6 @@ // Use INSN to perform minimum/maximum operation OPERATOR on type TR. // FUNCTION is the type of minimum/maximum function to perform. class FPMinMax function> + bits<4> func> : Pat<(tr.vt (operator (tr.vt tr.op:$vec1), (tr.vt tr.op:$vec2))), - (insn tr.op:$vec1, tr.op:$vec2, function)>; + (insn tr.op:$vec1, tr.op:$vec2, func)>; diff --git a/llvm/test/TableGen/function.td b/llvm/test/TableGen/function.td new file mode 100644 --- /dev/null +++ b/llvm/test/TableGen/function.td @@ -0,0 +1,534 @@ +// RUN: llvm-tblgen %s | FileCheck %s + +// ---- Test basic function items and default values ---- // + +class Value { + int value = v; +} +def value0: Value<0>; +def value1: Value<1>; +def value2: Value<2>; + +function sub_func(int a, int b): int { + return !sub(a, b); +} + +function add_func(int a, int b): int { + return !add(a, b); +} + +function default_arg(int a, int b = 3): int { + assert !gt(b, 2), "b > 2"; + return !add(a, b); +} + +function no_arg(): int { + defvar ret = !mul(2, 3); + return ret; +} + +function return_imm(): int { + return 2333; +} + +function return_record(): Value { + return value0; +} + +function return_record_field(): int { + return value1.value; +} + +function local_var_priority_global(): int { + defvar value0 = 233; + return value0; +} + +function local_var_priority_arg(int value0): int { + defvar value0 = 233; + return value0; +} + +function call_inside_dag(int i): Value { + return !cast("value"#i); +} + +// CHECK: def test1 { +// CHECK-NEXT: int add_value = 4; +// CHECK-NEXT: int default_arg_value = 8; +// CHECK-NEXT: int no_arg_value = 6; +// CHECK-NEXT: Value return_record_value = value0; +// CHECK-NEXT: int return_record_field_value = 1; +// CHECK-NEXT: int local_var_priority_global_value = 233; +// CHECK-NEXT: int local_var_priority_arg_value = 233; +// CHECK-NEXT: dag call_inside_dag_value = (value0 value1, value2); +// CHECK-NEXT: } +class BasicTest { + int add_value = !add(add_func(2, 3), sub_func(2, 3)); + int default_arg_value = default_arg(add_func(2, 3)); + int no_arg_value = no_arg(); + Value return_record_value = return_record(); + int return_record_field_value = return_record_field(); + int local_var_priority_global_value = local_var_priority_global(); + int local_var_priority_arg_value = local_var_priority_arg(2); + dag call_inside_dag_value = (call_inside_dag'(0) call_inside_dag'(1), call_inside_dag'(2)); +} + +def test1: BasicTest; + +// ---- Test all types ---- // +class A; +def a: A; +function class_type(A a): A { + return a; +} + +function bit_type(bit a): bit { + return a; +} + +function bits_type(bits<8> a): bits<8> { + return a; +} + +function int_type(int a): int { + return a; +} + +function string_type(string a): string { + return a; +} + +function code_type(code a): code { + return a; +} + +function list_type(list a) : list { + return a; +} + +function func_type(function func) : function { + return func; +} + +function dag_type(dag a): dag { + return a; +} + +// CHECK: def test2 { +// CHECK-NEXT: A class_value = a; +// CHECK-NEXT: bit bit_value = 1; +// CHECK-NEXT: bits<8> bits_value = { 1, 1, 1, 0, 1, 0, 0, 1 }; +// CHECK-NEXT: int int_value = 2333; +// CHECK-NEXT: string string_value = "string"; +// CHECK-NEXT: code code_value = [{code}]; +// CHECK-NEXT: list list_value = [1, 2, 3]; +// CHECK-NEXT: function func_value = add_func; +// CHECK-NEXT: dag dag_value = (a 1, 2, 3); +// CHECK-NEXT: } +class TypeTest { + A class_value = class_type(a); + bit bit_value = bit_type(1); + bits<8> bits_value = bits_type(233); + int int_value = int_type(2333); + string string_value = string_type("string"); + code code_value = code_type([{code}]); + list list_value = list_type([1, 2, 3]); + function func_value = func_type(add_func); + dag dag_value = dag_type((a 1, 2, 3)); +} + +def test2: TypeTest; + +// ---- Test all bang operators ---- // + +class B; +def b0: B; +def b1: B; + +function test_add(int a, int b): int { + return !add(a, b); +} + +function test_and(int a, int b): int { + return !and(a, b); +} + +function test_cast(string n): B { + return !cast(n); +} + +function test_con(dag a, dag b): dag { + return !con(a, b); +} + +function test_dag(): dag { + return !dag(b0, [1, 2, 3], ["a", "b", "c"]); +} + +function test_div(int a, int b): int { + return !div(a, b); +} + +function test_empty(list l): int { + return !empty(l); +} + +function test_eq(int a, int b): int { + return !eq(a, b); +} + +function test_exists(string n): int { + return !exists(n); +} + +function test_filter(list l): list { + return !filter(a, l, !gt(a, 0)); +} + +function test_find(string a, string b): int { + return !find(a, b); +} + +function test_foldl(list l): int { + return !foldl(0, l, acc, a, !add(acc, a)); +} + +function test_foreach(list l): list { + return !foreach(a, l, !add(a, 1)); +} + +function test_ge(int a, int b): int { + return !ge(a, b); +} + +function test_getdagop(dag a): B { + return !getdagop(a); +} + +function test_gt(int a, int b): int { + return !gt(a, b); +} + +function test_head(list l): int { + return !head(l); +} + +function test_if(int a): int { + return !if(!gt(a, 1), 0, 1); +} + +function test_interleave(list l): string { + return !interleave(l, ""); +} + +function test_isa(B b): int { + return !isa(b); +} + +function test_le(int a, int b): int { + return !le(a, b); +} + +function test_listconcat(list a, list b): list { + return !listconcat(a, b); +} + +function test_listremove(list a, list b): list { + return !listremove(a, b); +} + +function test_listsplat(int a, int b): list { + return !listsplat(a, b); +} + +function test_logtwo(int a): int { + return !logtwo(a); +} + +function test_lt(int a, int b): int { + return !lt(a, b); +} + +function test_mul(int a, int b): int { + return !mul(a, b); +} + +function test_ne(int a, int b): int { + return !ne(a, b); +} + +function test_not(int a): int { + return !not(a); +} + +function test_or(int a, int b): int { + return !or(a, b); +} + +function test_setdagop(dag a, B b): dag { + return !setdagop(a, b); +} + +function test_shl(int a, int b): int { + return !shl(a, b); +} + +function test_size(list l): int { + return !size(l); +} + +function test_sra(int a, int b): int { + return !sra(a, b); +} + +function test_srl(int a, int b): int { + return !srl(a, b); +} + +function test_strconcat(string a, string b): string { + return !strconcat(a, b); +} + +function test_sub(int a, int b): int { + return !sub(a, b); +} + +function test_subst(string target, string repl, string value): string { + return !subst(target, repl, value); +} + +function test_substr(string s, int a, int b): string { + return !substr(s, a, b); +} + +function test_tail(list l): list { + return !tail(l); +} + +function test_tolower(string s): string { + return !tolower(s); +} + +function test_toupper(string s): string { + return !toupper(s); +} + +function test_xor(int a, int b): int { + return !xor(a, b); +} + +function test_cond(int a): int { + return !cond( + !eq(a, 0): 0, + !lt(a, 0): -1, + !gt(a, 0): 1 + ); +} + +// CHECK: def test3 { +// CHECK-NEXT: int add_value = 3; +// CHECK-NEXT: int and_value = 0; +// CHECK-NEXT: B cast_value = b0; +// CHECK-NEXT: dag con_value = (b0 1, 2, 3); +// CHECK-NEXT: dag dag_value = (b0 1:$a, 2:$b, 3:$c); +// CHECK-NEXT: int div_value = 2; +// CHECK-NEXT: int empty_value = 0; +// CHECK-NEXT: int eq_value = 0; +// CHECK-NEXT: int exists_value = 1; +// CHECK-NEXT: list filter_value = [1, 2, 3]; +// CHECK-NEXT: int find_value = 1; +// CHECK-NEXT: int foldl_value = 6; +// CHECK-NEXT: list foreach_value = [2, 3, 4]; +// CHECK-NEXT: int ge_value = 0; +// CHECK-NEXT: B getdagop_value = b0; +// CHECK-NEXT: int gt_value = 0; +// CHECK-NEXT: int head_value = 1; +// CHECK-NEXT: int if_value = 1; +// CHECK-NEXT: string interleave_value = "123"; +// CHECK-NEXT: int isa_value = 1; +// CHECK-NEXT: int le_value = 1; +// CHECK-NEXT: list listconcat_value = [1, 2, 3, 4]; +// CHECK-NEXT: list listremove_value = [3, 4]; +// CHECK-NEXT: list listsplat_value = [5, 5, 5]; +// CHECK-NEXT: int logtwo_value = 3; +// CHECK-NEXT: int lt_value = 1; +// CHECK-NEXT: int mul_value = 2; +// CHECK-NEXT: int ne_value = 1; +// CHECK-NEXT: int not_value = 0; +// CHECK-NEXT: int or_value = 3; +// CHECK-NEXT: dag setdagop_value = (b1 1, 2, 3); +// CHECK-NEXT: int shl_value = 4; +// CHECK-NEXT: int size_value = 3; +// CHECK-NEXT: int sra_value = 1; +// CHECK-NEXT: int srl_value = 1; +// CHECK-NEXT: string strconcat_value = "ab"; +// CHECK-NEXT: int sub_value = -1; +// CHECK-NEXT: string subst_value = "aaabc"; +// CHECK-NEXT: string substr_value = "bc"; +// CHECK-NEXT: list tail_value = [2, 3]; +// CHECK-NEXT: string tolower_value = "abc"; +// CHECK-NEXT: string toupper_value = "ABC"; +// CHECK-NEXT: int xor_value = 3; +// CHECK-NEXT: int cond_value = 1; +// CHECK-NEXT: } +class BangOperatorTest { + int add_value = test_add(1, 2); + int and_value = test_and(1, 2); + B cast_value = test_cast("b0"); + dag con_value = test_con((b0 1, 2), (b0 3)); + dag dag_value = test_dag(); + int div_value = test_div(4, 2); + int empty_value = test_empty([1, 2, 3]); + int eq_value = test_eq(1, 2); + int exists_value = test_exists("b1"); + list filter_value = test_filter([0, 1, 2, 3]); + int find_value = test_find("abc", "b"); + int foldl_value = test_foldl([1, 2, 3]); + list foreach_value = test_foreach([1, 2, 3]); + int ge_value = test_ge(1, 2); + B getdagop_value = test_getdagop((b0 1, 2, 3)); + int gt_value = test_gt(1, 2); + int head_value = test_head([1, 2, 3]); + int if_value = test_if(1); + string interleave_value = test_interleave([1, 2, 3]); + int isa_value = test_isa(b1); + int le_value = test_le(1, 2); + list listconcat_value = test_listconcat([1, 2, 3], [4]); + list listremove_value = test_listremove([1, 2, 3, 4], [1, 2]); + list listsplat_value = test_listsplat(5, 3); + int logtwo_value = test_logtwo(9); + int lt_value = test_lt(1, 2); + int mul_value = test_mul(1, 2); + int ne_value = test_ne(1, 2); + int not_value = test_not(1); + int or_value = test_or(1, 2); + dag setdagop_value = test_setdagop((b0 1, 2, 3), b1); + int shl_value = test_shl(1, 2); + int size_value = test_size([1, 2, 3]); + int sra_value = test_sra(2, 1); + int srl_value = test_srl(2, 1); + string strconcat_value = test_strconcat("a", "b"); + int sub_value = test_sub(1, 2); + string subst_value = test_subst("abc", "", "aaabc"); + string substr_value = test_substr("abc", 1, 2); + list tail_value = test_tail([1, 2, 3]); + string tolower_value = test_tolower("ABC"); + string toupper_value = test_toupper("abc"); + int xor_value = test_xor(1, 2); + int cond_value = test_cond(2); +} + +def test3: BangOperatorTest; + +// ---- Test recursive function ---- // + +function fib(int n):int{ + return !if(!lt(n, 2), + 1, + !add( + fib(!sub(n, 1)), + fib(!sub(n, 2)) + ) + ); +} + +// CHECK: def test4_fib_0 { +// CHECK-NEXT: int value = 1; +// CHECK: def test4_fib_1 { +// CHECK-NEXT: int value = 1; +// CHECK: def test4_fib_2 { +// CHECK-NEXT: int value = 2; +// CHECK: def test4_fib_3 { +// CHECK-NEXT: int value = 3; +// CHECK: def test4_fib_4 { +// CHECK-NEXT: int value = 5; +// CHECK: def test4_fib_5 { +// CHECK-NEXT: int value = 8; +// CHECK: def test4_fib_6 { +// CHECK-NEXT: int value = 13; +// CHECK: def test4_fib_7 { +// CHECK-NEXT: int value = 21; +// CHECK: def test4_fib_8 { +// CHECK-NEXT: int value = 34; +// CHECK: def test4_fib_9 { +// CHECK-NEXT: int value = 55; +foreach n = 0...9 in { + def "test4_fib_" # n { + int value = fib(n); + } +} + +// ---- Test function as a value in class arguments ---- // +class FuncInClassArg func> { + int value = func(2333, 2333); +} + +// CHECK: def test5 { +// CHECK-NEXT: int value = 4666; +// CHECK-NEXT: } +def test5: FuncInClassArg; + +// ---- Test function as a value in multiclass arguments ---- // +multiclass FuncInMulticlassArg func> { + def "" { + int value = func(2333, 2333); + } +} + +// CHECK: def test6 { +// CHECK-NEXT: int value = 4666; +// CHECK-NEXT: } +defm test6: FuncInMulticlassArg; + +// ---- Test function as a value in foreach iterator ---- // + +// CHECK: def test7 { +// CHECK-NEXT: int value = 4666; +// CHECK-NEXT: } +foreach f = [add_func] in { + def test7 { + int value = f(2333, 2333); + } +} + +// ---- Test function as a field in class ---- // + +class FuncAsClassField { + function func; +} +let func = add_func in { + def func_as_class_field: FuncAsClassField; +} + +// CHECK: def test8 { +// CHECK-NEXT: int value = 4666; +// CHECK-NEXT: } +def test8 { + int value = func_as_class_field.func(2333, 2333); +} + +// ---- Test that function accepts values with sub-type ---- // + +class Base{ + int value = a; +} +class Sub: Base; + +def base: Base<233>; +def sub : Sub<2333>; + +function func_accept_Base(Base base): int { + return add_func(base.value, base.value); +} + +class TestTypeCompatible { + int base_result = func_accept_Base(base); + int sub_result = func_accept_Base(sub); +} + +// CHECK: def test9 { +// CHECK-NEXT: int base_result = 466; +// CHECK-NEXT: int sub_result = 4666; +// CHECK-NEXT: } +def test9: TestTypeCompatible; diff --git a/llvm/test/TableGen/lambda.td b/llvm/test/TableGen/lambda.td new file mode 100644 --- /dev/null +++ b/llvm/test/TableGen/lambda.td @@ -0,0 +1,179 @@ +// RUN: llvm-tblgen %s | FileCheck %s + +// ---- Test lambda as a value ---- // + +defvar global_lambda = function(int a, int b): int { + return !add(a, b); +}; + +class TestLambdaAsValue { + defvar local_lambda = function(int a, int b):int { return !add(a, b); }; + int global_lambda_value = global_lambda(a, b); + int local_lambda_value = local_lambda(a, b); + int anonymous_lambda_value = function(int a, int b): int { + return !add(a, b); + }(a, b); +} + +// CHECK: def test1 { +// CHECK-NEXT: int global_lambda_value = 4666; +// CHECK-NEXT: int local_lambda_value = 4666; +// CHECK-NEXT: int anonymous_lambda_value = 4666; +// CHECK-NEXT: } +def test1: TestLambdaAsValue<2333, 2333>; + +// ---- Test lambda captures variables correctly ---- // + +class CaptureClassArg{ + function f = function(int b):int { + return !add(a, b); + }; +} +def captureClassArg: CaptureClassArg<2333>; + +// CHECK: def test2 { +// CHECK-NEXT: int value = 4666; +// CHECK-NEXT: } +def test2 { + int value = captureClassArg.f(2333); +} + +multiclass CaptureMulticlassArg{ + def "" { + function f = function(int b):int { + return !add(a, b); + }; + } +} +defm captureMulticlassArg: CaptureMulticlassArg<2333>; + +// CHECK: def test3 { +// CHECK-NEXT: int value = 4666; +// CHECK-NEXT: } +def test3 { + int value = captureMulticlassArg.f(2333); +} + +foreach a = 0...2 in { + def "captureForeachVar" # a { + function f = function(int b):int { + return !add(a, b); + }; + } +} + +// CHECK: def test4 { +// CHECK-NEXT: int value0 = 2333; +// CHECK-NEXT: int value1 = 2334; +// CHECK-NEXT: int value2 = 2335; +// CHECK-NEXT: } +def test4 { + int value0 = captureForeachVar0.f(2333); + int value1 = captureForeachVar1.f(2333); + int value2 = captureForeachVar2.f(2333); +} + +// ---- Test function return lambda ---- // + +function func_return_lambda(int func_arg): function { + return function(int lambda_arg): int { + return !add(func_arg, lambda_arg); + }; +} + +// CHECK: def test5 { +// CHECK-NEXT: int value0 = 466; +// CHECK-NEXT: int value1 = 4666; +// CHECK-NEXT: } +def test5 { + defvar lambda0 = func_return_lambda(233); + defvar lambda1 = func_return_lambda(2333); + int value0 = lambda0(233); + int value1 = lambda1(2333); +} + +function func_return_lambda_return_lambda(int func_arg): function, int> { + return function(int outer_lambda_arg): function { + return function(int inner_lambda_arg): int { + return !add(func_arg, !add(outer_lambda_arg, inner_lambda_arg)); + }; + }; +} + +// CHECK: def test6 { +// CHECK-NEXT: int value0 = 699; +// CHECK-NEXT: int value1 = 6999; +// CHECK-NEXT: } +def test6 { + defvar lambda0 = func_return_lambda_return_lambda(233); + defvar lambda1 = func_return_lambda_return_lambda(2333); + defvar lambda0_result = lambda0(233); + defvar lambda1_result = lambda1(2333); + int value0 = lambda0_result(233); + int value1 = lambda1_result(2333); +} + +// ---- Test lambda as a class argument ---- // +class TestLambdaAsClassArg func, int a, int b> { + int value = func(a, b); +} + +// CHECK: def test7 { +// CHECK-NEXT: int value = 4666; +// CHECK-NEXT: } +def test7: TestLambdaAsClassArg; + +// ---- Test lambdas in some bang operators ---- // + +defvar func_list = [global_lambda, function(int a, int b): int { + return !sub(a, b); +}]; + +class TestBangOperators> funcs, int a, int b> { + list foreach_value = !foreach(f, funcs, f(a, b)); + int foldl_value = !foldl(0, !foreach(f, funcs, f(a, b)), total, v, !add(total, v)); +} + +// CHECK: def test8 { +// CHECK-NEXT: list foreach_value = [4666, 0]; +// CHECK-NEXT: int foldl_value = 4666; +// CHECK-NEXT: } +def test8: TestBangOperators; + +// ---- Object-Oriented Programming ---- // +class Shape { + function getArea; +} + +class Square: Shape { + int length = l; + let getArea = function():int{ + return !mul(l, l); + }; +} + +class Triangle: Shape { + int base = b; + int height = h; + let getArea = function():int{ + return !div(!mul(b, h), 2); + }; +} + +def square:Square<4>; +def triangle:Triangle<4, 4>; + +class TestGetArea { + int value = shape.getArea(); +} + +// CHECK: def test9_area_of_square { +// CHECK-NEXT: int value = 16; +// CHECK-NEXT: } +// CHECK: def test9_area_of_triangle { +// CHECK-NEXT: int value = 8; +// CHECK-NEXT: } +def test9_area_of_square: TestGetArea; +def test9_area_of_triangle: TestGetArea;