Index: include/llvm/TableGen/Record.h =================================================================== --- include/llvm/TableGen/Record.h +++ include/llvm/TableGen/Record.h @@ -16,6 +16,7 @@ #define LLVM_TABLEGEN_RECORD_H #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/FoldingSet.h" #include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/SmallVector.h" @@ -42,6 +43,7 @@ class Record; class RecordKeeper; class RecordVal; +class Resolver; class StringInit; //===----------------------------------------------------------------------===// @@ -360,7 +362,7 @@ /// variables which may not be defined at the time the expression is formed. /// If a value is set for the variable later, this method will be called on /// users of the value to allow the value to propagate out. - virtual Init *resolveReferences(Record &R, const RecordVal *RV) const { + virtual Init *resolveReferences(Resolver &R) const { return const_cast(this); } @@ -508,7 +510,7 @@ std::string getAsString() const override; - Init *resolveReferences(Record &R, const RecordVal *RV) const override; + Init *resolveReferences(Resolver &R) const override; Init *getBit(unsigned Bit) const override { assert(Bit < NumBits && "Bit index out of range!"); @@ -653,7 +655,7 @@ /// If a value is set for the variable later, this method will be called on /// users of the value to allow the value to propagate out. /// - Init *resolveReferences(Record &R, const RecordVal *RV) const override; + Init *resolveReferences(Resolver &R) const override; std::string getAsString() const override; @@ -746,7 +748,7 @@ // possible to fold. Init *Fold(Record *CurRec, MultiClass *CurMultiClass) const override; - Init *resolveReferences(Record &R, const RecordVal *RV) const override; + Init *resolveReferences(Resolver &R) const override; std::string getAsString() const override; }; @@ -800,7 +802,7 @@ // possible to fold. Init *Fold(Record *CurRec, MultiClass *CurMultiClass) const override; - Init *resolveReferences(Record &R, const RecordVal *RV) const override; + Init *resolveReferences(Resolver &R) const override; std::string getAsString() const override; }; @@ -862,7 +864,7 @@ return LHS->isComplete() && MHS->isComplete() && RHS->isComplete(); } - Init *resolveReferences(Record &R, const RecordVal *RV) const override; + Init *resolveReferences(Resolver &R) const override; std::string getAsString() const override; }; @@ -899,7 +901,7 @@ /// If a value is set for the variable later, this method will be called on /// users of the value to allow the value to propagate out. /// - Init *resolveReferences(Record &R, const RecordVal *RV) const override; + Init *resolveReferences(Resolver &R) const override; Init *getBit(unsigned Bit) const override; @@ -936,7 +938,7 @@ unsigned getBitNum() const override { return Bit; } std::string getAsString() const override; - Init *resolveReferences(Record &R, const RecordVal *RV) const override; + Init *resolveReferences(Resolver &R) const override; Init *getBit(unsigned B) const override { assert(B < 1 && "Bit index out of range!"); @@ -972,7 +974,7 @@ unsigned getElementNum() const { return Element; } std::string getAsString() const override; - Init *resolveReferences(Record &R, const RecordVal *RV) const override; + Init *resolveReferences(Resolver &R) const override; Init *getBit(unsigned Bit) const override; }; @@ -1032,7 +1034,7 @@ Init *getBit(unsigned Bit) const override; - Init *resolveReferences(Record &R, const RecordVal *RV) const override; + Init *resolveReferences(Resolver &R) const override; std::string getAsString() const override { return Rec->getAsString() + "." + FieldName->getValue().str(); @@ -1107,7 +1109,7 @@ return makeArrayRef(getTrailingObjects(), NumArgNames); } - Init *resolveReferences(Record &R, const RecordVal *RV) const override; + Init *resolveReferences(Resolver &R) const override; std::string getAsString() const override; @@ -1607,6 +1609,63 @@ Init *QualifyName(Record &CurRec, MultiClass *CurMultiClass, Init *Name, StringRef Scoper); +//===----------------------------------------------------------------------===// +// Resolvers +//===----------------------------------------------------------------------===// + +/// Interface for looking up the initializer for a variable name, used by +/// Init::resolveReferences. +class Resolver { + Record *CurRec; + +public: + explicit Resolver(Record *CurRec) : CurRec(CurRec) {} + virtual ~Resolver() {} + + Record *getCurrentRecord() const { return CurRec; } + + /// Return the initializer for the given variable name (should normally be a + /// StringInit), or nullptr if the name could not be resolved. + virtual Init *resolve(Init *VarName) = 0; + + // Whether bits in a BitsInit should stay unresolved if resolving them would + // result in a ? (UnsetInit). This behavior is used to represent instruction + // encodings by keeping references to unset variables within a record. + virtual bool keepUnsetBits() const { return false; } +}; + +/// Resolve all variables from a record except for unset variables. +class RecordResolver final : public Resolver { + DenseMap Cache; + SmallVector Stack; + +public: + explicit RecordResolver(Record &R) : Resolver(&R) {} + + Init *resolve(Init *VarName) override; + + bool keepUnsetBits() const override { return true; } +}; + +/// Resolve all references to a specific RecordVal. +// +// TODO: This is used for resolving references to template arguments, in a +// rather inefficient way. Change those uses to resolve all template +// arguments simultaneously and get rid of this class. +class RecordValResolver final : public Resolver { + const RecordVal *RV; + +public: + explicit RecordValResolver(Record &R, const RecordVal *RV) + : Resolver(&R), RV(RV) {} + + Init *resolve(Init *VarName) override { + if (VarName == RV->getNameInit()) + return RV->getValue(); + return nullptr; + } +}; + } // end namespace llvm #endif // LLVM_TABLEGEN_RECORD_H Index: lib/TableGen/Record.cpp =================================================================== --- lib/TableGen/Record.cpp +++ lib/TableGen/Record.cpp @@ -318,15 +318,15 @@ // Fix bit initializer to preserve the behavior that bit reference from a unset // bits initializer will resolve into VarBitInit to keep the field name and bit // number used in targets with fixed insn length. -static Init *fixBitInit(const RecordVal *RV, Init *Before, Init *After) { - if (RV || !isa(After)) +static Init *fixBitInit(const Resolver &R, Init *Before, Init *After) { + if (!isa(After) || !R.keepUnsetBits()) return After; return Before; } // resolveReferences - If there are any field references that refer to fields // that have been filled in, we can propagate the values now. -Init *BitsInit::resolveReferences(Record &R, const RecordVal *RV) const { +Init *BitsInit::resolveReferences(Resolver &R) const { bool Changed = false; SmallVector NewBits(getNumBits()); @@ -343,7 +343,7 @@ if (CurBitVar == CachedBitVar) { if (CachedBitVarChanged) { Init *Bit = CachedInit->getBit(CurBit->getBitNum()); - NewBits[i] = fixBitInit(RV, CurBit, Bit); + NewBits[i] = fixBitInit(R, CurBit, Bit); } continue; } @@ -353,7 +353,7 @@ Init *B; do { B = CurBitVar; - CurBitVar = CurBitVar->resolveReferences(R, RV); + CurBitVar = CurBitVar->resolveReferences(R); CachedBitVarChanged |= B != CurBitVar; Changed |= B != CurBitVar; } while (B != CurBitVar); @@ -361,7 +361,7 @@ if (CachedBitVarChanged) { Init *Bit = CurBitVar->getBit(CurBit->getBitNum()); - NewBits[i] = fixBitInit(RV, CurBit, Bit); + NewBits[i] = fixBitInit(R, CurBit, Bit); } } @@ -549,7 +549,7 @@ return DI->getDef(); } -Init *ListInit::resolveReferences(Record &R, const RecordVal *RV) const { +Init *ListInit::resolveReferences(Resolver &R) const { SmallVector Resolved; Resolved.reserve(size()); bool Changed = false; @@ -559,7 +559,7 @@ do { E = CurElt; - CurElt = CurElt->resolveReferences(R, RV); + CurElt = CurElt->resolveReferences(R); Changed |= E != CurElt; } while (E != CurElt); Resolved.push_back(E); @@ -712,12 +712,13 @@ return const_cast(this); } -Init *UnOpInit::resolveReferences(Record &R, const RecordVal *RV) const { - Init *lhs = LHS->resolveReferences(R, RV); +Init *UnOpInit::resolveReferences(Resolver &R) const { + Init *lhs = LHS->resolveReferences(R); if (LHS != lhs) - return (UnOpInit::get(getOpcode(), lhs, getType()))->Fold(&R, nullptr); - return Fold(&R, nullptr); + return (UnOpInit::get(getOpcode(), lhs, getType())) + ->Fold(R.getCurrentRecord(), nullptr); + return Fold(R.getCurrentRecord(), nullptr); } std::string UnOpInit::getAsString() const { @@ -860,13 +861,14 @@ return const_cast(this); } -Init *BinOpInit::resolveReferences(Record &R, const RecordVal *RV) const { - Init *lhs = LHS->resolveReferences(R, RV); - Init *rhs = RHS->resolveReferences(R, RV); +Init *BinOpInit::resolveReferences(Resolver &R) const { + Init *lhs = LHS->resolveReferences(R); + Init *rhs = RHS->resolveReferences(R); if (LHS != lhs || RHS != rhs) - return (BinOpInit::get(getOpcode(), lhs, rhs, getType()))->Fold(&R,nullptr); - return Fold(&R, nullptr); + return (BinOpInit::get(getOpcode(), lhs, rhs, getType())) + ->Fold(R.getCurrentRecord(), nullptr); + return Fold(R.getCurrentRecord(), nullptr); } std::string BinOpInit::getAsString() const { @@ -1064,9 +1066,8 @@ return const_cast(this); } -Init *TernOpInit::resolveReferences(Record &R, - const RecordVal *RV) const { - Init *lhs = LHS->resolveReferences(R, RV); +Init *TernOpInit::resolveReferences(Resolver &R) const { + Init *lhs = LHS->resolveReferences(R); if (getOpcode() == IF && lhs != LHS) { IntInit *Value = dyn_cast(lhs); @@ -1075,23 +1076,23 @@ if (Value) { // Short-circuit if (Value->getValue()) { - Init *mhs = MHS->resolveReferences(R, RV); - return (TernOpInit::get(getOpcode(), lhs, mhs, - RHS, getType()))->Fold(&R, nullptr); + Init *mhs = MHS->resolveReferences(R); + return (TernOpInit::get(getOpcode(), lhs, mhs, RHS, getType())) + ->Fold(R.getCurrentRecord(), nullptr); } - Init *rhs = RHS->resolveReferences(R, RV); - return (TernOpInit::get(getOpcode(), lhs, MHS, - rhs, getType()))->Fold(&R, nullptr); + Init *rhs = RHS->resolveReferences(R); + return (TernOpInit::get(getOpcode(), lhs, MHS, rhs, getType())) + ->Fold(R.getCurrentRecord(), nullptr); } } - Init *mhs = MHS->resolveReferences(R, RV); - Init *rhs = RHS->resolveReferences(R, RV); + Init *mhs = MHS->resolveReferences(R); + Init *rhs = RHS->resolveReferences(R); if (LHS != lhs || MHS != mhs || RHS != rhs) - return (TernOpInit::get(getOpcode(), lhs, mhs, rhs, - getType()))->Fold(&R, nullptr); - return Fold(&R, nullptr); + return (TernOpInit::get(getOpcode(), lhs, mhs, rhs, getType())) + ->Fold(R.getCurrentRecord(), nullptr); + return Fold(R.getCurrentRecord(), nullptr); } std::string TernOpInit::getAsString() const { @@ -1261,10 +1262,9 @@ return nullptr; } -Init *VarInit::resolveReferences(Record &R, const RecordVal *RV) const { - if (RecordVal *Val = R.getValue(VarName)) - if (RV == Val || (!RV && !isa(Val->getValue()))) - return Val->getValue(); +Init *VarInit::resolveReferences(Resolver &R) const { + if (Init *Val = R.resolve(VarName)) + return Val; return const_cast(this); } @@ -1291,8 +1291,8 @@ return TI->getAsString() + "{" + utostr(Bit) + "}"; } -Init *VarBitInit::resolveReferences(Record &R, const RecordVal *RV) const { - Init *I = TI->resolveReferences(R, RV); +Init *VarBitInit::resolveReferences(Resolver &R) const { + Init *I = TI->resolveReferences(R); if (TI != I) return I->getBit(getBitNum()); @@ -1315,9 +1315,8 @@ return TI->getAsString() + "[" + utostr(Element) + "]"; } -Init * -VarListElementInit::resolveReferences(Record &R, const RecordVal *RV) const { - Init *NewTI = TI->resolveReferences(R, RV); +Init *VarListElementInit::resolveReferences(Resolver &R) const { + Init *NewTI = TI->resolveReferences(R); if (ListInit *List = dyn_cast(NewTI)) { // Leave out-of-bounds array references as-is. This can happen without // being an error, e.g. in the untaken "branch" of an !if expression. @@ -1373,12 +1372,12 @@ return VarBitInit::get(const_cast(this), Bit); } -Init *FieldInit::resolveReferences(Record &R, const RecordVal *RV) const { - Init *NewRec = Rec->resolveReferences(R, RV); +Init *FieldInit::resolveReferences(Resolver &R) const { + Init *NewRec = Rec->resolveReferences(R); if (DefInit *DI = dyn_cast(NewRec)) { Init *FieldVal = DI->getDef()->getValue(FieldName)->getValue(); - Init *BVR = FieldVal->resolveReferences(R, RV); + Init *BVR = FieldVal->resolveReferences(R); if (BVR->isComplete()) return BVR; } @@ -1451,17 +1450,17 @@ return nullptr; } -Init *DagInit::resolveReferences(Record &R, const RecordVal *RV) const { +Init *DagInit::resolveReferences(Resolver &R) const { SmallVector NewArgs; NewArgs.reserve(arg_size()); bool ArgsChanged = false; for (const Init *Arg : getArgs()) { - Init *NewArg = Arg->resolveReferences(R, RV); + Init *NewArg = Arg->resolveReferences(R); NewArgs.push_back(NewArg); ArgsChanged |= NewArg != Arg; } - Init *Op = Val->resolveReferences(R, RV); + Init *Op = Val->resolveReferences(R); if (Op != Val || ArgsChanged) return DagInit::get(Op, ValName, NewArgs, getArgNames()); @@ -1551,11 +1550,19 @@ } void Record::resolveReferencesTo(const RecordVal *RV) { + RecordResolver RecResolver(*this); + RecordValResolver RecValResolver(*this, RV); + Resolver *R; + if (RV) + R = &RecValResolver; + else + R = &RecResolver; + for (RecordVal &Value : Values) { if (RV == &Value) // Skip resolve the same field as the given one continue; if (Init *V = Value.getValue()) - if (Value.setValue(V->resolveReferences(*this, RV))) + if (Value.setValue(V->resolveReferences(*R))) PrintFatalError(getLoc(), "Invalid value is found when setting '" + Value.getNameInitAsString() + "' after resolving references" + @@ -1565,7 +1572,7 @@ : "") + "\n"); } Init *OldName = getNameInit(); - Init *NewName = Name->resolveReferences(*this, RV); + Init *NewName = Name->resolveReferences(*R); if (NewName != OldName) { // Re-register with RecordKeeper. setName(NewName); @@ -1826,3 +1833,26 @@ NewName = BinOp->Fold(&CurRec, CurMultiClass); return NewName; } + +Init *RecordResolver::resolve(Init *VarName) { + Init *Val = Cache.lookup(VarName); + if (Val) + return Val; + + for (Init *S : Stack) { + if (S == VarName) + return nullptr; // prevent infinite recursion + } + + if (RecordVal *RV = getCurrentRecord()->getValue(VarName)) { + if (!isa(RV->getValue())) { + Val = RV->getValue(); + Stack.push_back(VarName); + Val = Val->resolveReferences(*this); + Stack.pop_back(); + } + } + + Cache[VarName] = Val; + return Val; +} Index: lib/TableGen/TGParser.cpp =================================================================== --- lib/TableGen/TGParser.cpp +++ lib/TableGen/TGParser.cpp @@ -317,7 +317,8 @@ // Process each value. for (unsigned i = 0; i < List->size(); ++i) { - Init *ItemVal = List->getElement(i)->resolveReferences(*CurRec, nullptr); + RecordResolver R(*CurRec); + Init *ItemVal = List->getElement(i)->resolveReferences(R); IterVals.push_back(IterRecord(CurLoop.IterVar, ItemVal)); if (ProcessForeachDefs(CurRec, Loc, IterVals)) return true;