diff --git a/llvm/include/llvm/IR/IRComparators.td b/llvm/include/llvm/IR/IRComparators.td new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/IR/IRComparators.td @@ -0,0 +1,215 @@ +// The different types in which operands can be compared like cmpNumbers, cmpAttrs, etc +class OperandCompareType { + string TyName = name; +} + +// To describe comparision of a single operand. CmpCastClass is used when an additional dyn_cast is +// needed to cast the pointers to another class and then compare Operands +class OperandCompare { + OperandCompareType compareTy = compare_ty; + string OperandFetchMethod = fetch_method; + bit needForLoop = need_loop; + string forEndMethod = for_end_method; +} + +// Describes the comparision of an instruction and contains the list of operands to compare for this +// specific instruction +class InstructionComparator ops = []> { + string InstClassName = inst_cls_name; + list Operands = ops; +} + +def CmpOperandsAsNumbers : OperandCompareType<"Numbers">; + +def CmpOperandsAsTypes : OperandCompareType<"Types">; + +def CmpOperandsAsOrderings : OperandCompareType<"Orderings">; + +def CmpOperandsAsMetaData : OperandCompareType<"RangeMetadata">; + +def CmpOperandsAsAttributes: OperandCompareType<"Attrs">; + +def CmpOperandsAsOperandBundle : OperandCompareType<"OperandBundlesSchema">; + +def CmpOperandsAsValues : OperandCompareType<"Values">; + +def CmpOperandsAsArrayRef : OperandCompareType<"ArrayRef">; + +//===----------------------------------------------------------------------===// +// AllocaInst +//===----------------------------------------------------------------------===// + +def AllocaInstTypes : OperandCompare; + +def AllocaInstAlignment : OperandCompare; + +def AllocaInstCompare : InstructionComparator<"AllocaInst", [AllocaInstTypes, AllocaInstAlignment]>; + +//===----------------------------------------------------------------------===// +// LoadInst +//===----------------------------------------------------------------------===// + +def LoadInstVolatile : OperandCompare; + +def LoadInstAlignment : OperandCompare; + +def LoadInstOrdering : OperandCompare; + +def LoadInstScopeID : OperandCompare; + +def LoadInstMetaData : OperandCompare; + +def LoadInstCompare : InstructionComparator<"LoadInst", [LoadInstVolatile, LoadInstAlignment, + LoadInstOrdering, LoadInstScopeID, LoadInstMetaData]>; + +//===----------------------------------------------------------------------===// +// StoreInst +//===----------------------------------------------------------------------===// + +def StoreInstVolatile : OperandCompare; + +def StoreInstAlignment : OperandCompare; + +def StoreInstOrdering : OperandCompare; + +def StoreInstScopeID : OperandCompare; + +def StoreInstCompare : InstructionComparator<"StoreInst", [StoreInstVolatile, StoreInstAlignment, + StoreInstOrdering, StoreInstScopeID]>; + +//===----------------------------------------------------------------------===// +// CmpInst +//===----------------------------------------------------------------------===// + +def CmpInstPredicate : OperandCompare; + +def CmpInstCompare : InstructionComparator<"CmpInst", [CmpInstPredicate]>; + +//===----------------------------------------------------------------------===// +// CallBase +//===----------------------------------------------------------------------===// + +def CallBaseCallingConv : OperandCompare; + +def CallBaseAttr : OperandCompare; + +def CallBaseOperandSchema : OperandCompare; + +def CallBaseMetaData : OperandCompare; + + +//===----------------------------------------------------------------------===// +// CallInst (Subclass of CallBase) +//===----------------------------------------------------------------------===// + +def CallInstTailCall : OperandCompare; + +def CallInstCompare : InstructionComparator<"CallInst", [CallBaseCallingConv, + CallBaseAttr, + CallInstTailCall, + CallBaseOperandSchema, + CallBaseMetaData]>; + +//===----------------------------------------------------------------------===// +// CallBrInst (Subclass of CallBase) +//===----------------------------------------------------------------------===// + +def CallBrInstCompare : InstructionComparator<"CallBrInst", [CallBaseCallingConv, + CallBaseAttr, + CallBaseOperandSchema, + CallBaseMetaData]>; + +//===----------------------------------------------------------------------===// +// GCStatepointInst (Subclass of CallBase) +//===----------------------------------------------------------------------===// + +def GCStatePointInstCompare : InstructionComparator<"GCStatepointInst", [CallBaseCallingConv, + CallBaseAttr, + CallBaseOperandSchema, + CallBaseMetaData]>; + +//===----------------------------------------------------------------------===// +// InvokeInst (Subclass of CallBase) +//===----------------------------------------------------------------------===// + +def InvokeInstCompare : InstructionComparator<"InvokeInst", [CallBaseCallingConv, + CallBaseAttr, + CallBaseOperandSchema, + CallBaseMetaData]>; + +//===----------------------------------------------------------------------===// +// InsertVal +//===----------------------------------------------------------------------===// + +def InsertValIndices : OperandCompare; + +def InsertValCompare : InstructionComparator<"InsertValueInst", [InsertValIndices]>; + +//===----------------------------------------------------------------------===// +// ExtractVal +//===----------------------------------------------------------------------===// + +def ExtractValIndices : OperandCompare; + +def ExtractValCompare : InstructionComparator<"ExtractValueInst", [ExtractValIndices]>; + +//===----------------------------------------------------------------------===// +// FenceInst +//===----------------------------------------------------------------------===// + +def FenceInstOrdering : OperandCompare; + +def FenceInstScopeID : OperandCompare; + +def FenceInstCompare : InstructionComparator<"FenceInst", [FenceInstOrdering, FenceInstScopeID]>; + +//===----------------------------------------------------------------------===// +// AtomicCmpXchgInst +//===----------------------------------------------------------------------===// + +def AtomicCmpXchgInstVolatile : OperandCompare; + +def AtomicCmpXchgInstWeak : OperandCompare; + +def AtomicCmpXchgSuccessOrd : OperandCompare; + +def AtomicCmpXchgFailureOrd : OperandCompare; + +def AtomicCmpXchgScopeID : OperandCompare; + +def AtomicCmpXchgCompare : InstructionComparator<"AtomicCmpXchgInst", [AtomicCmpXchgInstVolatile, + AtomicCmpXchgInstWeak, + AtomicCmpXchgSuccessOrd, + AtomicCmpXchgFailureOrd, + AtomicCmpXchgScopeID]>; + +//===----------------------------------------------------------------------===// +// AtomicRMWInst +//===----------------------------------------------------------------------===// + +def AtomicRMWOperation : OperandCompare; + +def AtomicRMWVolatile : OperandCompare; + +def AtomicRMWOrdering : OperandCompare; + +def AtomicRMWScopeID : OperandCompare; + +def AtomicRMWInstCompare : InstructionComparator<"AtomicRMWInst", [AtomicRMWOperation, AtomicRMWVolatile, + AtomicRMWOrdering, AtomicRMWScopeID]>; + +//===----------------------------------------------------------------------===// +// ShuffleVectorInst +//===----------------------------------------------------------------------===// + +def ShuffleVectMask : OperandCompare; + +def ShuffleVectCompare : InstructionComparator<"ShuffleVectorInst", [ShuffleVectMask]>; + +//===----------------------------------------------------------------------===// +// PhiInst +//===----------------------------------------------------------------------===// + +def PhiInstBBCmp : OperandCompare; + +def PhiInstCompare : InstructionComparator<"PHINode", [PhiInstBBCmp]>; \ No newline at end of file diff --git a/llvm/include/llvm/Transforms/Utils/FunctionComparator.h b/llvm/include/llvm/Transforms/Utils/FunctionComparator.h --- a/llvm/include/llvm/Transforms/Utils/FunctionComparator.h +++ b/llvm/include/llvm/Transforms/Utils/FunctionComparator.h @@ -19,6 +19,7 @@ #include "llvm/IR/Attributes.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/Statepoint.h" #include "llvm/IR/ValueMap.h" #include "llvm/Support/AtomicOrdering.h" #include "llvm/Support/Casting.h" @@ -261,15 +262,7 @@ /// just convert it to integers and call cmpNumbers. /// 5. Compare in operation operand types with cmpType in /// most significant operand first order. - /// 6. Last stage. Check operations for some specific attributes. - /// For example, for Load it would be: - /// 6.1.Load: volatile (as boolean flag) - /// 6.2.Load: alignment (as integer numbers) - /// 6.3.Load: ordering (as underlying enum class value) - /// 6.4.Load: synch-scope (as integer numbers) - /// 6.5.Load: range metadata (as integer ranges) - /// On this stage its better to see the code, since its not more than 10-15 - /// strings for particular instruction, and could change sometimes. + /// 6. Uses cmpIROperations to compare instruction specific properites /// /// Sets \p needToCmpOperands to true if the operands of the instructions /// still must be compared afterwards. In this case it's already guaranteed @@ -277,6 +270,11 @@ int cmpOperations(const Instruction *L, const Instruction *R, bool &needToCmpOperands) const; + /// This is an function in IRComparator.inc to compare specific + /// attributes of IR instructions auto-generated by tblgen. + /// The backend used is --gen-ir-cmp with tblgen file IRComparators.td + int cmpIROperations(const Instruction *L, const Instruction *R) const; + /// cmpType - compares two types, /// defines total ordering among the types set. /// @@ -323,6 +321,9 @@ int cmpAPInts(const APInt &L, const APInt &R) const; int cmpAPFloats(const APFloat &L, const APFloat &R) const; int cmpMem(StringRef L, StringRef R) const; + int cmpArrayRef(ArrayRef LIndices, ArrayRef RIndices) const; + int cmpArrayRef(ArrayRef LIndices, + ArrayRef RIndices) const; // The two functions undergoing comparison. const Function *FnL, *FnR; diff --git a/llvm/include/llvm/Transforms/Utils/IRComparator.inc b/llvm/include/llvm/Transforms/Utils/IRComparator.inc new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/Transforms/Utils/IRComparator.inc @@ -0,0 +1,150 @@ +/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ +|* *| +|* IRComparator Function Source Fragment *| +|* *| +|* Automatically generated file, do not edit! *| +|* *| +\*===----------------------------------------------------------------------===*/ + +#include "llvm/Transforms/Utils/FunctionComparator.h" +using namespace llvm; +// Function to compare IR instructions for equality +int FunctionComparator::cmpIROperations(const Instruction *L, + const Instruction *R) const { + assert(L->getOpcode() != R->getOpcode() && + "Cannot compare instructions with different Opcode"); + assert(L->getNumOperands() != R->getNumOperands() && + "Cannot compare instructions with different number of operands"); + if (const AllocaInst *castL = dyn_cast(L)) { + const AllocaInst *castR = cast(R); + if (int Res = + cmpTypes(castL->getAllocatedType(), castR->getAllocatedType())) + return Res; + return cmpNumbers(castL->getAlignment(), castR->getAlignment()); + } + if (const AtomicCmpXchgInst *castL = dyn_cast(L)) { + const AtomicCmpXchgInst *castR = cast(R); + if (int Res = cmpNumbers(castL->isVolatile(), castR->isVolatile())) + return Res; + if (int Res = cmpNumbers(castL->isWeak(), castR->isWeak())) + return Res; + if (int Res = cmpOrderings(castL->getSuccessOrdering(), + castR->getSuccessOrdering())) + return Res; + if (int Res = cmpOrderings(castL->getFailureOrdering(), + castR->getFailureOrdering())) + return Res; + return cmpNumbers(castL->getSyncScopeID(), castR->getSyncScopeID()); + } + if (const AtomicRMWInst *castL = dyn_cast(L)) { + const AtomicRMWInst *castR = cast(R); + if (int Res = cmpNumbers(castL->getOperation(), castR->getOperation())) + return Res; + if (int Res = cmpNumbers(castL->isVolatile(), castR->isVolatile())) + return Res; + if (int Res = cmpOrderings(castL->getOrdering(), castR->getOrdering())) + return Res; + return cmpNumbers(castL->getSyncScopeID(), castR->getSyncScopeID()); + } + if (const CallBrInst *castL = dyn_cast(L)) { + const CallBrInst *castR = cast(R); + if (int Res = cmpNumbers(castL->getCallingConv(), castR->getCallingConv())) + return Res; + if (int Res = cmpAttrs(castL->getAttributes(), castR->getAttributes())) + return Res; + if (int Res = cmpOperandBundlesSchema(*castL, *castR)) + return Res; + return cmpRangeMetadata(castL->getMetadata(LLVMContext::MD_range), + castR->getMetadata(LLVMContext::MD_range)); + } + if (const CallInst *castL = dyn_cast(L)) { + const CallInst *castR = cast(R); + if (int Res = cmpNumbers(castL->getCallingConv(), castR->getCallingConv())) + return Res; + if (int Res = cmpAttrs(castL->getAttributes(), castR->getAttributes())) + return Res; + if (int Res = + cmpNumbers(castL->getTailCallKind(), castR->getTailCallKind())) + return Res; + if (int Res = cmpOperandBundlesSchema(*castL, *castR)) + return Res; + return cmpRangeMetadata(castL->getMetadata(LLVMContext::MD_range), + castR->getMetadata(LLVMContext::MD_range)); + } + if (const CmpInst *castL = dyn_cast(L)) { + const CmpInst *castR = cast(R); + return cmpNumbers(castL->getPredicate(), castR->getPredicate()); + } + if (const ExtractValueInst *castL = dyn_cast(L)) { + const ExtractValueInst *castR = cast(R); + return cmpArrayRef(castL->getIndices(), castR->getIndices()); + } + if (const FenceInst *castL = dyn_cast(L)) { + const FenceInst *castR = cast(R); + if (int Res = cmpOrderings(castL->getOrdering(), castR->getOrdering())) + return Res; + return cmpNumbers(castL->getSyncScopeID(), castR->getSyncScopeID()); + } + if (const GCStatepointInst *castL = dyn_cast(L)) { + const GCStatepointInst *castR = cast(R); + if (int Res = cmpNumbers(castL->getCallingConv(), castR->getCallingConv())) + return Res; + if (int Res = cmpAttrs(castL->getAttributes(), castR->getAttributes())) + return Res; + if (int Res = cmpOperandBundlesSchema(*castL, *castR)) + return Res; + return cmpRangeMetadata(castL->getMetadata(LLVMContext::MD_range), + castR->getMetadata(LLVMContext::MD_range)); + } + if (const InsertValueInst *castL = dyn_cast(L)) { + const InsertValueInst *castR = cast(R); + return cmpArrayRef(castL->getIndices(), castR->getIndices()); + } + if (const InvokeInst *castL = dyn_cast(L)) { + const InvokeInst *castR = cast(R); + if (int Res = cmpNumbers(castL->getCallingConv(), castR->getCallingConv())) + return Res; + if (int Res = cmpAttrs(castL->getAttributes(), castR->getAttributes())) + return Res; + if (int Res = cmpOperandBundlesSchema(*castL, *castR)) + return Res; + return cmpRangeMetadata(castL->getMetadata(LLVMContext::MD_range), + castR->getMetadata(LLVMContext::MD_range)); + } + if (const LoadInst *castL = dyn_cast(L)) { + const LoadInst *castR = cast(R); + if (int Res = cmpNumbers(castL->isVolatile(), castR->isVolatile())) + return Res; + if (int Res = cmpNumbers(castL->getAlignment(), castR->getAlignment())) + return Res; + if (int Res = cmpOrderings(castL->getOrdering(), castR->getOrdering())) + return Res; + if (int Res = cmpNumbers(castL->getSyncScopeID(), castR->getSyncScopeID())) + return Res; + return cmpRangeMetadata(castL->getMetadata(LLVMContext::MD_range), + castR->getMetadata(LLVMContext::MD_range)); + } + if (const PHINode *castL = dyn_cast(L)) { + const PHINode *castR = cast(R); + for (unsigned i = 0, e = castL->getNumIncomingValues(); i != e; i++) { + if (int Res = + cmpValues(castL->getIncomingBlock(i), castR->getIncomingBlock(i))) + return Res; + } + } + if (const ShuffleVectorInst *castL = dyn_cast(L)) { + const ShuffleVectorInst *castR = cast(R); + return cmpArrayRef(castL->getShuffleMask(), castR->getShuffleMask()); + } + if (const StoreInst *castL = dyn_cast(L)) { + const StoreInst *castR = cast(R); + if (int Res = cmpNumbers(castL->isVolatile(), castR->isVolatile())) + return Res; + if (int Res = cmpNumbers(castL->getAlignment(), castR->getAlignment())) + return Res; + if (int Res = cmpOrderings(castL->getOrdering(), castR->getOrdering())) + return Res; + return cmpNumbers(castL->getSyncScopeID(), castR->getSyncScopeID()); + } + return 0; +} diff --git a/llvm/lib/Transforms/Utils/FunctionComparator.cpp b/llvm/lib/Transforms/Utils/FunctionComparator.cpp --- a/llvm/lib/Transforms/Utils/FunctionComparator.cpp +++ b/llvm/lib/Transforms/Utils/FunctionComparator.cpp @@ -41,6 +41,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/IRComparator.inc" #include #include #include @@ -105,6 +106,28 @@ return L.compare(R); } +int FunctionComparator::cmpArrayRef(ArrayRef LIndices, + ArrayRef RIndices) const { + if (int Res = cmpNumbers(LIndices.size(), RIndices.size())) + return Res; + for (size_t i = 0, e = LIndices.size(); i != e; ++i) { + if (int Res = cmpNumbers(LIndices[i], RIndices[i])) + return Res; + } + return 0; +} + +int FunctionComparator::cmpArrayRef(ArrayRef LIndices, + ArrayRef RIndices) const { + if (int Res = cmpNumbers(LIndices.size(), RIndices.size())) + return Res; + for (size_t i = 0, e = LIndices.size(); i != e; ++i) { + if (int Res = cmpNumbers(LIndices[i], RIndices[i])) + return Res; + } + return 0; +} + int FunctionComparator::cmpAttrs(const AttributeList L, const AttributeList R) const { if (int Res = cmpNumbers(L.getNumAttrSets(), R.getNumAttrSets())) @@ -544,121 +567,7 @@ return Res; } - // Check special state that is a part of some instructions. - if (const AllocaInst *AI = dyn_cast(L)) { - if (int Res = cmpTypes(AI->getAllocatedType(), - cast(R)->getAllocatedType())) - return Res; - return cmpNumbers(AI->getAlignment(), cast(R)->getAlignment()); - } - if (const LoadInst *LI = dyn_cast(L)) { - if (int Res = cmpNumbers(LI->isVolatile(), cast(R)->isVolatile())) - return Res; - if (int Res = - cmpNumbers(LI->getAlignment(), cast(R)->getAlignment())) - return Res; - if (int Res = - cmpOrderings(LI->getOrdering(), cast(R)->getOrdering())) - return Res; - if (int Res = cmpNumbers(LI->getSyncScopeID(), - cast(R)->getSyncScopeID())) - return Res; - return cmpRangeMetadata( - LI->getMetadata(LLVMContext::MD_range), - cast(R)->getMetadata(LLVMContext::MD_range)); - } - if (const StoreInst *SI = dyn_cast(L)) { - if (int Res = - cmpNumbers(SI->isVolatile(), cast(R)->isVolatile())) - return Res; - if (int Res = - cmpNumbers(SI->getAlignment(), cast(R)->getAlignment())) - return Res; - if (int Res = - cmpOrderings(SI->getOrdering(), cast(R)->getOrdering())) - return Res; - return cmpNumbers(SI->getSyncScopeID(), - cast(R)->getSyncScopeID()); - } - if (const CmpInst *CI = dyn_cast(L)) - return cmpNumbers(CI->getPredicate(), cast(R)->getPredicate()); - if (auto *CBL = dyn_cast(L)) { - auto *CBR = cast(R); - if (int Res = cmpNumbers(CBL->getCallingConv(), CBR->getCallingConv())) - return Res; - if (int Res = cmpAttrs(CBL->getAttributes(), CBR->getAttributes())) - return Res; - if (int Res = cmpOperandBundlesSchema(*CBL, *CBR)) - return Res; - if (const CallInst *CI = dyn_cast(L)) - if (int Res = cmpNumbers(CI->getTailCallKind(), - cast(R)->getTailCallKind())) - return Res; - return cmpRangeMetadata(L->getMetadata(LLVMContext::MD_range), - R->getMetadata(LLVMContext::MD_range)); - } - if (const InsertValueInst *IVI = dyn_cast(L)) { - ArrayRef LIndices = IVI->getIndices(); - ArrayRef RIndices = cast(R)->getIndices(); - if (int Res = cmpNumbers(LIndices.size(), RIndices.size())) - return Res; - for (size_t i = 0, e = LIndices.size(); i != e; ++i) { - if (int Res = cmpNumbers(LIndices[i], RIndices[i])) - return Res; - } - return 0; - } - if (const ExtractValueInst *EVI = dyn_cast(L)) { - ArrayRef LIndices = EVI->getIndices(); - ArrayRef RIndices = cast(R)->getIndices(); - if (int Res = cmpNumbers(LIndices.size(), RIndices.size())) - return Res; - for (size_t i = 0, e = LIndices.size(); i != e; ++i) { - if (int Res = cmpNumbers(LIndices[i], RIndices[i])) - return Res; - } - } - if (const FenceInst *FI = dyn_cast(L)) { - return FI->hasSamePropertiesAs(R); - } - if (const AtomicCmpXchgInst *CXI = dyn_cast(L)) { - return CXI->hasSamePropertiesAs(R); - } - if (const AtomicRMWInst *RMWI = dyn_cast(L)) { - if (int Res = cmpNumbers(RMWI->getOperation(), - cast(R)->getOperation())) - return Res; - if (int Res = cmpNumbers(RMWI->isVolatile(), - cast(R)->isVolatile())) - return Res; - if (int Res = cmpOrderings(RMWI->getOrdering(), - cast(R)->getOrdering())) - return Res; - return cmpNumbers(RMWI->getSyncScopeID(), - cast(R)->getSyncScopeID()); - } - if (const ShuffleVectorInst *SVI = dyn_cast(L)) { - ArrayRef LMask = SVI->getShuffleMask(); - ArrayRef RMask = cast(R)->getShuffleMask(); - if (int Res = cmpNumbers(LMask.size(), RMask.size())) - return Res; - for (size_t i = 0, e = LMask.size(); i != e; ++i) { - if (int Res = cmpNumbers(LMask[i], RMask[i])) - return Res; - } - } - if (const PHINode *PNL = dyn_cast(L)) { - const PHINode *PNR = cast(R); - // Ensure that in addition to the incoming values being identical - // (checked by the caller of this function), the incoming blocks - // are also identical. - for (unsigned i = 0, e = PNL->getNumIncomingValues(); i != e; ++i) { - if (int Res = - cmpValues(PNL->getIncomingBlock(i), PNR->getIncomingBlock(i))) - return Res; - } - } - return 0; + return cmpIROperations(L, R); } // Determine whether two GEP operations perform the same underlying arithmetic. diff --git a/llvm/utils/TableGen/CMakeLists.txt b/llvm/utils/TableGen/CMakeLists.txt --- a/llvm/utils/TableGen/CMakeLists.txt +++ b/llvm/utils/TableGen/CMakeLists.txt @@ -33,6 +33,7 @@ InstrInfoEmitter.cpp InstrDocsEmitter.cpp IntrinsicEmitter.cpp + IRComparator.cpp OptEmitter.cpp OptParserEmitter.cpp OptRSTEmitter.cpp diff --git a/llvm/utils/TableGen/IRComparator.h b/llvm/utils/TableGen/IRComparator.h new file mode 100644 --- /dev/null +++ b/llvm/utils/TableGen/IRComparator.h @@ -0,0 +1,70 @@ +#ifndef LLVM_UTILS_TABLEGEN_IRCOMPARATORS_H +#define LLVM_UTILS_TABLEGEN_IRCOMPARATORS_H + +#include "llvm/TableGen/Record.h" +#include +#include + +namespace llvm { +class Record; +class RecordKeeper; + +struct ComparableInstruction { + Record *TheDef; + std::string InstClassName; + struct OperandCompareElement { + std::string compareTy; + std::string fetchMethod; + bool requiresForLoop; + std::string forloopEndMethod; + }; + std::vector operandsToCompare; + + ComparableInstruction(Record *R) { + TheDef = R; + InstClassName = std::string(R->getValueAsString("InstClassName")); + ListInit *Operands = R->getValueAsListInit("Operands"); + for (unsigned i = 0, e = Operands->size(); i != e; ++i) { + Record *Op = Operands->getElementAsRecord(i); + assert(Op->isSubClassOf("OperandCompare") && + "Expected an element of OperandCompare!"); + Record *OpTy = Op->getValueAsDef("compareTy"); + std::string OpTyName = std::string(OpTy->getValueAsString("TyName")); + std::string OpFetchMethod = + std::string(Op->getValueAsString("OperandFetchMethod")); + bool OprequiresForLoop = Op->getValueAsBit("needForLoop"); + std::string OpforLoopEndMethod = + std::string(Op->getValueAsString("forEndMethod")); + struct OperandCompareElement Elm; + Elm.compareTy = OpTyName; + Elm.fetchMethod = OpFetchMethod; + Elm.requiresForLoop = OprequiresForLoop; + Elm.forloopEndMethod = OpforLoopEndMethod; + operandsToCompare.push_back(Elm); + } + } +}; + +class InstructionCompTable { + std::vector Instructions; + +public: + explicit InstructionCompTable(const RecordKeeper &RC) { + std::vector Defs = + RC.getAllDerivedDefinitions("InstructionComparator"); + Instructions.reserve(Defs.size()); + + for (unsigned i = 0, e = Defs.size(); i != e; ++i) + Instructions.push_back(ComparableInstruction(Defs[i])); + } + + InstructionCompTable() = default; + bool empty() const { return Instructions.empty(); } + size_t size() const { return Instructions.size(); } + ComparableInstruction &operator[](size_t Pos) { return Instructions[Pos]; } + const ComparableInstruction &operator[](size_t Pos) const { + return Instructions[Pos]; + } +}; +} // namespace llvm +#endif \ No newline at end of file diff --git a/llvm/utils/TableGen/IRComparator.cpp b/llvm/utils/TableGen/IRComparator.cpp new file mode 100644 --- /dev/null +++ b/llvm/utils/TableGen/IRComparator.cpp @@ -0,0 +1,139 @@ +#include "IRComparator.h" +#include "TableGenBackends.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" +#include "llvm/TableGen/StringMatcher.h" +#include "llvm/TableGen/StringToOffsetTable.h" +#include "llvm/TableGen/TableGenBackend.h" +using namespace llvm; + +namespace { +class IRComparator { + RecordKeeper &Records; + +public: + IRComparator(RecordKeeper &R) : Records(R) {} + + void run(raw_ostream &OS); + + std::string getIntendation(int tabs); + + void EmitCond(std::string compareTy, std::string fetchMethod, raw_ostream &OS, + bool isLast, bool inLoopCmp); + + void EmitCmpFunction(const InstructionCompTable &Insts, raw_ostream &OS); +}; +} // End anonymous namespace + +//===----------------------------------------------------------------------===// +// IRComparator Implementation +//===----------------------------------------------------------------------===// + +void IRComparator::run(raw_ostream &OS) { + emitSourceFileHeader("IRComparator Function Source Fragment", OS); + + InstructionCompTable Insts(Records); + + // Header for helper functions to compare basic types + OS << "#include \"llvm/Transforms/Utils/FunctionComparator.h\"\n"; + OS << "using namespace llvm;\n"; + + // Emit the function which checks for similarity between every instruction + EmitCmpFunction(Insts, OS); +} + +std::string IRComparator::getIntendation(int tabs) { + std::string intend = ""; + while (tabs--) + intend += " "; + return intend; +} + +void IRComparator::EmitCond(std::string compareTy, std::string fetchMethod, + raw_ostream &OS, bool isLast, bool inLoopCmp) { + if (compareTy == "OperandBundlesSchema") { + OS << "*castL, *castR)"; + } else if (compareTy == "RangeMetadata") { + OS << "castL->" << fetchMethod << "(LLVMContext::MD_range), "; + OS << "castR->" << fetchMethod << "(LLVMContext::MD_range))"; + } else { + if (inLoopCmp) { + OS << "castL->" << fetchMethod << "(i), "; + OS << "castR->" << fetchMethod << "(i))"; + } else { + OS << "castL->" << fetchMethod << "(), "; + OS << "castR->" << fetchMethod << "())"; + } + } + if (isLast) + OS << ";\n"; + else + OS << ")\n"; + return; +} + +// Function to generate structured operand checks for each Instruction +void IRComparator::EmitCmpFunction(const InstructionCompTable &Insts, + raw_ostream &OS) { + OS << "// Function to compare IR instructions for equality\n"; + OS << "int FunctionComparator::cmpIROperations(const Instruction *L, const " + "Instruction *R) const {\n"; + OS << " assert(L->getOpcode() != R->getOpcode() && \"Cannot compare " + "instructions with different Opcode\");\n"; + OS << " assert(L->getNumOperands() != R->getNumOperands() && \"Cannot " + "compare instructions with different number of operands\");\n"; + int intend = 1; + for (unsigned i = 0, e = Insts.size(); i < e; i++) { + OS << getIntendation(intend); + OS << "if(const " << Insts[i].InstClassName << " *castL = dyn_cast<" + << Insts[i].InstClassName << ">(L)) {\n"; + intend++; + OS << getIntendation(intend); + OS << "const " << Insts[i].InstClassName << " *castR = cast<" + << Insts[i].InstClassName << ">(R);\n"; + unsigned last = Insts[i].operandsToCompare.size() - 1; + + for (unsigned i1 = 0, e1 = last + 1; i1 != e1; i1++) { + bool requiresFor = Insts[i].operandsToCompare[i1].requiresForLoop; + if (i1 < last || Insts[i].operandsToCompare[last].requiresForLoop) { + if (requiresFor) { + OS << getIntendation(intend); + OS << "for(unsigned i = 0, e = castL->" + << Insts[i].operandsToCompare[i1].forloopEndMethod + << "(); i != e; i++) {\n"; + intend++; + } + OS << getIntendation(intend); + OS << "if(int Res = cmp" << Insts[i].operandsToCompare[i1].compareTy + << "("; + EmitCond(Insts[i].operandsToCompare[i1].compareTy, + Insts[i].operandsToCompare[i1].fetchMethod, OS, false, + requiresFor); + OS << getIntendation(intend + 1); + OS << "return Res;\n"; + if (requiresFor) { + intend--; + OS << getIntendation(intend); + OS << "}\n"; + intend--; + OS << getIntendation(intend); + OS << "}\n"; + } + } else { + OS << getIntendation(intend); + OS << "return cmp" << Insts[i].operandsToCompare[last].compareTy << "("; + EmitCond(Insts[i].operandsToCompare[last].compareTy, + Insts[i].operandsToCompare[last].fetchMethod, OS, true, false); + intend--; + OS << getIntendation(intend); + OS << "}\n"; + } + } + } + OS << " return 0;\n"; + OS << "}\n"; +} + +void llvm::EmitIRCompFunc(RecordKeeper &RK, raw_ostream &OS) { + IRComparator(RK).run(OS); +} diff --git a/llvm/utils/TableGen/TableGen.cpp b/llvm/utils/TableGen/TableGen.cpp --- a/llvm/utils/TableGen/TableGen.cpp +++ b/llvm/utils/TableGen/TableGen.cpp @@ -28,6 +28,7 @@ GenRegisterInfo, GenInstrInfo, GenInstrDocs, + GenInstCmp, GenAsmWriter, GenAsmMatcher, GenDisassembler, @@ -83,6 +84,7 @@ "Generate instruction descriptions"), clEnumValN(GenInstrDocs, "gen-instr-docs", "Generate instruction documentation"), + clEnumValN(GenInstCmp, "gen-ir-cmp", "Generate IR Comparator function"), clEnumValN(GenCallingConv, "gen-callingconv", "Generate calling convention descriptions"), clEnumValN(GenAsmWriter, "gen-asm-writer", "Generate assembly writer"), @@ -160,6 +162,9 @@ case GenInstrDocs: EmitInstrDocs(Records, OS); break; + case GenInstCmp: + EmitIRCompFunc(Records, OS); + break; case GenCallingConv: EmitCallingConv(Records, OS); break; diff --git a/llvm/utils/TableGen/TableGenBackends.h b/llvm/utils/TableGen/TableGenBackends.h --- a/llvm/utils/TableGen/TableGenBackends.h +++ b/llvm/utils/TableGen/TableGenBackends.h @@ -90,6 +90,7 @@ void EmitRegisterBank(RecordKeeper &RK, raw_ostream &OS); void EmitExegesis(RecordKeeper &RK, raw_ostream &OS); void EmitAutomata(RecordKeeper &RK, raw_ostream &OS); +void EmitIRCompFunc(RecordKeeper &RK, raw_ostream &OS); } // End llvm namespace