Index: include/llvm/CodeGen/GlobalISel/InstructionSelector.h =================================================================== --- include/llvm/CodeGen/GlobalISel/InstructionSelector.h +++ include/llvm/CodeGen/GlobalISel/InstructionSelector.h @@ -16,8 +16,12 @@ #ifndef LLVM_CODEGEN_GLOBALISEL_INSTRUCTIONSELECTOR_H #define LLVM_CODEGEN_GLOBALISEL_INSTRUCTIONSELECTOR_H +#include + namespace llvm { class MachineInstr; +class MachineOperand; +class MachineRegisterInfo; class RegisterBankInfo; class TargetInstrInfo; class TargetRegisterInfo; @@ -56,6 +60,9 @@ const TargetInstrInfo &TII, const TargetRegisterInfo &TRI, const RegisterBankInfo &RBI) const; + + bool isOperandImmEqual(const MachineOperand &MO, int64_t Value, + const MachineRegisterInfo &MRI) const; }; } // End namespace llvm. Index: include/llvm/CodeGen/LowLevelType.h =================================================================== --- include/llvm/CodeGen/LowLevelType.h +++ include/llvm/CodeGen/LowLevelType.h @@ -1,4 +1,4 @@ -//== llvm/CodeGen/GlobalISel/LowLevelType.h -------------------- -*- C++ -*-==// +//== llvm/CodeGen/LowLevelType.h ------------------------------- -*- C++ -*-==// // // The LLVM Compiler Infrastructure // @@ -10,197 +10,23 @@ /// Implement a low-level type suitable for MachineInstr level instruction /// selection. /// -/// For a type attached to a MachineInstr, we only care about 2 details: total -/// size and the number of vector lanes (if any). Accordingly, there are 4 -/// possible valid type-kinds: -/// -/// * `sN` for scalars and aggregates -/// * `` for vectors, which must have at least 2 elements. -/// * `pN` for pointers -/// -/// Other information required for correct selection is expected to be carried -/// by the opcode, or non-type flags. For example the distinction between G_ADD -/// and G_FADD for int/float or fast-math flags. +/// This provides the CodeGen aspects of LowLevelType, such as Type conversion. // //===----------------------------------------------------------------------===// -#ifndef LLVM_CODEGEN_GLOBALISEL_LOWLEVELTYPE_H -#define LLVM_CODEGEN_GLOBALISEL_LOWLEVELTYPE_H +#ifndef LLVM_CODEGEN_LOWLEVELTYPE_H +#define LLVM_CODEGEN_LOWLEVELTYPE_H -#include -#include "llvm/ADT/DenseMapInfo.h" -#include "llvm/CodeGen/ValueTypes.h" +#include "llvm/Support/LowLevelTypeImpl.h" namespace llvm { class DataLayout; -class LLVMContext; class Type; -class raw_ostream; - -class LLT { -public: - enum TypeKind : uint16_t { - Invalid, - Scalar, - Pointer, - Vector, - }; - - /// Get a low-level scalar or aggregate "bag of bits". - static LLT scalar(unsigned SizeInBits) { - assert(SizeInBits > 0 && "invalid scalar size"); - return LLT{Scalar, 1, SizeInBits}; - } - - /// Get a low-level pointer in the given address space (defaulting to 0). - static LLT pointer(uint16_t AddressSpace, unsigned SizeInBits) { - return LLT{Pointer, AddressSpace, SizeInBits}; - } - - /// Get a low-level vector of some number of elements and element width. - /// \p NumElements must be at least 2. - static LLT vector(uint16_t NumElements, unsigned ScalarSizeInBits) { - assert(NumElements > 1 && "invalid number of vector elements"); - return LLT{Vector, NumElements, ScalarSizeInBits}; - } - - /// Get a low-level vector of some number of elements and element type. - static LLT vector(uint16_t NumElements, LLT ScalarTy) { - assert(NumElements > 1 && "invalid number of vector elements"); - assert(ScalarTy.isScalar() && "invalid vector element type"); - return LLT{Vector, NumElements, ScalarTy.getSizeInBits()}; - } - - explicit LLT(TypeKind Kind, uint16_t NumElements, unsigned SizeInBits) - : SizeInBits(SizeInBits), ElementsOrAddrSpace(NumElements), Kind(Kind) { - assert((Kind != Vector || ElementsOrAddrSpace > 1) && - "invalid number of vector elements"); - } - - explicit LLT() : SizeInBits(0), ElementsOrAddrSpace(0), Kind(Invalid) {} - - /// Construct a low-level type based on an LLVM type. - explicit LLT(Type &Ty, const DataLayout &DL); - - explicit LLT(MVT VT); - - bool isValid() const { return Kind != Invalid; } - - bool isScalar() const { return Kind == Scalar; } - - bool isPointer() const { return Kind == Pointer; } - - bool isVector() const { return Kind == Vector; } - - /// Returns the number of elements in a vector LLT. Must only be called on - /// vector types. - uint16_t getNumElements() const { - assert(isVector() && "cannot get number of elements on scalar/aggregate"); - return ElementsOrAddrSpace; - } - - /// Returns the total size of the type. Must only be called on sized types. - unsigned getSizeInBits() const { - if (isPointer() || isScalar()) - return SizeInBits; - return SizeInBits * ElementsOrAddrSpace; - } - - unsigned getScalarSizeInBits() const { - return SizeInBits; - } - - unsigned getAddressSpace() const { - assert(isPointer() && "cannot get address space of non-pointer type"); - return ElementsOrAddrSpace; - } - - /// Returns the vector's element type. Only valid for vector types. - LLT getElementType() const { - assert(isVector() && "cannot get element type of scalar/aggregate"); - return scalar(SizeInBits); - } - - /// Get a low-level type with half the size of the original, by halving the - /// size of the scalar type involved. For example `s32` will become `s16`, - /// `<2 x s32>` will become `<2 x s16>`. - LLT halfScalarSize() const { - assert(!isPointer() && getScalarSizeInBits() > 1 && - getScalarSizeInBits() % 2 == 0 && "cannot half size of this type"); - return LLT{Kind, ElementsOrAddrSpace, SizeInBits / 2}; - } - - /// Get a low-level type with twice the size of the original, by doubling the - /// size of the scalar type involved. For example `s32` will become `s64`, - /// `<2 x s32>` will become `<2 x s64>`. - LLT doubleScalarSize() const { - assert(!isPointer() && "cannot change size of this type"); - return LLT{Kind, ElementsOrAddrSpace, SizeInBits * 2}; - } - - /// Get a low-level type with half the size of the original, by halving the - /// number of vector elements of the scalar type involved. The source must be - /// a vector type with an even number of elements. For example `<4 x s32>` - /// will become `<2 x s32>`, `<2 x s32>` will become `s32`. - LLT halfElements() const { - assert(isVector() && ElementsOrAddrSpace % 2 == 0 && - "cannot half odd vector"); - if (ElementsOrAddrSpace == 2) - return scalar(SizeInBits); - - return LLT{Vector, static_cast(ElementsOrAddrSpace / 2), - SizeInBits}; - } - - /// Get a low-level type with twice the size of the original, by doubling the - /// number of vector elements of the scalar type involved. The source must be - /// a vector type. For example `<2 x s32>` will become `<4 x s32>`. Doubling - /// the number of elements in sN produces <2 x sN>. - LLT doubleElements() const { - assert(!isPointer() && "cannot double elements in pointer"); - return LLT{Vector, static_cast(ElementsOrAddrSpace * 2), - SizeInBits}; - } - - void print(raw_ostream &OS) const; - - bool operator==(const LLT &RHS) const { - return Kind == RHS.Kind && SizeInBits == RHS.SizeInBits && - ElementsOrAddrSpace == RHS.ElementsOrAddrSpace; - } - - bool operator!=(const LLT &RHS) const { return !(*this == RHS); } - - friend struct DenseMapInfo; -private: - unsigned SizeInBits; - uint16_t ElementsOrAddrSpace; - TypeKind Kind; -}; - -inline raw_ostream& operator<<(raw_ostream &OS, const LLT &Ty) { - Ty.print(OS); - return OS; -} -template<> struct DenseMapInfo { - static inline LLT getEmptyKey() { - return LLT{LLT::Invalid, 0, -1u}; - } - static inline LLT getTombstoneKey() { - return LLT{LLT::Invalid, 0, -2u}; - } - static inline unsigned getHashValue(const LLT &Ty) { - uint64_t Val = ((uint64_t)Ty.SizeInBits << 32) | - ((uint64_t)Ty.ElementsOrAddrSpace << 16) | (uint64_t)Ty.Kind; - return DenseMapInfo::getHashValue(Val); - } - static bool isEqual(const LLT &LHS, const LLT &RHS) { - return LHS == RHS; - } -}; +/// Construct a low-level type based on an LLVM type. +LLT getLLTForType(Type &Ty, const DataLayout &DL); } -#endif +#endif // LLVM_CODEGEN_LOWLEVELTYPE_H Index: include/llvm/Support/LowLevelTypeImpl.h =================================================================== --- include/llvm/Support/LowLevelTypeImpl.h +++ include/llvm/Support/LowLevelTypeImpl.h @@ -1,4 +1,4 @@ -//== llvm/CodeGen/GlobalISel/LowLevelType.h -------------------- -*- C++ -*-==// +//== llvm/Support/LowLevelTypeImpl.h --------------------------- -*- C++ -*-==// // // The LLVM Compiler Infrastructure // @@ -24,17 +24,16 @@ // //===----------------------------------------------------------------------===// -#ifndef LLVM_CODEGEN_GLOBALISEL_LOWLEVELTYPE_H -#define LLVM_CODEGEN_GLOBALISEL_LOWLEVELTYPE_H +#ifndef LLVM_SUPPORT_LOWLEVELTYPEIMPL_H +#define LLVM_SUPPORT_LOWLEVELTYPEIMPL_H #include #include "llvm/ADT/DenseMapInfo.h" -#include "llvm/CodeGen/ValueTypes.h" +#include "llvm/CodeGen/MachineValueType.h" namespace llvm { class DataLayout; -class LLVMContext; class Type; class raw_ostream; @@ -80,9 +79,6 @@ explicit LLT() : SizeInBits(0), ElementsOrAddrSpace(0), Kind(Invalid) {} - /// Construct a low-level type based on an LLVM type. - explicit LLT(Type &Ty, const DataLayout &DL); - explicit LLT(MVT VT); bool isValid() const { return Kind != Invalid; } @@ -203,4 +199,4 @@ } -#endif +#endif // LLVM_SUPPORT_LOWLEVELTYPEIMPL_H Index: lib/CodeGen/GlobalISel/IRTranslator.cpp =================================================================== --- lib/CodeGen/GlobalISel/IRTranslator.cpp +++ lib/CodeGen/GlobalISel/IRTranslator.cpp @@ -70,7 +70,8 @@ // we need to concat together to produce the value. assert(Val.getType()->isSized() && "Don't know how to create an empty vreg"); - unsigned VReg = MRI->createGenericVirtualRegister(LLT{*Val.getType(), *DL}); + unsigned VReg = + MRI->createGenericVirtualRegister(getLLTForType(*Val.getType(), *DL)); ValReg = VReg; if (auto CV = dyn_cast(&Val)) { @@ -221,7 +222,7 @@ const unsigned SwCondValue = getOrCreateVReg(*SwInst.getCondition()); const BasicBlock *OrigBB = SwInst.getParent(); - LLT LLTi1 = LLT(*Type::getInt1Ty(U.getContext()), *DL); + LLT LLTi1 = getLLTForType(*Type::getInt1Ty(U.getContext()), *DL); for (auto &CaseIt : SwInst.cases()) { const unsigned CaseValueReg = getOrCreateVReg(*CaseIt.getCaseValue()); const unsigned Tst = MRI->createGenericVirtualRegister(LLTi1); @@ -277,7 +278,7 @@ unsigned Res = getOrCreateVReg(LI); unsigned Addr = getOrCreateVReg(*LI.getPointerOperand()); - LLT VTy{*LI.getType(), *DL}, PTy{*LI.getPointerOperand()->getType(), *DL}; + MIRBuilder.buildLoad( Res, Addr, *MF->getMachineMemOperand(MachinePointerInfo(LI.getPointerOperand()), @@ -295,8 +296,6 @@ unsigned Val = getOrCreateVReg(*SI.getValueOperand()); unsigned Addr = getOrCreateVReg(*SI.getPointerOperand()); - LLT VTy{*SI.getValueOperand()->getType(), *DL}, - PTy{*SI.getPointerOperand()->getType(), *DL}; MIRBuilder.buildStore( Val, Addr, @@ -372,7 +371,8 @@ bool IRTranslator::translateBitCast(const User &U, MachineIRBuilder &MIRBuilder) { - if (LLT{*U.getOperand(0)->getType(), *DL} == LLT{*U.getType(), *DL}) { + if (getLLTForType(*U.getOperand(0)->getType(), *DL) == + getLLTForType(*U.getType(), *DL)) { unsigned &Reg = ValToVReg[&U]; if (Reg) MIRBuilder.buildCopy(Reg, getOrCreateVReg(*U.getOperand(0))); @@ -399,7 +399,7 @@ Value &Op0 = *U.getOperand(0); unsigned BaseReg = getOrCreateVReg(Op0); - LLT PtrTy{*Op0.getType(), *DL}; + LLT PtrTy = getLLTForType(*Op0.getType(), *DL); unsigned PtrSize = DL->getPointerSizeInBits(PtrTy.getAddressSpace()); LLT OffsetTy = LLT::scalar(PtrSize); @@ -465,7 +465,7 @@ bool IRTranslator::translateMemfunc(const CallInst &CI, MachineIRBuilder &MIRBuilder, unsigned ID) { - LLT SizeTy{*CI.getArgOperand(2)->getType(), *DL}; + LLT SizeTy = getLLTForType(*CI.getArgOperand(2)->getType(), *DL); Type *DstTy = CI.getArgOperand(0)->getType(); if (cast(DstTy)->getAddressSpace() != 0 || SizeTy.getSizeInBits() != DL->getPointerSizeInBits(0)) @@ -522,7 +522,7 @@ bool IRTranslator::translateOverflowIntrinsic(const CallInst &CI, unsigned Op, MachineIRBuilder &MIRBuilder) { - LLT Ty{*CI.getOperand(0)->getType(), *DL}; + LLT Ty = getLLTForType(*CI.getOperand(0)->getType(), *DL); LLT s1 = LLT::scalar(1); unsigned Width = Ty.getSizeInBits(); unsigned Res = MRI->createGenericVirtualRegister(Ty); @@ -665,7 +665,7 @@ getStackGuard(getOrCreateVReg(CI), MIRBuilder); return true; case Intrinsic::stackprotector: { - LLT PtrTy{*CI.getArgOperand(0)->getType(), *DL}; + LLT PtrTy = getLLTForType(*CI.getArgOperand(0)->getType(), *DL); unsigned GuardVal = MRI->createGenericVirtualRegister(PtrTy); getStackGuard(GuardVal, MIRBuilder); @@ -808,7 +808,7 @@ SmallVector Tys; for (Type *Ty : cast(LP.getType())->elements()) - Tys.push_back(LLT{*Ty, *DL}); + Tys.push_back(getLLTForType(*Ty, *DL)); assert(Tys.size() == 2 && "Only two-valued landingpads are supported"); // Mark exception register as live in. @@ -873,7 +873,7 @@ MIRBuilder.buildConstant(TySize, -DL->getTypeAllocSize(Ty)); MIRBuilder.buildMul(AllocSize, NumElts, TySize); - LLT PtrTy = LLT{*AI.getType(), *DL}; + LLT PtrTy = getLLTForType(*AI.getType(), *DL); auto &TLI = *MF->getSubtarget().getTargetLowering(); unsigned SPReg = TLI.getStackPointerRegisterToSaveRestore(); Index: lib/CodeGen/GlobalISel/InstructionSelect.cpp =================================================================== --- lib/CodeGen/GlobalISel/InstructionSelect.cpp +++ lib/CodeGen/GlobalISel/InstructionSelect.cpp @@ -176,3 +176,20 @@ // FIXME: Should we accurately track changes? return true; } + +bool InstructionSelector::isOperandImmEqual( + const MachineOperand &MO, int64_t Value, + const MachineRegisterInfo &MRI) const { + // TODO: We should also test isImm() and isCImm() too but this isn't required + // until a DAGCombine equivalent is implemented. + + if (MO.isReg()) { + MachineInstr *Def = MRI.getVRegDef(MO.getReg()); + if (Def->getOpcode() != TargetOpcode::G_CONSTANT) + return false; + assert(Def->getOperand(1).isImm() && "G_CONSTANT values must be constants"); + return Def->getOperand(1).getImm() == Value; + } + + return false; +} Index: lib/CodeGen/LowLevelType.cpp =================================================================== --- lib/CodeGen/LowLevelType.cpp +++ lib/CodeGen/LowLevelType.cpp @@ -1,4 +1,4 @@ -//===-- llvm/CodeGen/GlobalISel/LowLevelType.cpp --------------------------===// +//===-- llvm/CodeGen/LowLevelType.cpp -------------------------------------===// // // The LLVM Compiler Infrastructure // @@ -18,54 +18,21 @@ #include "llvm/Support/raw_ostream.h" using namespace llvm; -LLT::LLT(Type &Ty, const DataLayout &DL) { +LLT llvm::getLLTForType(Type &Ty, const DataLayout &DL) { if (auto VTy = dyn_cast(&Ty)) { - SizeInBits = VTy->getElementType()->getPrimitiveSizeInBits(); - ElementsOrAddrSpace = VTy->getNumElements(); - Kind = ElementsOrAddrSpace == 1 ? Scalar : Vector; + auto NumElements = VTy->getNumElements(); + auto ScalarSizeInBits = VTy->getElementType()->getPrimitiveSizeInBits(); + if (NumElements == 1) + return LLT::scalar(ScalarSizeInBits); + return LLT::vector(NumElements, ScalarSizeInBits); } else if (auto PTy = dyn_cast(&Ty)) { - Kind = Pointer; - SizeInBits = DL.getTypeSizeInBits(&Ty); - ElementsOrAddrSpace = PTy->getAddressSpace(); + return LLT::pointer(PTy->getAddressSpace(), DL.getTypeSizeInBits(&Ty)); } else if (Ty.isSized()) { // Aggregates are no different from real scalars as far as GlobalISel is // concerned. - Kind = Scalar; - SizeInBits = DL.getTypeSizeInBits(&Ty); - ElementsOrAddrSpace = 1; + auto SizeInBits = DL.getTypeSizeInBits(&Ty); assert(SizeInBits != 0 && "invalid zero-sized type"); - } else { - Kind = Invalid; - SizeInBits = ElementsOrAddrSpace = 0; + return LLT::scalar(SizeInBits); } -} - -LLT::LLT(MVT VT) { - if (VT.isVector()) { - SizeInBits = VT.getVectorElementType().getSizeInBits(); - ElementsOrAddrSpace = VT.getVectorNumElements(); - Kind = ElementsOrAddrSpace == 1 ? Scalar : Vector; - } else if (VT.isValid()) { - // Aggregates are no different from real scalars as far as GlobalISel is - // concerned. - Kind = Scalar; - SizeInBits = VT.getSizeInBits(); - ElementsOrAddrSpace = 1; - assert(SizeInBits != 0 && "invalid zero-sized type"); - } else { - Kind = Invalid; - SizeInBits = ElementsOrAddrSpace = 0; - } -} - -void LLT::print(raw_ostream &OS) const { - if (isVector()) - OS << "<" << ElementsOrAddrSpace << " x s" << SizeInBits << ">"; - else if (isPointer()) - OS << "p" << getAddressSpace(); - else if (isValid()) { - assert(isScalar() && "unexpected type"); - OS << "s" << getScalarSizeInBits(); - } else - llvm_unreachable("trying to print an invalid type"); + return LLT(); } Index: lib/Support/CMakeLists.txt =================================================================== --- lib/Support/CMakeLists.txt +++ lib/Support/CMakeLists.txt @@ -68,6 +68,7 @@ LineIterator.cpp Locale.cpp LockFileManager.cpp + LowLevelType.cpp ManagedStatic.cpp MathExtras.cpp MemoryBuffer.cpp Index: lib/Support/LowLevelType.cpp =================================================================== --- lib/Support/LowLevelType.cpp +++ lib/Support/LowLevelType.cpp @@ -1,4 +1,4 @@ -//===-- llvm/CodeGen/GlobalISel/LowLevelType.cpp --------------------------===// +//===-- llvm/Support/LowLevelType.cpp -------------------------------------===// // // The LLVM Compiler Infrastructure // @@ -12,34 +12,10 @@ // //===----------------------------------------------------------------------===// -#include "llvm/CodeGen/LowLevelType.h" -#include "llvm/IR/DataLayout.h" -#include "llvm/IR/DerivedTypes.h" +#include "llvm/Support/LowLevelTypeImpl.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; -LLT::LLT(Type &Ty, const DataLayout &DL) { - if (auto VTy = dyn_cast(&Ty)) { - SizeInBits = VTy->getElementType()->getPrimitiveSizeInBits(); - ElementsOrAddrSpace = VTy->getNumElements(); - Kind = ElementsOrAddrSpace == 1 ? Scalar : Vector; - } else if (auto PTy = dyn_cast(&Ty)) { - Kind = Pointer; - SizeInBits = DL.getTypeSizeInBits(&Ty); - ElementsOrAddrSpace = PTy->getAddressSpace(); - } else if (Ty.isSized()) { - // Aggregates are no different from real scalars as far as GlobalISel is - // concerned. - Kind = Scalar; - SizeInBits = DL.getTypeSizeInBits(&Ty); - ElementsOrAddrSpace = 1; - assert(SizeInBits != 0 && "invalid zero-sized type"); - } else { - Kind = Invalid; - SizeInBits = ElementsOrAddrSpace = 0; - } -} - LLT::LLT(MVT VT) { if (VT.isVector()) { SizeInBits = VT.getVectorElementType().getSizeInBits(); Index: lib/Target/AArch64/AArch64CallLowering.cpp =================================================================== --- lib/Target/AArch64/AArch64CallLowering.cpp +++ lib/Target/AArch64/AArch64CallLowering.cpp @@ -192,8 +192,8 @@ // FIXME: set split flags if they're actually used (e.g. i128 on AAPCS). Type *SplitTy = SplitVT.getTypeForEVT(Ctx); SplitArgs.push_back( - ArgInfo{MRI.createGenericVirtualRegister(LLT{*SplitTy, DL}), SplitTy, - OrigArg.Flags, OrigArg.IsFixed}); + ArgInfo{MRI.createGenericVirtualRegister(getLLTForType(*SplitTy, DL)), + SplitTy, OrigArg.Flags, OrigArg.IsFixed}); } SmallVector BitOffsets; Index: lib/Target/AArch64/AArch64InstructionSelector.cpp =================================================================== --- lib/Target/AArch64/AArch64InstructionSelector.cpp +++ lib/Target/AArch64/AArch64InstructionSelector.cpp @@ -634,9 +634,12 @@ // FIXME: Is going through int64_t always correct? ImmOp.ChangeToImmediate( ImmOp.getFPImm()->getValueAPF().bitcastToAPInt().getZExtValue()); - } else { + } else if (I.getOperand(1).isCImm()) { uint64_t Val = I.getOperand(1).getCImm()->getZExtValue(); I.getOperand(1).ChangeToImmediate(Val); + } else if (I.getOperand(1).isImm()) { + uint64_t Val = I.getOperand(1).getImm(); + I.getOperand(1).ChangeToImmediate(Val); } constrainSelectedInstRegOperands(I, TII, TRI, RBI); Index: lib/Target/AMDGPU/AMDGPUCallLowering.cpp =================================================================== --- lib/Target/AMDGPU/AMDGPUCallLowering.cpp +++ lib/Target/AMDGPU/AMDGPUCallLowering.cpp @@ -50,7 +50,7 @@ const Function &F = *MF.getFunction(); const DataLayout &DL = F.getParent()->getDataLayout(); PointerType *PtrTy = PointerType::get(ParamTy, AMDGPUAS::CONSTANT_ADDRESS); - LLT PtrType(*PtrTy, DL); + LLT PtrType = getLLTForType(*PtrTy, DL); unsigned DstReg = MRI.createGenericVirtualRegister(PtrType); unsigned KernArgSegmentPtr = TRI->getPreloadedValue(MF, SIRegisterInfo::KERNARG_SEGMENT_PTR); Index: lib/Target/X86/X86CallLowering.cpp =================================================================== --- lib/Target/X86/X86CallLowering.cpp +++ lib/Target/X86/X86CallLowering.cpp @@ -58,8 +58,9 @@ Type *PartTy = PartVT.getTypeForEVT(Context); for (unsigned i = 0; i < NumParts; ++i) { - ArgInfo Info = ArgInfo{MRI.createGenericVirtualRegister(LLT{*PartTy, DL}), - PartTy, OrigArg.Flags}; + ArgInfo Info = + ArgInfo{MRI.createGenericVirtualRegister(getLLTForType(*PartTy, DL)), + PartTy, OrigArg.Flags}; SplitArgs.push_back(Info); BitOffsets.push_back(PartVT.getSizeInBits() * i); SplitRegs.push_back(Info.Reg); Index: test/CodeGen/AArch64/GlobalISel/arm64-instructionselect-xor.mir =================================================================== --- /dev/null +++ test/CodeGen/AArch64/GlobalISel/arm64-instructionselect-xor.mir @@ -0,0 +1,166 @@ +# RUN: llc -O0 -mtriple=aarch64-apple-ios -run-pass=instruction-select -verify-machineinstrs -global-isel %s -o - | FileCheck %s -check-prefix=CHECK -check-prefix=IOS +# RUN: llc -O0 -mtriple=aarch64-linux-gnu -run-pass=instruction-select -verify-machineinstrs -global-isel %s -o - | FileCheck %s -check-prefix=CHECK -check-prefix=LINUX-DEFAULT +# RUN: llc -O0 -mtriple=aarch64-linux-gnu -relocation-model=pic -run-pass=instruction-select -verify-machineinstrs -global-isel %s -o - | FileCheck %s -check-prefix=CHECK -check-prefix=LINUX-PIC + +# Test the instruction selector. +# As we support more instructions, we need to split this up. + +--- | + target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128" + + define void @xor_s32_gpr() { ret void } + define void @xor_s64_gpr() { ret void } + define void @xor_constant_n1_s32_gpr() { ret void } + define void @xor_constant_n1_s64_gpr() { ret void } + define void @xor_constant_n1_s32_gpr_2bb() { ret void } + +... + +--- +# Check that we select a 32-bit GPR G_XOR into EORWrr on GPR32. +# Also check that we constrain the register class of the COPY to GPR32. +# CHECK-LABEL: name: xor_s32_gpr +name: xor_s32_gpr +legalized: true +regBankSelected: true + +# CHECK: registers: +# CHECK-NEXT: - { id: 0, class: gpr32 } +# CHECK-NEXT: - { id: 1, class: gpr32 } +# CHECK-NEXT: - { id: 2, class: gpr32 } +registers: + - { id: 0, class: gpr } + - { id: 1, class: gpr } + - { id: 2, class: gpr } + +# CHECK: body: +# CHECK: %0 = COPY %w0 +# CHECK: %1 = COPY %w1 +# CHECK: %2 = EORWrr %0, %1 +body: | + bb.0: + liveins: %w0, %w1 + + %0(s32) = COPY %w0 + %1(s32) = COPY %w1 + %2(s32) = G_XOR %0, %1 +... + +--- +# Same as xor_s64_gpr, for 64-bit operations. +# CHECK-LABEL: name: xor_s64_gpr +name: xor_s64_gpr +legalized: true +regBankSelected: true + +# CHECK: registers: +# CHECK-NEXT: - { id: 0, class: gpr64 } +# CHECK-NEXT: - { id: 1, class: gpr64 } +# CHECK-NEXT: - { id: 2, class: gpr64 } +registers: + - { id: 0, class: gpr } + - { id: 1, class: gpr } + - { id: 2, class: gpr } + +# CHECK: body: +# CHECK: %0 = COPY %x0 +# CHECK: %1 = COPY %x1 +# CHECK: %2 = EORXrr %0, %1 +body: | + bb.0: + liveins: %x0, %x1 + + %0(s64) = COPY %x0 + %1(s64) = COPY %x1 + %2(s64) = G_XOR %0, %1 +... + +--- +# Check that we select a 32-bit GPR G_XOR into EORWrr on GPR32. +# Also check that we constrain the register class of the COPY to GPR32. +# CHECK-LABEL: name: xor_constant_n1_s32_gpr +name: xor_constant_n1_s32_gpr +legalized: true +regBankSelected: true + +# CHECK: registers: +# CHECK-NEXT: - { id: 0, class: gpr32 } +# CHECK-NEXT: - { id: 1, class: gpr32 } +# CHECK-NEXT: - { id: 2, class: gpr32 } +registers: + - { id: 0, class: gpr } + - { id: 1, class: gpr } + - { id: 2, class: gpr } + +# CHECK: body: +# CHECK: %0 = COPY %w0 +# CHECK: %2 = ORNWrr %wzr, %0 +body: | + bb.0: + liveins: %w0 + + %0(s32) = COPY %w0 + %1(s32) = G_CONSTANT -1 + %2(s32) = G_XOR %0, %1 +... + +--- +# Same as xor_constant_n1_s64_gpr, for 64-bit operations. +# CHECK-LABEL: name: xor_constant_n1_s64_gpr +name: xor_constant_n1_s64_gpr +legalized: true +regBankSelected: true + +# CHECK: registers: +# CHECK-NEXT: - { id: 0, class: gpr64 } +# CHECK-NEXT: - { id: 1, class: gpr64 } +# CHECK-NEXT: - { id: 2, class: gpr64 } +registers: + - { id: 0, class: gpr } + - { id: 1, class: gpr } + - { id: 2, class: gpr } + +# CHECK: body: +# CHECK: %0 = COPY %x0 +# CHECK: %2 = ORNXrr %xzr, %0 +body: | + bb.0: + liveins: %x0 + + %0(s64) = COPY %x0 + %1(s64) = G_CONSTANT -1 + %2(s64) = G_XOR %0, %1 +... + +--- +# Check that we can obtain constants from other basic blocks. +# CHECK-LABEL: name: xor_constant_n1_s32_gpr_2bb +name: xor_constant_n1_s32_gpr_2bb +legalized: true +regBankSelected: true + +# CHECK: registers: +# CHECK-NEXT: - { id: 0, class: gpr32 } +# CHECK-NEXT: - { id: 1, class: gpr32 } +# CHECK-NEXT: - { id: 2, class: gpr32 } +registers: + - { id: 0, class: gpr } + - { id: 1, class: gpr } + - { id: 2, class: gpr } + +# CHECK: body: +# CHECK: B %bb.1 +# CHECK: %0 = COPY %w0 +# CHECK: %2 = ORNWrr %wzr, %0 + +body: | + bb.0: + liveins: %w0, %w1 + successors: %bb.1 + %1(s32) = G_CONSTANT -1 + G_BR %bb.1 + bb.1: + %0(s32) = COPY %w0 + %2(s32) = G_XOR %0, %1 +... + Index: test/CodeGen/AArch64/GlobalISel/arm64-instructionselect.mir =================================================================== --- test/CodeGen/AArch64/GlobalISel/arm64-instructionselect.mir +++ test/CodeGen/AArch64/GlobalISel/arm64-instructionselect.mir @@ -18,9 +18,6 @@ define void @or_s64_gpr() { ret void } define void @or_v2s32_fpr() { ret void } - define void @xor_s32_gpr() { ret void } - define void @xor_s64_gpr() { ret void } - define void @and_s32_gpr() { ret void } define void @and_s64_gpr() { ret void } @@ -354,64 +351,6 @@ ... --- -# Same as add_s32_gpr, for G_XOR operations. -# CHECK-LABEL: name: xor_s32_gpr -name: xor_s32_gpr -legalized: true -regBankSelected: true - -# CHECK: registers: -# CHECK-NEXT: - { id: 0, class: gpr32 } -# CHECK-NEXT: - { id: 1, class: gpr32 } -# CHECK-NEXT: - { id: 2, class: gpr32 } -registers: - - { id: 0, class: gpr } - - { id: 1, class: gpr } - - { id: 2, class: gpr } - -# CHECK: body: -# CHECK: %0 = COPY %w0 -# CHECK: %1 = COPY %w1 -# CHECK: %2 = EORWrr %0, %1 -body: | - bb.0: - liveins: %w0, %w1 - - %0(s32) = COPY %w0 - %1(s32) = COPY %w1 - %2(s32) = G_XOR %0, %1 -... - ---- -# Same as add_s64_gpr, for G_XOR operations. -# CHECK-LABEL: name: xor_s64_gpr -name: xor_s64_gpr -legalized: true -regBankSelected: true - -# CHECK: registers: -# CHECK-NEXT: - { id: 0, class: gpr64 } -# CHECK-NEXT: - { id: 1, class: gpr64 } -# CHECK-NEXT: - { id: 2, class: gpr64 } -registers: - - { id: 0, class: gpr } - - { id: 1, class: gpr } - - { id: 2, class: gpr } - -# CHECK: body: -# CHECK: %0 = COPY %x0 -# CHECK: %1 = COPY %x1 -# CHECK: %2 = EORXrr %0, %1 -body: | - bb.0: - liveins: %x0, %x1 - - %0(s64) = COPY %x0 - %1(s64) = COPY %x1 - %2(s64) = G_XOR %0, %1 -... - ---- # Same as add_s32_gpr, for G_AND operations. # CHECK-LABEL: name: and_s32_gpr name: and_s32_gpr Index: test/TableGen/GlobalISelEmitter.td =================================================================== --- test/TableGen/GlobalISelEmitter.td +++ test/TableGen/GlobalISelEmitter.td @@ -7,7 +7,7 @@ def MyTargetISA : InstrInfo; def MyTarget : Target { let InstructionSet = MyTargetISA; } -def R0 : Register<"r0">; +def R0 : Register<"r0"> { let Namespace = "MyTarget"; } def GPR32 : RegisterClass<"MyTarget", [i32], 32, (add R0)>; class I Pat> @@ -23,34 +23,86 @@ // CHECK: bool MyTargetInstructionSelector::selectImpl(MachineInstr &I) const { // CHECK: const MachineRegisterInfo &MRI = I.getParent()->getParent()->getRegInfo(); - //===- Test a simple pattern with regclass operands. ----------------------===// -// CHECK: if ((I.getOpcode() == TargetOpcode::G_ADD) && -// CHECK-NEXT: ((/* Operand 0 */ (MRI.getType(I.getOperand(0).getReg()) == (LLT::scalar(32))) && +// CHECK-LABEL: if ((I.getOpcode() == TargetOpcode::G_ADD) && +// CHECK-NEXT: ((/* dst */ (MRI.getType(I.getOperand(0).getReg()) == (LLT::scalar(32))) && // CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(I.getOperand(0).getReg(), MRI, TRI))))) && -// CHECK-NEXT: ((/* Operand 1 */ (MRI.getType(I.getOperand(1).getReg()) == (LLT::scalar(32))) && +// CHECK-NEXT: ((/* src1 */ (MRI.getType(I.getOperand(1).getReg()) == (LLT::scalar(32))) && // CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(I.getOperand(1).getReg(), MRI, TRI))))) && -// CHECK-NEXT: ((/* Operand 2 */ (MRI.getType(I.getOperand(2).getReg()) == (LLT::scalar(32))) && +// CHECK-NEXT: ((/* src2 */ (MRI.getType(I.getOperand(2).getReg()) == (LLT::scalar(32))) && // CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(I.getOperand(2).getReg(), MRI, TRI)))))) { // CHECK-NEXT: // (add:i32 GPR32:i32:$src1, GPR32:i32:$src2) => (ADD:i32 GPR32:i32:$src1, GPR32:i32:$src2) // CHECK-NEXT: I.setDesc(TII.get(MyTarget::ADD)); -// CHECK-NEXT: constrainSelectedInstRegOperands(I, TII, TRI, RBI); +// CHECK-NEXT: MachineInstr &NewI = I; +// CHECK: constrainSelectedInstRegOperands(NewI, TII, TRI, RBI); // CHECK-NEXT: return true; // CHECK-NEXT: } def ADD : I<(outs GPR32:$dst), (ins GPR32:$src1, GPR32:$src2), [(set GPR32:$dst, (add GPR32:$src1, GPR32:$src2))]>; +// CHECK-LABEL: if ((I.getOpcode() == TargetOpcode::G_MUL) && +// CHECK-NEXT: ((/* dst */ (MRI.getType(I.getOperand(0).getReg()) == (LLT::scalar(32))) && +// CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(I.getOperand(0).getReg(), MRI, TRI))))) && +// CHECK-NEXT: ((/* src1 */ (MRI.getType(I.getOperand(1).getReg()) == (LLT::scalar(32))) && +// CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(I.getOperand(1).getReg(), MRI, TRI))))) && +// CHECK-NEXT: ((/* src2 */ (MRI.getType(I.getOperand(2).getReg()) == (LLT::scalar(32))) && +// CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(I.getOperand(2).getReg(), MRI, TRI)))))) { + +// CHECK-NEXT: // (mul:i32 GPR32:i32:$src1, GPR32:i32:$src2) => (MUL:i32 GPR32:i32:$src2, GPR32:i32:$src1) +// CHECK-NEXT: MachineInstrBuilder MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(MyTarget::MUL)); +// CHECK-NEXT: MIB.add(I.getOperand(0)/*dst*/); +// CHECK-NEXT: MIB.add(I.getOperand(2)/*src2*/); +// CHECK-NEXT: MIB.add(I.getOperand(1)/*src1*/); +// CHECK-NEXT: MIB.setMemRefs(I.memoperands_begin(), I.memoperands_end()); +// CHECK-NEXT: I.eraseFromParent(); +// CHECK-NEXT: MachineInstr &NewI = *MIB; +// CHECK: constrainSelectedInstRegOperands(NewI, TII, TRI, RBI); +// CHECK-NEXT: return true; +// CHECK-NEXT: } + +def MUL : I<(outs GPR32:$dst), (ins GPR32:$src2, GPR32:$src1), + [(set GPR32:$dst, (mul GPR32:$src1, GPR32:$src2))]>; + +//===- Test a simple pattern with constant immediate operands. ------------===// +// +// This must precede the 3-register variants because constant immediates have +// priority over register banks. + +// CHECK-LABEL: if ((I.getOpcode() == TargetOpcode::G_XOR) && +// CHECK-NEXT: ((/* dst */ (MRI.getType(I.getOperand(0).getReg()) == (LLT::scalar(32))) && +// CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(I.getOperand(0).getReg(), MRI, TRI))))) && +// CHECK-NEXT: ((/* Wm */ (MRI.getType(I.getOperand(1).getReg()) == (LLT::scalar(32))) && +// CHECK-NEXT: ((&RBI.getRegBankFromRegClass(MyTarget::GPR32RegClass) == RBI.getRegBank(I.getOperand(1).getReg(), MRI, TRI))))) && +// CHECK-NEXT: ((/* Operand 2 */ (MRI.getType(I.getOperand(2).getReg()) == (LLT::scalar(32))) && +// CHECK-NEXT: (isOperandImmEqual(I.getOperand(2), -1, MRI))))) { + +// CHECK-NEXT: // (xor:i32 GPR32:i32:$Wm, -1:i32) => (ORN:i32 R0:i32, GPR32:i32:$Wm) +// CHECK-NEXT: MachineInstrBuilder MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(MyTarget::ORN)); +// CHECK-NEXT: MIB.add(I.getOperand(0)/*dst*/); +// CHECK-NEXT: MIB.addReg(MyTarget::R0); +// CHECK-NEXT: MIB.add(I.getOperand(1)/*Wm*/); +// CHECK-NEXT: MIB.setMemRefs(I.memoperands_begin(), I.memoperands_end()); +// CHECK-NEXT: I.eraseFromParent(); +// CHECK-NEXT: MachineInstr &NewI = *MIB; +// CHECK: constrainSelectedInstRegOperands(NewI, TII, TRI, RBI); +// CHECK-NEXT: return true; +// CHECK-NEXT: } + +def ORN : I<(outs GPR32:$dst), (ins GPR32:$src1, GPR32:$src2), []>; +def : Pat<(not GPR32:$Wm), (ORN R0, GPR32:$Wm)>; + //===- Test a pattern with an MBB operand. --------------------------------===// -// CHECK: if ((I.getOpcode() == TargetOpcode::G_BR) && -// CHECK-NEXT: ((/* Operand 0 */ (I.getOperand(0).isMBB())))) { +// CHECK-LABEL: if ((I.getOpcode() == TargetOpcode::G_BR) && +// CHECK-NEXT: ((/* target */ (I.getOperand(0).isMBB())))) { // CHECK-NEXT: // (br (bb:Other):$target) => (BR (bb:Other):$target) // CHECK-NEXT: I.setDesc(TII.get(MyTarget::BR)); -// CHECK-NEXT: constrainSelectedInstRegOperands(I, TII, TRI, RBI); +// CHECK-NEXT: MachineInstr &NewI = I; +// CHECK: constrainSelectedInstRegOperands(NewI, TII, TRI, RBI); // CHECK-NEXT: return true; // CHECK-NEXT: } Index: unittests/CodeGen/LowLevelTypeTest.cpp =================================================================== --- unittests/CodeGen/LowLevelTypeTest.cpp +++ unittests/CodeGen/LowLevelTypeTest.cpp @@ -68,7 +68,7 @@ // Test Type->LLT conversion. Type *IRTy = IntegerType::get(C, S); - EXPECT_EQ(Ty, LLT(*IRTy, DL)); + EXPECT_EQ(Ty, getLLTForType(*IRTy, DL)); } } @@ -160,7 +160,7 @@ // Test Type->LLT conversion. Type *IRSTy = IntegerType::get(C, S); Type *IRTy = VectorType::get(IRSTy, Elts); - EXPECT_EQ(VTy, LLT(*IRTy, DL)); + EXPECT_EQ(VTy, getLLTForType(*IRTy, DL)); } } } @@ -188,7 +188,7 @@ // Test Type->LLT conversion. Type *IRTy = PointerType::get(IntegerType::get(C, 8), AS); - EXPECT_EQ(Ty, LLT(*IRTy, DL)); + EXPECT_EQ(Ty, getLLTForType(*IRTy, DL)); } } Index: utils/TableGen/GlobalISelEmitter.cpp =================================================================== --- utils/TableGen/GlobalISelEmitter.cpp +++ utils/TableGen/GlobalISelEmitter.cpp @@ -56,8 +56,6 @@ "in the GlobalISel selector"), cl::init(false)); -namespace { - //===- Helper functions ---------------------------------------------------===// /// Convert an MVT to an equivalent LLT if possible, or the invalid LLT() for @@ -103,6 +101,7 @@ iterator_range predicates() const { return make_range(predicates_begin(), predicates_end()); } + typename PredicateVec::size_type predicates_size() const { return Predicates.size(); } /// Emit a C++ expression that tests whether all the predicates are met. template @@ -130,11 +129,39 @@ /// * Operand is an MBB. class OperandPredicateMatcher { public: + /// This enum is used for RTTI and also defines the priority that is given to + /// the predicate when generating the matcher code. Kinds with higher priority + /// must be tested first. + /// + /// The relative priority of OPM_LLT, OPM_RegBank, and OPM_MBB do not matter + /// but OPM_Int must have priority over OPM_RegBank since constant integers + /// are represented by a virtual register defined by a G_CONSTANT instruction. + enum PredicateKind { + OPM_Int, + OPM_LLT, + OPM_RegBank, + OPM_MBB, + }; + +protected: + PredicateKind Kind; + +public: + OperandPredicateMatcher(PredicateKind Kind) : Kind(Kind) {} virtual ~OperandPredicateMatcher() {} + PredicateKind getKind() const { return Kind; } + /// Emit a C++ expression that checks the predicate for the given operand. virtual void emitCxxPredicateExpr(raw_ostream &OS, StringRef OperandExpr) const = 0; + + /// Compare the priority of this object and B. + /// + /// Returns true if this object is more important than B. + virtual bool isHigherPriorityThan(const OperandPredicateMatcher &B) const { + return Kind < B.Kind; + }; }; /// Generates code to check that an operand is a particular LLT. @@ -143,7 +170,12 @@ std::string Ty; public: - LLTOperandMatcher(std::string Ty) : Ty(Ty) {} + LLTOperandMatcher(std::string Ty) + : OperandPredicateMatcher(OPM_LLT), Ty(Ty) {} + + static bool classof(const OperandPredicateMatcher *P) { + return P->getKind() == OPM_LLT; + } void emitCxxPredicateExpr(raw_ostream &OS, StringRef OperandExpr) const override { @@ -157,7 +189,12 @@ const CodeGenRegisterClass &RC; public: - RegisterBankOperandMatcher(const CodeGenRegisterClass &RC) : RC(RC) {} + RegisterBankOperandMatcher(const CodeGenRegisterClass &RC) + : OperandPredicateMatcher(OPM_RegBank), RC(RC) {} + + static bool classof(const OperandPredicateMatcher *P) { + return P->getKind() == OPM_RegBank; + } void emitCxxPredicateExpr(raw_ostream &OS, StringRef OperandExpr) const override { @@ -170,31 +207,89 @@ /// Generates code to check that an operand is a basic block. class MBBOperandMatcher : public OperandPredicateMatcher { public: + MBBOperandMatcher() : OperandPredicateMatcher(OPM_MBB) {} + + static bool classof(const OperandPredicateMatcher *P) { + return P->getKind() == OPM_MBB; + } + void emitCxxPredicateExpr(raw_ostream &OS, StringRef OperandExpr) const override { OS << OperandExpr << ".isMBB()"; } }; +/// Generates code to check that an operand is a particular int. +class IntOperandMatcher : public OperandPredicateMatcher { +protected: + int64_t Value; + +public: + IntOperandMatcher(int64_t Value) + : OperandPredicateMatcher(OPM_Int), Value(Value) {} + + static bool classof(const OperandPredicateMatcher *P) { + return P->getKind() == OPM_Int; + } + + void emitCxxPredicateExpr(raw_ostream &OS, + const StringRef OperandExpr) const override { + OS << "isOperandImmEqual(" << OperandExpr << ", " << Value << ", MRI)"; + } +}; + /// Generates code to check that a set of predicates match for a particular /// operand. class OperandMatcher : public PredicateListMatcher { protected: unsigned OpIdx; + std::string SymbolicName; public: - OperandMatcher(unsigned OpIdx) : OpIdx(OpIdx) {} + OperandMatcher(unsigned OpIdx, const std::string &SymbolicName) + : OpIdx(OpIdx), SymbolicName(SymbolicName) {} + + bool hasSymbolicName() const { return !SymbolicName.empty(); } + const StringRef getSymbolicName() const { return SymbolicName; } + unsigned getOperandIndex() const { return OpIdx; } + std::string getOperandExpr(StringRef InsnVarName) const { return (InsnVarName + ".getOperand(" + llvm::to_string(OpIdx) + ")").str(); } /// Emit a C++ expression that tests whether the instruction named in /// InsnVarName matches all the predicate and all the operands. - void emitCxxPredicateExpr(raw_ostream &OS, StringRef InsnVarName) const { - OS << "(/* Operand " << OpIdx << " */ "; + void emitCxxPredicateExpr(raw_ostream &OS, const StringRef InsnVarName) const { + OS << "(/* "; + if (SymbolicName.empty()) + OS << "Operand " << OpIdx; + else + OS << SymbolicName; + OS << " */ "; emitCxxPredicateListExpr(OS, getOperandExpr(InsnVarName)); OS << ")"; } + + /// Compare the priority of this object and B. + /// + /// Returns true if this object is more important than B. + bool isHigherPriorityThan(const OperandMatcher &B) const { + // Operand matchers involving more predicates have higher priority. + if (predicates_size() > B.predicates_size()) + return true; + if (predicates_size() < B.predicates_size()) + return false; + + // This assumes that predicates are added in a consistent order. + for (const auto &Predicate : zip(predicates(), B.predicates())) { + if (std::get<0>(Predicate)->isHigherPriorityThan(*std::get<1>(Predicate))) + return true; + if (std::get<1>(Predicate)->isHigherPriorityThan(*std::get<0>(Predicate))) + return false; + } + + return false; + }; }; /// Generates code to check a predicate on an instruction. @@ -203,13 +298,33 @@ /// * The opcode of the instruction is a particular value. /// * The nsw/nuw flag is/isn't set. class InstructionPredicateMatcher { +protected: + /// This enum is used for RTTI and also defines the priority that is given to + /// the predicate when generating the matcher code. Kinds with higher priority + /// must be tested first. + enum PredicateKind { + IPM_Opcode, + }; + + PredicateKind Kind; + public: + InstructionPredicateMatcher(PredicateKind Kind) : Kind(Kind) {} virtual ~InstructionPredicateMatcher() {} + PredicateKind getKind() const { return Kind; } + /// Emit a C++ expression that tests whether the instruction named in /// InsnVarName matches the predicate. virtual void emitCxxPredicateExpr(raw_ostream &OS, StringRef InsnVarName) const = 0; + + /// Compare the priority of this object and B. + /// + /// Returns true if this object is more important than B. + virtual bool isHigherPriorityThan(const InstructionPredicateMatcher &B) const { + return Kind < B.Kind; + }; }; /// Generates code to check the opcode of an instruction. @@ -218,13 +333,37 @@ const CodeGenInstruction *I; public: - InstructionOpcodeMatcher(const CodeGenInstruction *I) : I(I) {} + InstructionOpcodeMatcher(const CodeGenInstruction *I) + : InstructionPredicateMatcher(IPM_Opcode), I(I) {} + + static bool classof(const InstructionPredicateMatcher *P) { + return P->getKind() == IPM_Opcode; + } void emitCxxPredicateExpr(raw_ostream &OS, StringRef InsnVarName) const override { OS << InsnVarName << ".getOpcode() == " << I->Namespace << "::" << I->TheDef->getName(); } + + /// Compare the priority of this object and B. + /// + /// Returns true if this object is more important than B. + bool isHigherPriorityThan(const InstructionPredicateMatcher &B) const override { + if (InstructionPredicateMatcher::isHigherPriorityThan(B)) + return true; + if (B.InstructionPredicateMatcher::isHigherPriorityThan(*this)) + return false; + + // Prioritize opcodes for cosmetic reasons in the generated source. Although + // this is cosmetic at the moment, we may want to drive a similar ordering + // using instruction frequency information to improve compile time. + if (const InstructionOpcodeMatcher *BO = + dyn_cast(&B)) + return I->TheDef->getName() < BO->I->TheDef->getName(); + + return false; + }; }; /// Generates code to check that a set of predicates and operands match for a @@ -236,15 +375,41 @@ class InstructionMatcher : public PredicateListMatcher { protected: - std::vector Operands; + typedef std::vector OperandVec; + + /// The operands to match. All rendered operands must be present even if the + /// condition is always true. + OperandVec Operands; public: /// Add an operand to the matcher. - OperandMatcher &addOperand(unsigned OpIdx) { - Operands.emplace_back(OpIdx); + OperandMatcher &addOperand(unsigned OpIdx, const std::string &SymbolicName) { + Operands.emplace_back(OpIdx, SymbolicName); return Operands.back(); } + const OperandMatcher &getOperand(const StringRef SymbolicName) const { + assert(!SymbolicName.empty() && "Cannot lookup unnamed operand"); + const auto &I = std::find_if(Operands.begin(), Operands.end(), + [&SymbolicName](const OperandMatcher &X) { + return X.getSymbolicName() == SymbolicName; + }); + if (I != Operands.end()) + return *I; + llvm_unreachable("Failed to lookup operand"); + } + + unsigned getNumOperands() const { return Operands.size(); } + OperandVec::const_iterator operands_begin() const { + return Operands.begin(); + } + OperandVec::const_iterator operands_end() const { + return Operands.end(); + } + iterator_range operands() const { + return make_range(operands_begin(), operands_end()); + } + /// Emit a C++ expression that tests whether the instruction named in /// InsnVarName matches all the predicates and all the operands. void emitCxxPredicateExpr(raw_ostream &OS, StringRef InsnVarName) const { @@ -255,10 +420,106 @@ OS << ")"; } } + + /// Compare the priority of this object and B. + /// + /// Returns true if this object is more important than B. + bool isHigherPriorityThan(const InstructionMatcher &B) const { + // Instruction matchers involving more operands have higher priority. + if (Operands.size() > B.Operands.size()) + return true; + if (Operands.size() < B.Operands.size()) + return false; + + for (const auto &Predicate : zip(predicates(), B.predicates())) { + if (std::get<0>(Predicate)->isHigherPriorityThan(*std::get<1>(Predicate))) + return true; + if (std::get<1>(Predicate)->isHigherPriorityThan(*std::get<0>(Predicate))) + return false; + } + + for (const auto &Operand : zip(Operands, B.Operands)) { + if (std::get<0>(Operand).isHigherPriorityThan(std::get<1>(Operand))) + return true; + if (std::get<1>(Operand).isHigherPriorityThan(std::get<0>(Operand))) + return false; + } + + return false; + }; }; //===- Actions ------------------------------------------------------------===// +namespace { +class OperandRenderer { +public: + enum RendererKind { OR_Copy, OR_Register }; + +protected: + RendererKind Kind; + +public: + OperandRenderer(RendererKind Kind) : Kind(Kind) {} + virtual ~OperandRenderer() {} + + RendererKind getKind() const { return Kind; } + + virtual void emitCxxRenderStmts(raw_ostream &OS) const = 0; +}; + +/// A CopyRenderer emits code to copy a single operand from an existing +/// instruction to the one being built. +class CopyRenderer : public OperandRenderer { +protected: + /// The matcher for the instruction that this operand is copied from. + /// This provides the facility for looking up an a operand by it's name so + /// that it can be used as a source for the instruction being built. + const InstructionMatcher &Matched; + /// The name of the instruction to copy from. + const StringRef InsnVarName; + /// The name of the operand. + const StringRef SymbolicName; + +public: + CopyRenderer(const InstructionMatcher &Matched, const StringRef InsnVarName, + const StringRef SymbolicName) + : OperandRenderer(OR_Copy), Matched(Matched), InsnVarName(InsnVarName), + SymbolicName(SymbolicName) {} + + static bool classof(const OperandRenderer *R) { + return R->getKind() == OR_Copy; + } + + const StringRef getSymbolicName() const { return SymbolicName; } + + void emitCxxRenderStmts(raw_ostream &OS) const override { + std::string OperandExpr = + Matched.getOperand(SymbolicName).getOperandExpr(InsnVarName); + OS << " MIB.add(" << OperandExpr << "/*" << SymbolicName << "*/);\n"; + } +}; + +/// Adds a specific physical register to the instruction being built. +/// This is typically useful for WZR/XZR on AArch64. +class AddRegisterRenderer : public OperandRenderer { +protected: + const Record *RegisterDef; + +public: + AddRegisterRenderer(const Record *RegisterDef) + : OperandRenderer(OR_Register), RegisterDef(RegisterDef) {} + + static bool classof(const OperandRenderer *R) { + return R->getKind() == OR_Register; + } + + void emitCxxRenderStmts(raw_ostream &OS) const override { + OS << " MIB.addReg(" << RegisterDef->getValueAsString("Namespace") + << "::" << RegisterDef->getName() << ");\n"; + } +}; + /// An action taken when all Matcher predicates succeeded for a parent rule. /// /// Typical actions include: @@ -267,7 +528,14 @@ class MatchAction { public: virtual ~MatchAction() {} - virtual void emitCxxActionStmts(raw_ostream &OS) const = 0; + + /// Emit the C++ statements to implement the action. + /// + /// \param InsnVarName If given, it's an instruction to recycle. The + /// requirements on the instruction vary from action to + /// action. + virtual void emitCxxActionStmts(raw_ostream &OS, + const StringRef InsnVarName) const = 0; }; /// Generates a comment describing the matched rule being acted upon. @@ -278,23 +546,65 @@ public: DebugCommentAction(const PatternToMatch &P) : P(P) {} - virtual void emitCxxActionStmts(raw_ostream &OS) const { + void emitCxxActionStmts(raw_ostream &OS, + const StringRef InsnVarName) const override { OS << "// " << *P.getSrcPattern() << " => " << *P.getDstPattern(); } }; -/// Generates code to set the opcode (really, the MCInstrDesc) of a matched -/// instruction to a given Instruction. -class MutateOpcodeAction : public MatchAction { +/// Generates code to build an instruction or mutate an existing instruction +/// into the desired instruction when this is possible. +class BuildMIAction : public MatchAction { private: const CodeGenInstruction *I; + const InstructionMatcher &Matched; + std::vector> OperandRenderers; + + /// True if the instruction can be built solely by mutating the opcode. + bool canMutate() const { + for (const auto &Renderer : enumerate(OperandRenderers)) { + if (const auto *Copy = dyn_cast(&*Renderer.Value)) { + if (Matched.getOperand(Copy->getSymbolicName()).getOperandIndex() != + Renderer.Index) + return false; + } else + return false; + } + + return true; + } public: - MutateOpcodeAction(const CodeGenInstruction *I) : I(I) {} + BuildMIAction(const CodeGenInstruction *I, const InstructionMatcher &Matched) + : I(I), Matched(Matched) {} - virtual void emitCxxActionStmts(raw_ostream &OS) const { - OS << "I.setDesc(TII.get(" << I->Namespace << "::" << I->TheDef->getName() - << "));"; + template + Kind &addRenderer(Args&&... args) { + OperandRenderers.emplace_back( + llvm::make_unique(std::forward(args)...)); + return *static_cast(OperandRenderers.back().get()); + } + + virtual void emitCxxActionStmts(raw_ostream &OS, + const StringRef InsnVarName) const { + if (canMutate()) { + OS << "I.setDesc(TII.get(" << I->Namespace << "::" << I->TheDef->getName() + << "));\n"; + OS << " MachineInstr &NewI = I;\n"; + return; + } + + // TODO: Simple permutation looks like it could be almost as common as + // mutation due to commutative operations. + + OS << "MachineInstrBuilder MIB = BuildMI(*I.getParent(), I, " + "I.getDebugLoc(), TII.get(" + << I->Namespace << "::" << I->TheDef->getName() << "));\n"; + for (const auto &Renderer : OperandRenderers) + Renderer->emitCxxRenderStmts(OS); + OS << " MIB.setMemRefs(I.memoperands_begin(), I.memoperands_end());\n"; + OS << " " << InsnVarName << ".eraseFromParent();\n"; + OS << " MachineInstr &NewI = *MIB;\n"; } }; @@ -344,14 +654,34 @@ for (const auto &MA : Actions) { OS << " "; - MA->emitCxxActionStmts(OS); + MA->emitCxxActionStmts(OS, "I"); OS << "\n"; } - OS << " constrainSelectedInstRegOperands(I, TII, TRI, RBI);\n"; + OS << " constrainSelectedInstRegOperands(NewI, TII, TRI, RBI);\n"; OS << " return true;\n"; OS << " }\n\n"; } + + /// Compare the priority of this object and B. + /// + /// Returns true if this object is more important than B. + bool isHigherPriorityThan(const RuleMatcher &B) const { + // Rules involving more match roots have higher priority. + if (Matchers.size() > B.Matchers.size()) + return true; + if (Matchers.size() < B.Matchers.size()) + return false; + + for (const auto &Matcher : zip(Matchers, B.Matchers)) { + if (std::get<0>(Matcher)->isHigherPriorityThan(*std::get<1>(Matcher))) + return true; + if (std::get<1>(Matcher)->isHigherPriorityThan(*std::get<0>(Matcher))) + return false; + } + + return false; + }; }; //===- GlobalISelEmitter class --------------------------------------------===// @@ -437,7 +767,7 @@ // The operators look good: match the opcode and mutate it to the new one. InstructionMatcher &InsnMatcher = M.addInstructionMatcher(); InsnMatcher.addPredicate(&SrcGI); - M.addAction(&DstI); + auto &DstMIBuilder = M.addAction(&DstI, InsnMatcher); // Next, analyze the children, only accepting patterns that don't require // any change to operands. @@ -451,7 +781,8 @@ return failedImport("Src pattern results and dst MI defs are different"); for (const EEVT::TypeSet &Ty : Src->getExtTypes()) { - Record *DstIOpRec = DstI.Operands[OpIdx].Rec; + const auto &DstIOperand = DstI.Operands[OpIdx]; + Record *DstIOpRec = DstIOperand.Rec; if (!DstIOpRec->isSubClassOf("RegisterClass")) return failedImport("Dst MI def isn't a register class"); @@ -459,48 +790,35 @@ if (!OpTyOrNone) return failedImport("Dst operand has an unsupported type"); - OperandMatcher &OM = InsnMatcher.addOperand(OpIdx); + OperandMatcher &OM = InsnMatcher.addOperand(OpIdx, DstIOperand.Name); OM.addPredicate(*OpTyOrNone); OM.addPredicate( Target.getRegisterClass(DstIOpRec)); + DstMIBuilder.addRenderer(InsnMatcher, "I", DstIOperand.Name); ++OpIdx; } // Finally match the used operands (i.e., the children of the root operator). for (unsigned i = 0, e = Src->getNumChildren(); i != e; ++i) { auto *SrcChild = Src->getChild(i); - auto *DstChild = Dst->getChild(i); - // Patterns can reorder operands. Ignore those for now. - if (SrcChild->getName() != DstChild->getName()) - return failedImport("Src/dst pattern children not in same order"); + OperandMatcher &OM = InsnMatcher.addOperand(OpIdx++, SrcChild->getName()); // The only non-leaf child we accept is 'bb': it's an operator because // BasicBlockSDNode isn't inline, but in MI it's just another operand. if (!SrcChild->isLeaf()) { - if (DstChild->isLeaf() || - SrcChild->getOperator() != DstChild->getOperator()) - return failedImport("Src/dst pattern child operator mismatch"); - if (SrcChild->getOperator()->isSubClassOf("SDNode")) { auto &ChildSDNI = CGP.getSDNodeInfo(SrcChild->getOperator()); if (ChildSDNI.getSDClassName() == "BasicBlockSDNode") { - InsnMatcher.addOperand(OpIdx++).addPredicate(); + OM.addPredicate(); continue; } } - return failedImport("Src pattern child isn't a leaf node"); + return failedImport("Src pattern child isn't a leaf node or an MBB"); } - if (SrcChild->getLeafValue() != DstChild->getLeafValue()) - return failedImport("Src/dst pattern child leaf mismatch"); - - // Otherwise, we're looking for a bog-standard RegisterClass operand. if (SrcChild->hasAnyPredicate()) return failedImport("Src pattern child has predicate"); - auto *ChildRec = cast(SrcChild->getLeafValue())->getDef(); - if (!ChildRec->isSubClassOf("RegisterClass")) - return failedImport("Src pattern child isn't a RegisterClass"); ArrayRef ChildTypes = SrcChild->getExtTypes(); if (ChildTypes.size() != 1) @@ -509,12 +827,77 @@ auto OpTyOrNone = MVTToLLT(ChildTypes.front().getConcrete()); if (!OpTyOrNone) return failedImport("Src operand has an unsupported type"); - - OperandMatcher &OM = InsnMatcher.addOperand(OpIdx); OM.addPredicate(*OpTyOrNone); - OM.addPredicate( - Target.getRegisterClass(ChildRec)); - ++OpIdx; + + if (auto *ChildInt = dyn_cast(SrcChild->getLeafValue())) { + OM.addPredicate(ChildInt->getValue()); + continue; + } + + if (auto *ChildDefInit = dyn_cast(SrcChild->getLeafValue())) { + auto *ChildRec = ChildDefInit->getDef(); + + // Otherwise, we're looking for a bog-standard RegisterClass operand. + if (!ChildRec->isSubClassOf("RegisterClass")) + return failedImport("Src pattern child isn't a RegisterClass"); + + OM.addPredicate( + Target.getRegisterClass(ChildRec)); + continue; + } + + return failedImport("Src pattern child is an unsupported kind"); + } + + // Finally render the used operands (i.e., the children of the root operator). + for (unsigned i = 0, e = Dst->getNumChildren(); i != e; ++i) { + auto *DstChild = Dst->getChild(i); + + // The only non-leaf child we accept is 'bb': it's an operator because + // BasicBlockSDNode isn't inline, but in MI it's just another operand. + if (!DstChild->isLeaf()) { + if (DstChild->getOperator()->isSubClassOf("SDNode")) { + auto &ChildSDNI = CGP.getSDNodeInfo(DstChild->getOperator()); + if (ChildSDNI.getSDClassName() == "BasicBlockSDNode") { + DstMIBuilder.addRenderer(InsnMatcher, "I", + DstChild->getName()); + continue; + } + } + return failedImport("Dst pattern child isn't a leaf node or an MBB"); + } + + // Otherwise, we're looking for a bog-standard RegisterClass operand. + if (DstChild->hasAnyPredicate()) + return failedImport("Dst pattern child has predicate"); + + if (auto *ChildDefInit = dyn_cast(DstChild->getLeafValue())) { + auto *ChildRec = ChildDefInit->getDef(); + + ArrayRef ChildTypes = DstChild->getExtTypes(); + if (ChildTypes.size() != 1) + return failedImport("Dst pattern child has multiple results"); + + auto OpTyOrNone = MVTToLLT(ChildTypes.front().getConcrete()); + if (!OpTyOrNone) + return failedImport("Dst operand has an unsupported type"); + + if (ChildRec->isSubClassOf("Register")) { + DstMIBuilder.addRenderer(ChildRec); + continue; + } + + if (ChildRec->isSubClassOf("RegisterClass")) { + DstMIBuilder.addRenderer(InsnMatcher, "I", + DstChild->getName()); + continue; + } + + return failedImport( + "Dst pattern child def is an unsupported tablegen class"); + } + + return failedImport("Src pattern child is an unsupported kind"); } // We're done with this pattern! It's eligible for GISel emission; return it. @@ -555,6 +938,17 @@ Rules.push_back(std::move(MatcherOrErr.get())); } + std::stable_sort(Rules.begin(), Rules.end(), + [&](const RuleMatcher &A, const RuleMatcher &B) { + if (A.isHigherPriorityThan(B)) { + assert(!B.isHigherPriorityThan(A) && "Cannot be more important " + "and less important at " + "the same time"); + return true; + } + return false; + }); + for (const auto &Rule : Rules) { Rule.emit(OS); ++NumPatternEmitted;