Index: include/llvm/TableGen/Record.h =================================================================== --- include/llvm/TableGen/Record.h +++ include/llvm/TableGen/Record.h @@ -220,26 +220,47 @@ std::string getAsString() const override; }; -/// '[classname]' - Represent an instance of a class, such as: -/// (R32 X = EAX). -class RecordRecTy : public RecTy { +/// '[classname]' - Type of record values that have zero or more superclasses. +/// +/// The list of superclasses is non-redundant, i.e. only contains classes that +/// are not the superclass of some other listed class. +class RecordRecTy final : public RecTy, public FoldingSetNode, + public TrailingObjects { friend class Record; - Record *Rec; + unsigned NumClasses; - explicit RecordRecTy(Record *R) : RecTy(RecordRecTyKind), Rec(R) {} + explicit RecordRecTy(unsigned Num) + : RecTy(RecordRecTyKind), NumClasses(Num) {} public: + RecordRecTy(const RecordRecTy &) = delete; + RecordRecTy &operator=(const RecordRecTy &) = delete; + + // Do not use sized deallocation due to trailing objects. + void operator delete(void *p) { ::operator delete(p); } + static bool classof(const RecTy *RT) { return RT->getRecTyKind() == RecordRecTyKind; } - static RecordRecTy *get(Record *R); + /// Get the record type with the given non-redundant list of superclasses. + static RecordRecTy *get(ArrayRef Classes); + + void Profile(FoldingSetNodeID &ID) const; + + ArrayRef getClasses() const { + return makeArrayRef(getTrailingObjects(), NumClasses); + } + + using const_record_iterator = Record * const *; - Record *getRecord() const { return Rec; } + const_record_iterator classes_begin() const { return getClasses().begin(); } + const_record_iterator classes_end() const { return getClasses().end(); } std::string getAsString() const override; + bool isSubClassOf(Record *Class) const; bool typeIsConvertibleTo(const RecTy *RHS) const override; }; @@ -981,7 +1002,7 @@ Record *Def; - DefInit(Record *D, RecordRecTy *T) : TypedInit(IK_DefInit, T), Def(D) {} + explicit DefInit(Record *D); public: DefInit(const DefInit &) = delete; @@ -1183,6 +1204,9 @@ SmallVector Locs; SmallVector TemplateArgs; SmallVector Values; + + // All superclasses in the inheritance forest in reverse preorder (yes, it + // must be a forest; diamond-shaped inheritance is not allowed). SmallVector, 0> SuperClasses; // Tracks Record instances. Not owned by Record. @@ -1263,6 +1287,9 @@ return SuperClasses; } + /// Append the direct super classes of this record to Classes. + void getDirectSuperClasses(SmallVectorImpl &Classes) const; + bool isTemplateArg(Init *Name) const { for (Init *TA : TemplateArgs) if (TA == Name) return true; @@ -1452,8 +1479,10 @@ }; class RecordKeeper { + friend class RecordRecTy; using RecordMap = std::map>; RecordMap Classes, Defs; + FoldingSet RecordTypePool; public: const RecordMap &getClasses() const { return Classes; } Index: lib/TableGen/Record.cpp =================================================================== --- lib/TableGen/Record.cpp +++ lib/TableGen/Record.cpp @@ -125,54 +125,138 @@ return "dag"; } -RecordRecTy *RecordRecTy::get(Record *R) { - return dyn_cast(R->getDefInit()->getType()); +static void ProfileRecordRecTy(FoldingSetNodeID &ID, + ArrayRef Classes) { + ID.AddInteger(Classes.size()); + for (Record *R : Classes) + ID.AddPointer(R); +} + +RecordRecTy *RecordRecTy::get(ArrayRef UnsortedClasses) { + if (UnsortedClasses.empty()) { + static RecordRecTy AnyRecord(0); + return &AnyRecord; + } + + FoldingSet &ThePool = + UnsortedClasses[0]->getRecords().RecordTypePool; + + SmallVector Classes(UnsortedClasses.begin(), + UnsortedClasses.end()); + std::sort(Classes.begin(), Classes.end(), + [](Record *LHS, Record *RHS) { + return LHS->getNameInitAsString() < RHS->getNameInitAsString(); + }); + + FoldingSetNodeID ID; + ProfileRecordRecTy(ID, Classes); + + void *IP = nullptr; + if (RecordRecTy *Ty = ThePool.FindNodeOrInsertPos(ID, IP)) + return Ty; + +#ifndef NDEBUG + // Check for redundancy. + for (unsigned i = 0; i < Classes.size(); ++i) { + for (unsigned j = 0; j < Classes.size(); ++j) { + assert(i == j || !Classes[i]->isSubClassOf(Classes[j])); + } + assert(&Classes[0]->getRecords() == &Classes[i]->getRecords()); + } +#endif + + void *Mem = Allocator.Allocate(totalSizeToAlloc(Classes.size()), + alignof(RecordRecTy)); + RecordRecTy *Ty = new(Mem) RecordRecTy(Classes.size()); + std::uninitialized_copy(Classes.begin(), Classes.end(), + Ty->getTrailingObjects()); + ThePool.InsertNode(Ty, IP); + return Ty; +} + +void RecordRecTy::Profile(FoldingSetNodeID &ID) const { + ProfileRecordRecTy(ID, getClasses()); } std::string RecordRecTy::getAsString() const { - return Rec->getName(); + if (NumClasses == 1) + return getClasses()[0]->getName(); + + std::string Str = "{"; + bool First = true; + for (Record *R : getClasses()) { + if (!First) + Str += ", "; + First = false; + Str += R->getName(); + } + Str += "}"; + return Str; +} + +bool RecordRecTy::isSubClassOf(Record *Class) const { + return llvm::any_of(getClasses(), [Class](Record *MySuperClass) { + return MySuperClass == Class || + MySuperClass->isSubClassOf(Class); + }); } bool RecordRecTy::typeIsConvertibleTo(const RecTy *RHS) const { + if (this == RHS) + return true; + const RecordRecTy *RTy = dyn_cast(RHS); if (!RTy) return false; - if (RTy->getRecord() == Rec || Rec->isSubClassOf(RTy->getRecord())) - return true; + return llvm::all_of(RTy->getClasses(), [this](Record *TargetClass) { + return isSubClassOf(TargetClass); + }); +} - for (const auto &SCPair : RTy->getRecord()->getSuperClasses()) - if (Rec->isSubClassOf(SCPair.first)) - return true; +static RecordRecTy *resolveRecordTypes(RecordRecTy *T1, RecordRecTy *T2) { + SmallVector CommonSuperClasses; + SmallVector Stack; - return false; + Stack.insert(Stack.end(), T1->classes_begin(), T1->classes_end()); + + while (!Stack.empty()) { + Record *R = Stack.back(); + Stack.pop_back(); + + if (T2->isSubClassOf(R)) { + CommonSuperClasses.push_back(R); + } else { + R->getDirectSuperClasses(Stack); + } + } + + return RecordRecTy::get(CommonSuperClasses); } RecTy *llvm::resolveTypes(RecTy *T1, RecTy *T2) { + if (T1 == T2) + return T1; + + if (RecordRecTy *RecTy1 = dyn_cast(T1)) { + if (RecordRecTy *RecTy2 = dyn_cast(T2)) + return resolveRecordTypes(RecTy1, RecTy2); + } + if (T1->typeIsConvertibleTo(T2)) return T2; if (T2->typeIsConvertibleTo(T1)) return T1; - // If one is a Record type, check superclasses - if (RecordRecTy *RecTy1 = dyn_cast(T1)) { - // See if T2 inherits from a type T1 also inherits from - for (const auto &SuperPair1 : RecTy1->getRecord()->getSuperClasses()) { - RecordRecTy *SuperRecTy1 = RecordRecTy::get(SuperPair1.first); - RecTy *NewType1 = resolveTypes(SuperRecTy1, T2); - if (NewType1) - return NewType1; - } - } - if (RecordRecTy *RecTy2 = dyn_cast(T2)) { - // See if T1 inherits from a type T2 also inherits from - for (const auto &SuperPair2 : RecTy2->getRecord()->getSuperClasses()) { - RecordRecTy *SuperRecTy2 = RecordRecTy::get(SuperPair2.first); - RecTy *NewType2 = resolveTypes(T1, SuperRecTy2); - if (NewType2) - return NewType2; + if (ListRecTy *ListTy1 = dyn_cast(T1)) { + if (ListRecTy *ListTy2 = dyn_cast(T2)) { + RecTy* NewType = resolveTypes(ListTy1->getElementType(), + ListTy2->getElementType()); + if (NewType) + return NewType->getListTy(); } } + return nullptr; } @@ -1100,9 +1184,12 @@ } RecTy *TypedInit::getFieldType(StringInit *FieldName) const { - if (RecordRecTy *RecordType = dyn_cast(getType())) - if (RecordVal *Field = RecordType->getRecord()->getValue(FieldName)) - return Field->getType(); + if (RecordRecTy *RecordType = dyn_cast(getType())) { + for (Record *Rec : RecordType->getClasses()) { + if (RecordVal *Field = Rec->getValue(FieldName)) + return Field->getType(); + } + } return nullptr; } @@ -1177,11 +1264,8 @@ } if (auto *SRRT = dyn_cast(Ty)) { - // Ensure that this is compatible with Rec. - if (RecordRecTy *DRRT = dyn_cast(getType())) - if (DRRT->getRecord()->isSubClassOf(SRRT->getRecord()) || - DRRT->getRecord() == SRRT->getRecord()) - return const_cast(this); + if (getType()->typeIsConvertibleTo(SRRT)) + return const_cast(this); return nullptr; } @@ -1322,13 +1406,22 @@ return VarBitInit::get(const_cast(this), Bit); } +static RecordRecTy *makeDefInitType(Record *Rec) { + SmallVector SuperClasses; + Rec->getDirectSuperClasses(SuperClasses); + return RecordRecTy::get(SuperClasses); +} + +DefInit::DefInit(Record *D) + : TypedInit(IK_DefInit, makeDefInitType(D)), Def(D) {} + DefInit *DefInit::get(Record *R) { return R->getDefInit(); } Init *DefInit::convertInitializerTo(RecTy *Ty) const { if (auto *RRT = dyn_cast(Ty)) - if (getDef()->isSubClassOf(RRT->getRecord())) + if (getType()->typeIsConvertibleTo(RRT)) return const_cast(this); return nullptr; } @@ -1517,7 +1610,7 @@ DefInit *Record::getDefInit() { if (!TheInit) - TheInit = new(Allocator) DefInit(this, new(Allocator) RecordRecTy(this)); + TheInit = new(Allocator) DefInit(this); return TheInit; } @@ -1537,6 +1630,15 @@ // this. See TGParser::ParseDef and TGParser::ParseDefm. } +void Record::getDirectSuperClasses(SmallVectorImpl &Classes) const { + ArrayRef> SCs = getSuperClasses(); + while (!SCs.empty()) { + Record *SC = SCs.back().first; + SCs = SCs.drop_back(1 + SC->getSuperClasses().size()); + Classes.push_back(SC); + } +} + void Record::resolveReferencesTo(const RecordVal *RV) { for (RecordVal &Value : Values) { if (RV == &Value) // Skip resolve the same field as the given one Index: lib/TableGen/TGParser.cpp =================================================================== --- lib/TableGen/TGParser.cpp +++ lib/TableGen/TGParser.cpp @@ -1067,12 +1067,10 @@ return nullptr; } - if (MHSTy->typeIsConvertibleTo(RHSTy)) { - Type = RHSTy; - } else if (RHSTy->typeIsConvertibleTo(MHSTy)) { - Type = MHSTy; - } else { - TokError("inconsistent types for !if"); + Type = resolveTypes(MHSTy, RHSTy); + if (!Type) { + TokError(Twine("inconsistent types '") + MHSTy->getAsString() + + "' and '" + RHSTy->getAsString() + "' for !if"); return nullptr; } break; @@ -1270,7 +1268,7 @@ MCNameRV->getType()), NewRec->getNameInit(), StringRecTy::get()), - Class->getDefInit()->getType()); + NewRec->getDefInit()->getType()); } // The result of the expression is a reference to the new record. Index: test/TableGen/if-type.td =================================================================== --- /dev/null +++ test/TableGen/if-type.td @@ -0,0 +1,11 @@ +// RUN: not llvm-tblgen %s 2>&1 | FileCheck %s +// XFAIL: vg_leak + +class A {} +class B : A {} +class C : A {} + +// CHECK: Value 'x' of type 'C' is incompatible with initializer '{{.*}}' of type 'A' +class X { + C x = !if(cc, b, c); +} Index: test/TableGen/if.td =================================================================== --- test/TableGen/if.td +++ test/TableGen/if.td @@ -73,6 +73,30 @@ // CHECK-NEXT: int i = 0; def DI2: I2<1>; +// Check that !if with operands of different subtypes can initialize a +// supertype variable. +// +// CHECK: def EXd1 { +// CHECK: E x = E1d; +// CHECK: } +// +// CHECK: def EXd2 { +// CHECK: E x = E2d; +// CHECK: } +class E {} +class E1 : E {} +class E2 : E {} + +class EX { + E x = !if(cc, b, c); +} + +def E1d : E1<0>; +def E2d : E2<0>; + +def EXd1 : EX<1, E1d, E2d>; +def EXd2 : EX<0, E1d, E2d>; + // CHECK: def One // CHECK-NEXT: list first = [1, 2, 3]; // CHECK-NEXT: list rest = [1, 2, 3];