diff --git a/llvm/include/llvm/TableGen/Record.h b/llvm/include/llvm/TableGen/Record.h --- a/llvm/include/llvm/TableGen/Record.h +++ b/llvm/include/llvm/TableGen/Record.h @@ -317,7 +317,8 @@ IK_VarBitInit, IK_VarDefInit, IK_LastTypedInit, - IK_UnsetInit + IK_UnsetInit, + IK_ArgumentInit, }; private: @@ -480,6 +481,39 @@ std::string getAsString() const override { return "?"; } }; +// Represent an argument. +class ArgumentInit : public Init, public FoldingSetNode { + Init *Value; + +protected: + explicit ArgumentInit(Init *Value) : Init(IK_ArgumentInit), Value(Value) {} + +public: + ArgumentInit(const ArgumentInit &) = delete; + ArgumentInit &operator=(const ArgumentInit &) = delete; + + static bool classof(const Init *I) { return I->getKind() == IK_ArgumentInit; } + + RecordKeeper &getRecordKeeper() const { return Value->getRecordKeeper(); } + + static ArgumentInit *get(Init *Value); + + Init *getValue() const { return Value; } + + void Profile(FoldingSetNodeID &ID) const; + + Init *resolveReferences(Resolver &R) const override; + std::string getAsString() const override { return Value->getAsString(); } + + bool isComplete() const override { return false; } + bool isConcrete() const override { return false; } + Init *getBit(unsigned Bit) const override { return Value->getBit(Bit); } + Init *getCastTo(RecTy *Ty) const override { return Value->getCastTo(Ty); } + Init *convertInitializerTo(RecTy *Ty) const override { + return Value->convertInitializerTo(Ty); + } +}; + /// 'true'/'false' - Represent a concrete initializer for a bit. class BitInit final : public TypedInit { friend detail::RecordKeeperImpl; @@ -1278,8 +1312,9 @@ /// classname - Represent an uninstantiated anonymous class /// instantiation. -class VarDefInit final : public TypedInit, public FoldingSetNode, - public TrailingObjects { +class VarDefInit final : public TypedInit, + public FoldingSetNode, + public TrailingObjects { Record *Class; DefInit *Def = nullptr; // after instantiation unsigned NumArgs; @@ -1298,7 +1333,7 @@ static bool classof(const Init *I) { return I->getKind() == IK_VarDefInit; } - static VarDefInit *get(Record *Class, ArrayRef Args); + static VarDefInit *get(Record *Class, ArrayRef Args); void Profile(FoldingSetNodeID &ID) const; @@ -1307,20 +1342,24 @@ std::string getAsString() const override; - Init *getArg(unsigned i) const { + ArgumentInit *getArg(unsigned i) const { assert(i < NumArgs && "Argument index out of range!"); - return getTrailingObjects()[i]; + return getTrailingObjects()[i]; } - using const_iterator = Init *const *; + using const_iterator = ArgumentInit *const *; - const_iterator args_begin() const { return getTrailingObjects(); } + const_iterator args_begin() const { + return getTrailingObjects(); + } const_iterator args_end () const { return args_begin() + NumArgs; } size_t args_size () const { return NumArgs; } bool args_empty() const { return NumArgs == 0; } - ArrayRef args() const { return ArrayRef(args_begin(), NumArgs); } + ArrayRef args() const { + return ArrayRef(args_begin(), NumArgs); + } Init *getBit(unsigned Bit) const override { llvm_unreachable("Illegal bit reference off anonymous def"); diff --git a/llvm/lib/TableGen/Record.cpp b/llvm/lib/TableGen/Record.cpp --- a/llvm/lib/TableGen/Record.cpp +++ b/llvm/lib/TableGen/Record.cpp @@ -70,6 +70,7 @@ BitInit TrueBitInit; BitInit FalseBitInit; + FoldingSet TheArgumentInitPool; FoldingSet TheBitsInitPool; std::map TheIntInitPool; StringMap StringInitStringPool; @@ -349,6 +350,8 @@ RecordKeeper &Init::getRecordKeeper() const { if (auto *TyInit = dyn_cast(this)) return TyInit->getType()->getRecordKeeper(); + if (auto *ArgInit = dyn_cast(this)) + return ArgInit->getRecordKeeper(); return cast(this)->getRecordKeeper(); } @@ -364,6 +367,37 @@ return const_cast(this); } +static void ProfileArgumentInit(FoldingSetNodeID &ID, Init *Value) { + ID.AddPointer(Value); +} + +void ArgumentInit::Profile(FoldingSetNodeID &ID) const { + ProfileArgumentInit(ID, Value); +} + +ArgumentInit *ArgumentInit::get(Init *Value) { + FoldingSetNodeID ID; + ProfileArgumentInit(ID, Value); + + RecordKeeper &RK = Value->getRecordKeeper(); + detail::RecordKeeperImpl &RKImpl = RK.getImpl(); + void *IP = nullptr; + if (ArgumentInit *I = RKImpl.TheArgumentInitPool.FindNodeOrInsertPos(ID, IP)) + return I; + + ArgumentInit *I = new (RKImpl.Allocator) ArgumentInit(Value); + RKImpl.TheArgumentInitPool.InsertNode(I, IP); + return I; +} + +Init *ArgumentInit::resolveReferences(Resolver &R) const { + Init *NewValue = Value->resolveReferences(R); + if (NewValue != Value) + return ArgumentInit::get(NewValue); + + return const_cast(this); +} + BitInit *BitInit::get(RecordKeeper &RK, bool V) { return V ? &RK.getImpl().TrueBitInit : &RK.getImpl().FalseBitInit; } @@ -2131,9 +2165,8 @@ std::string DefInit::getAsString() const { return std::string(Def->getName()); } -static void ProfileVarDefInit(FoldingSetNodeID &ID, - Record *Class, - ArrayRef Args) { +static void ProfileVarDefInit(FoldingSetNodeID &ID, Record *Class, + ArrayRef Args) { ID.AddInteger(Args.size()); ID.AddPointer(Class); @@ -2145,7 +2178,7 @@ : TypedInit(IK_VarDefInit, RecordRecTy::get(Class)), Class(Class), NumArgs(N) {} -VarDefInit *VarDefInit::get(Record *Class, ArrayRef Args) { +VarDefInit *VarDefInit::get(Record *Class, ArrayRef Args) { FoldingSetNodeID ID; ProfileVarDefInit(ID, Class, Args); @@ -2154,11 +2187,11 @@ if (VarDefInit *I = RK.TheVarDefInitPool.FindNodeOrInsertPos(ID, IP)) return I; - void *Mem = RK.Allocator.Allocate(totalSizeToAlloc(Args.size()), - alignof(VarDefInit)); + void *Mem = RK.Allocator.Allocate( + totalSizeToAlloc(Args.size()), alignof(VarDefInit)); VarDefInit *I = new (Mem) VarDefInit(Class, Args.size()); std::uninitialized_copy(Args.begin(), Args.end(), - I->getTrailingObjects()); + I->getTrailingObjects()); RK.TheVarDefInitPool.InsertNode(I, IP); return I; } @@ -2188,7 +2221,7 @@ for (unsigned i = 0, e = TArgs.size(); i != e; ++i) { if (i < args_size()) - R.set(TArgs[i], getArg(i)); + R.set(TArgs[i], getArg(i)->getValue()); else R.set(TArgs[i], NewRec->getValue(TArgs[i])->getValue()); @@ -2222,11 +2255,11 @@ Init *VarDefInit::resolveReferences(Resolver &R) const { TrackUnresolvedResolver UR(&R); bool Changed = false; - SmallVector NewArgs; + SmallVector NewArgs; NewArgs.reserve(args_size()); - for (Init *Arg : args()) { - Init *NewArg = Arg->resolveReferences(UR); + for (ArgumentInit *Arg : args()) { + auto *NewArg = cast(Arg->resolveReferences(UR)); NewArgs.push_back(NewArg); Changed |= NewArg != Arg; } diff --git a/llvm/lib/TableGen/TGParser.h b/llvm/lib/TableGen/TGParser.h --- a/llvm/lib/TableGen/TGParser.h +++ b/llvm/lib/TableGen/TGParser.h @@ -244,13 +244,13 @@ using ArgValueHandler = std::function; bool resolveArguments( - Record *Rec, ArrayRef ArgValues, SMLoc Loc, + Record *Rec, ArrayRef ArgValues, SMLoc Loc, ArgValueHandler ArgValueHandler = [](Init *, Init *) {}); bool resolveArgumentsOfClass(MapResolver &R, Record *Rec, - ArrayRef ArgValues, SMLoc Loc); + ArrayRef ArgValues, SMLoc Loc); bool resolveArgumentsOfMultiClass(SubstStack &Substs, MultiClass *MC, - ArrayRef ArgValues, Init *DefmName, - SMLoc Loc); + ArrayRef ArgValues, + Init *DefmName, SMLoc Loc); private: // Parser methods. bool consume(tgtok::TokKind K); @@ -288,7 +288,7 @@ IDParseMode Mode = ParseValueMode); void ParseValueList(SmallVectorImpl &Result, Record *CurRec, RecTy *ItemType = nullptr); - bool ParseTemplateArgValueList(SmallVectorImpl &Result, + bool ParseTemplateArgValueList(SmallVectorImpl &Result, Record *CurRec, Record *ArgsRec); void ParseDagArgList( SmallVectorImpl> &Result, @@ -312,7 +312,7 @@ MultiClass *ParseMultiClassID(); bool ApplyLetStack(Record *CurRec); bool ApplyLetStack(RecordsEntry &Entry); - bool CheckTemplateArgValues(SmallVectorImpl &Values, + bool CheckTemplateArgValues(SmallVectorImpl &Values, SMLoc Loc, Record *ArgsRec); }; diff --git a/llvm/lib/TableGen/TGParser.cpp b/llvm/lib/TableGen/TGParser.cpp --- a/llvm/lib/TableGen/TGParser.cpp +++ b/llvm/lib/TableGen/TGParser.cpp @@ -36,7 +36,7 @@ struct SubClassReference { SMRange RefRange; Record *Rec; - SmallVector TemplateArgs; + SmallVector TemplateArgs; SubClassReference() : Rec(nullptr) {} @@ -46,7 +46,7 @@ struct SubMultiClassReference { SMRange RefRange; MultiClass *MC; - SmallVector TemplateArgs; + SmallVector TemplateArgs; SubMultiClassReference() : MC(nullptr) {} @@ -569,7 +569,7 @@ return false; } -bool TGParser::resolveArguments(Record *Rec, ArrayRef ArgValues, +bool TGParser::resolveArguments(Record *Rec, ArrayRef ArgValues, SMLoc Loc, ArgValueHandler ArgValueHandler) { ArrayRef ArgNames = Rec->getTemplateArgs(); assert(ArgValues.size() <= ArgNames.size() && @@ -579,7 +579,7 @@ // handle the (name, value) pair. If not and there was no default, complain. for (unsigned I = 0, E = ArgNames.size(); I != E; ++I) { if (I < ArgValues.size()) - ArgValueHandler(ArgNames[I], ArgValues[I]); + ArgValueHandler(ArgNames[I], ArgValues[I]->getValue()); else { Init *Default = Rec->getValue(ArgNames[I])->getValue(); if (!Default->isComplete()) @@ -597,7 +597,8 @@ /// Resolve the arguments of class and set them to MapResolver. /// Returns true if failed. bool TGParser::resolveArgumentsOfClass(MapResolver &R, Record *Rec, - ArrayRef ArgValues, SMLoc Loc) { + ArrayRef ArgValues, + SMLoc Loc) { return resolveArguments(Rec, ArgValues, Loc, [&](Init *Name, Init *Value) { R.set(Name, Value); }); } @@ -605,7 +606,7 @@ /// Resolve the arguments of multiclass and store them into SubstStack. /// Returns true if failed. bool TGParser::resolveArgumentsOfMultiClass(SubstStack &Substs, MultiClass *MC, - ArrayRef ArgValues, + ArrayRef ArgValues, Init *DefmName, SMLoc Loc) { // Add an implicit argument NAME. Substs.emplace_back(QualifiedNameOfImplicitName(MC), DefmName); @@ -2596,7 +2597,7 @@ return nullptr; } - SmallVector Args; + SmallVector Args; Lex.Lex(); // consume the < if (ParseTemplateArgValueList(Args, CurRec, Class)) return nullptr; // Error parsing value list. @@ -3121,8 +3122,8 @@ // error was detected. // // TemplateArgList ::= '<' [Value {',' Value}*] '>' -bool TGParser::ParseTemplateArgValueList(SmallVectorImpl &Result, - Record *CurRec, Record *ArgsRec) { +bool TGParser::ParseTemplateArgValueList( + SmallVectorImpl &Result, Record *CurRec, Record *ArgsRec) { assert(Result.empty() && "Result vector is not empty"); ArrayRef TArgs = ArgsRec->getTemplateArgs(); @@ -3144,7 +3145,7 @@ Init *Value = ParseValue(CurRec, ItemType); if (!Value) return true; - Result.push_back(Value); + Result.push_back(ArgumentInit::get(Value)); if (consume(tgtok::greater)) // end of argument list? return false; @@ -4247,9 +4248,8 @@ // inheritance, multiclass invocation, or anonymous class invocation. // If necessary, replace an argument with a cast to the required type. // The argument count has already been checked. -bool TGParser::CheckTemplateArgValues(SmallVectorImpl &Values, - SMLoc Loc, Record *ArgsRec) { - +bool TGParser::CheckTemplateArgValues( + SmallVectorImpl &Values, SMLoc Loc, Record *ArgsRec) { ArrayRef TArgs = ArgsRec->getTemplateArgs(); for (unsigned I = 0, E = Values.size(); I < E; ++I) { @@ -4257,13 +4257,13 @@ RecTy *ArgType = Arg->getType(); auto *Value = Values[I]; - if (TypedInit *ArgValue = dyn_cast(Value)) { + if (TypedInit *ArgValue = dyn_cast(Value->getValue())) { auto *CastValue = ArgValue->getCastTo(ArgType); if (CastValue) { assert((!isa(CastValue) || cast(CastValue)->getType()->typeIsA(ArgType)) && "result of template arg value cast has wrong type"); - Values[I] = CastValue; + Values[I] = ArgumentInit::get(CastValue); } else { PrintFatalError(Loc, "Value specified for template argument '" +