diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h index df4f6429dcb8..4822abc46300 100644 --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -1,1788 +1,1787 @@ //===- BasicTTIImpl.h -------------------------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // /// \file /// This file provides a helper that implements much of the TTI interface in /// terms of the target-independent code generator and TargetLowering /// interfaces. // //===----------------------------------------------------------------------===// #ifndef LLVM_CODEGEN_BASICTTIIMPL_H #define LLVM_CODEGEN_BASICTTIIMPL_H #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/BitVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/TargetTransformInfoImpl.h" #include "llvm/CodeGen/ISDOpcodes.h" #include "llvm/CodeGen/TargetLowering.h" #include "llvm/CodeGen/TargetSubtargetInfo.h" #include "llvm/CodeGen/ValueTypes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Operator.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/MC/MCSchedule.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MachineValueType.h" #include "llvm/Support/MathExtras.h" #include #include #include #include #include namespace llvm { class Function; class GlobalValue; class LLVMContext; class ScalarEvolution; class SCEV; class TargetMachine; extern cl::opt PartialUnrollingThreshold; /// Base class which can be used to help build a TTI implementation. /// /// This class provides as much implementation of the TTI interface as is /// possible using the target independent parts of the code generator. /// /// In order to subclass it, your class must implement a getST() method to /// return the subtarget, and a getTLI() method to return the target lowering. /// We need these methods implemented in the derived class so that this class /// doesn't have to duplicate storage for them. template class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase { private: using BaseT = TargetTransformInfoImplCRTPBase; using TTI = TargetTransformInfo; /// Estimate a cost of Broadcast as an extract and sequence of insert /// operations. unsigned getBroadcastShuffleOverhead(Type *Ty) { auto *VTy = cast(Ty); unsigned Cost = 0; // Broadcast cost is equal to the cost of extracting the zero'th element // plus the cost of inserting it into every element of the result vector. Cost += static_cast(this)->getVectorInstrCost( Instruction::ExtractElement, VTy, 0); for (int i = 0, e = VTy->getNumElements(); i < e; ++i) { Cost += static_cast(this)->getVectorInstrCost( Instruction::InsertElement, VTy, i); } return Cost; } /// Estimate a cost of shuffle as a sequence of extract and insert /// operations. unsigned getPermuteShuffleOverhead(Type *Ty) { auto *VTy = cast(Ty); unsigned Cost = 0; // Shuffle cost is equal to the cost of extracting element from its argument // plus the cost of inserting them onto the result vector. // e.g. <4 x float> has a mask of <0,5,2,7> i.e we need to extract from // index 0 of first vector, index 1 of second vector,index 2 of first // vector and finally index 3 of second vector and insert them at index // <0,1,2,3> of result vector. for (int i = 0, e = VTy->getNumElements(); i < e; ++i) { Cost += static_cast(this)->getVectorInstrCost( Instruction::InsertElement, VTy, i); Cost += static_cast(this)->getVectorInstrCost( Instruction::ExtractElement, VTy, i); } return Cost; } /// Estimate a cost of subvector extraction as a sequence of extract and /// insert operations. unsigned getExtractSubvectorOverhead(Type *Ty, int Index, Type *SubTy) { assert(Ty && Ty->isVectorTy() && SubTy && SubTy->isVectorTy() && "Can only extract subvectors from vectors"); auto *VTy = cast(Ty); auto *SubVTy = cast(SubTy); int NumSubElts = SubVTy->getNumElements(); assert((Index + NumSubElts) <= (int)VTy->getNumElements() && "SK_ExtractSubvector index out of range"); unsigned Cost = 0; // Subvector extraction cost is equal to the cost of extracting element from // the source type plus the cost of inserting them into the result vector // type. for (int i = 0; i != NumSubElts; ++i) { Cost += static_cast(this)->getVectorInstrCost( Instruction::ExtractElement, VTy, i + Index); Cost += static_cast(this)->getVectorInstrCost( Instruction::InsertElement, SubVTy, i); } return Cost; } /// Estimate a cost of subvector insertion as a sequence of extract and /// insert operations. unsigned getInsertSubvectorOverhead(Type *Ty, int Index, Type *SubTy) { assert(Ty && Ty->isVectorTy() && SubTy && SubTy->isVectorTy() && "Can only insert subvectors into vectors"); auto *VTy = cast(Ty); auto *SubVTy = cast(SubTy); int NumSubElts = SubVTy->getNumElements(); assert((Index + NumSubElts) <= (int)VTy->getNumElements() && "SK_InsertSubvector index out of range"); unsigned Cost = 0; // Subvector insertion cost is equal to the cost of extracting element from // the source type plus the cost of inserting them into the result vector // type. for (int i = 0; i != NumSubElts; ++i) { Cost += static_cast(this)->getVectorInstrCost( Instruction::ExtractElement, SubVTy, i); Cost += static_cast(this)->getVectorInstrCost( Instruction::InsertElement, VTy, i + Index); } return Cost; } /// Local query method delegates up to T which *must* implement this! const TargetSubtargetInfo *getST() const { return static_cast(this)->getST(); } /// Local query method delegates up to T which *must* implement this! const TargetLoweringBase *getTLI() const { return static_cast(this)->getTLI(); } static ISD::MemIndexedMode getISDIndexedMode(TTI::MemIndexedMode M) { switch (M) { case TTI::MIM_Unindexed: return ISD::UNINDEXED; case TTI::MIM_PreInc: return ISD::PRE_INC; case TTI::MIM_PreDec: return ISD::PRE_DEC; case TTI::MIM_PostInc: return ISD::POST_INC; case TTI::MIM_PostDec: return ISD::POST_DEC; } llvm_unreachable("Unexpected MemIndexedMode"); } protected: explicit BasicTTIImplBase(const TargetMachine *TM, const DataLayout &DL) : BaseT(DL) {} virtual ~BasicTTIImplBase() = default; using TargetTransformInfoImplBase::DL; public: /// \name Scalar TTI Implementations /// @{ bool allowsMisalignedMemoryAccesses(LLVMContext &Context, unsigned BitWidth, unsigned AddressSpace, unsigned Alignment, bool *Fast) const { EVT E = EVT::getIntegerVT(Context, BitWidth); return getTLI()->allowsMisalignedMemoryAccesses( E, AddressSpace, Alignment, MachineMemOperand::MONone, Fast); } bool hasBranchDivergence() { return false; } bool useGPUDivergenceAnalysis() { return false; } bool isSourceOfDivergence(const Value *V) { return false; } bool isAlwaysUniform(const Value *V) { return false; } unsigned getFlatAddressSpace() { // Return an invalid address space. return -1; } bool collectFlatAddressOperands(SmallVectorImpl &OpIndexes, Intrinsic::ID IID) const { return false; } bool rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV, Value *NewV) const { return false; } bool isLegalAddImmediate(int64_t imm) { return getTLI()->isLegalAddImmediate(imm); } bool isLegalICmpImmediate(int64_t imm) { return getTLI()->isLegalICmpImmediate(imm); } bool isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset, bool HasBaseReg, int64_t Scale, unsigned AddrSpace, Instruction *I = nullptr) { TargetLoweringBase::AddrMode AM; AM.BaseGV = BaseGV; AM.BaseOffs = BaseOffset; AM.HasBaseReg = HasBaseReg; AM.Scale = Scale; return getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace, I); } bool isIndexedLoadLegal(TTI::MemIndexedMode M, Type *Ty, const DataLayout &DL) const { EVT VT = getTLI()->getValueType(DL, Ty); return getTLI()->isIndexedLoadLegal(getISDIndexedMode(M), VT); } bool isIndexedStoreLegal(TTI::MemIndexedMode M, Type *Ty, const DataLayout &DL) const { EVT VT = getTLI()->getValueType(DL, Ty); return getTLI()->isIndexedStoreLegal(getISDIndexedMode(M), VT); } bool isLSRCostLess(TTI::LSRCost C1, TTI::LSRCost C2) { return TargetTransformInfoImplBase::isLSRCostLess(C1, C2); } int getScalingFactorCost(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset, bool HasBaseReg, int64_t Scale, unsigned AddrSpace) { TargetLoweringBase::AddrMode AM; AM.BaseGV = BaseGV; AM.BaseOffs = BaseOffset; AM.HasBaseReg = HasBaseReg; AM.Scale = Scale; return getTLI()->getScalingFactorCost(DL, AM, Ty, AddrSpace); } bool isTruncateFree(Type *Ty1, Type *Ty2) { return getTLI()->isTruncateFree(Ty1, Ty2); } bool isProfitableToHoist(Instruction *I) { return getTLI()->isProfitableToHoist(I); } bool useAA() const { return getST()->useAA(); } bool isTypeLegal(Type *Ty) { EVT VT = getTLI()->getValueType(DL, Ty); return getTLI()->isTypeLegal(VT); } int getGEPCost(Type *PointeeType, const Value *Ptr, ArrayRef Operands) { return BaseT::getGEPCost(PointeeType, Ptr, Operands); } int getExtCost(const Instruction *I, const Value *Src) { if (getTLI()->isExtFree(I)) return TargetTransformInfo::TCC_Free; if (isa(I) || isa(I)) if (const LoadInst *LI = dyn_cast(Src)) if (getTLI()->isExtLoad(LI, I, DL)) return TargetTransformInfo::TCC_Free; return TargetTransformInfo::TCC_Basic; } unsigned getIntrinsicCost(Intrinsic::ID IID, Type *RetTy, ArrayRef Arguments, const User *U) { return BaseT::getIntrinsicCost(IID, RetTy, Arguments, U); } unsigned getIntrinsicCost(Intrinsic::ID IID, Type *RetTy, ArrayRef ParamTys, const User *U) { if (IID == Intrinsic::cttz) { if (getTLI()->isCheapToSpeculateCttz()) return TargetTransformInfo::TCC_Basic; return TargetTransformInfo::TCC_Expensive; } if (IID == Intrinsic::ctlz) { if (getTLI()->isCheapToSpeculateCtlz()) return TargetTransformInfo::TCC_Basic; return TargetTransformInfo::TCC_Expensive; } return BaseT::getIntrinsicCost(IID, RetTy, ParamTys, U); } unsigned getEstimatedNumberOfCaseClusters(const SwitchInst &SI, unsigned &JumpTableSize, ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI) { /// Try to find the estimated number of clusters. Note that the number of /// clusters identified in this function could be different from the actual /// numbers found in lowering. This function ignore switches that are /// lowered with a mix of jump table / bit test / BTree. This function was /// initially intended to be used when estimating the cost of switch in /// inline cost heuristic, but it's a generic cost model to be used in other /// places (e.g., in loop unrolling). unsigned N = SI.getNumCases(); const TargetLoweringBase *TLI = getTLI(); const DataLayout &DL = this->getDataLayout(); JumpTableSize = 0; bool IsJTAllowed = TLI->areJTsAllowed(SI.getParent()->getParent()); // Early exit if both a jump table and bit test are not allowed. if (N < 1 || (!IsJTAllowed && DL.getIndexSizeInBits(0u) < N)) return N; APInt MaxCaseVal = SI.case_begin()->getCaseValue()->getValue(); APInt MinCaseVal = MaxCaseVal; for (auto CI : SI.cases()) { const APInt &CaseVal = CI.getCaseValue()->getValue(); if (CaseVal.sgt(MaxCaseVal)) MaxCaseVal = CaseVal; if (CaseVal.slt(MinCaseVal)) MinCaseVal = CaseVal; } // Check if suitable for a bit test if (N <= DL.getIndexSizeInBits(0u)) { SmallPtrSet Dests; for (auto I : SI.cases()) Dests.insert(I.getCaseSuccessor()); if (TLI->isSuitableForBitTests(Dests.size(), N, MinCaseVal, MaxCaseVal, DL)) return 1; } // Check if suitable for a jump table. if (IsJTAllowed) { if (N < 2 || N < TLI->getMinimumJumpTableEntries()) return N; uint64_t Range = (MaxCaseVal - MinCaseVal) .getLimitedValue(std::numeric_limits::max() - 1) + 1; // Check whether a range of clusters is dense enough for a jump table if (TLI->isSuitableForJumpTable(&SI, N, Range, PSI, BFI)) { JumpTableSize = Range; return 1; } } return N; } bool shouldBuildLookupTables() { const TargetLoweringBase *TLI = getTLI(); return TLI->isOperationLegalOrCustom(ISD::BR_JT, MVT::Other) || TLI->isOperationLegalOrCustom(ISD::BRIND, MVT::Other); } bool haveFastSqrt(Type *Ty) { const TargetLoweringBase *TLI = getTLI(); EVT VT = TLI->getValueType(DL, Ty); return TLI->isTypeLegal(VT) && TLI->isOperationLegalOrCustom(ISD::FSQRT, VT); } bool isFCmpOrdCheaperThanFCmpZero(Type *Ty) { return true; } unsigned getFPOpCost(Type *Ty) { // Check whether FADD is available, as a proxy for floating-point in // general. const TargetLoweringBase *TLI = getTLI(); EVT VT = TLI->getValueType(DL, Ty); if (TLI->isOperationLegalOrCustomOrPromote(ISD::FADD, VT)) return TargetTransformInfo::TCC_Basic; return TargetTransformInfo::TCC_Expensive; } unsigned getOperationCost(unsigned Opcode, Type *Ty, Type *OpTy) { const TargetLoweringBase *TLI = getTLI(); switch (Opcode) { default: break; case Instruction::Trunc: if (TLI->isTruncateFree(OpTy, Ty)) return TargetTransformInfo::TCC_Free; return TargetTransformInfo::TCC_Basic; case Instruction::ZExt: if (TLI->isZExtFree(OpTy, Ty)) return TargetTransformInfo::TCC_Free; return TargetTransformInfo::TCC_Basic; case Instruction::AddrSpaceCast: if (TLI->isFreeAddrSpaceCast(OpTy->getPointerAddressSpace(), Ty->getPointerAddressSpace())) return TargetTransformInfo::TCC_Free; return TargetTransformInfo::TCC_Basic; } return BaseT::getOperationCost(Opcode, Ty, OpTy); } unsigned getInliningThresholdMultiplier() { return 1; } int getInlinerVectorBonusPercent() { return 150; } void getUnrollingPreferences(Loop *L, ScalarEvolution &SE, TTI::UnrollingPreferences &UP) { // This unrolling functionality is target independent, but to provide some // motivation for its intended use, for x86: // According to the Intel 64 and IA-32 Architectures Optimization Reference // Manual, Intel Core models and later have a loop stream detector (and // associated uop queue) that can benefit from partial unrolling. // The relevant requirements are: // - The loop must have no more than 4 (8 for Nehalem and later) branches // taken, and none of them may be calls. // - The loop can have no more than 18 (28 for Nehalem and later) uops. // According to the Software Optimization Guide for AMD Family 15h // Processors, models 30h-4fh (Steamroller and later) have a loop predictor // and loop buffer which can benefit from partial unrolling. // The relevant requirements are: // - The loop must have fewer than 16 branches // - The loop must have less than 40 uops in all executed loop branches // The number of taken branches in a loop is hard to estimate here, and // benchmarking has revealed that it is better not to be conservative when // estimating the branch count. As a result, we'll ignore the branch limits // until someone finds a case where it matters in practice. unsigned MaxOps; const TargetSubtargetInfo *ST = getST(); if (PartialUnrollingThreshold.getNumOccurrences() > 0) MaxOps = PartialUnrollingThreshold; else if (ST->getSchedModel().LoopMicroOpBufferSize > 0) MaxOps = ST->getSchedModel().LoopMicroOpBufferSize; else return; // Scan the loop: don't unroll loops with calls. for (Loop::block_iterator I = L->block_begin(), E = L->block_end(); I != E; ++I) { BasicBlock *BB = *I; for (BasicBlock::iterator J = BB->begin(), JE = BB->end(); J != JE; ++J) if (isa(J) || isa(J)) { ImmutableCallSite CS(&*J); if (const Function *F = CS.getCalledFunction()) { if (!static_cast(this)->isLoweredToCall(F)) continue; } return; } } // Enable runtime and partial unrolling up to the specified size. // Enable using trip count upper bound to unroll loops. UP.Partial = UP.Runtime = UP.UpperBound = true; UP.PartialThreshold = MaxOps; // Avoid unrolling when optimizing for size. UP.OptSizeThreshold = 0; UP.PartialOptSizeThreshold = 0; // Set number of instructions optimized when "back edge" // becomes "fall through" to default value of 2. UP.BEInsns = 2; } bool isHardwareLoopProfitable(Loop *L, ScalarEvolution &SE, AssumptionCache &AC, TargetLibraryInfo *LibInfo, HardwareLoopInfo &HWLoopInfo) { return BaseT::isHardwareLoopProfitable(L, SE, AC, LibInfo, HWLoopInfo); } bool preferPredicateOverEpilogue(Loop *L, LoopInfo *LI, ScalarEvolution &SE, AssumptionCache &AC, TargetLibraryInfo *TLI, DominatorTree *DT, const LoopAccessInfo *LAI) { return BaseT::preferPredicateOverEpilogue(L, LI, SE, AC, TLI, DT, LAI); } int getInstructionLatency(const Instruction *I) { if (isa(I)) return getST()->getSchedModel().DefaultLoadLatency; return BaseT::getInstructionLatency(I); } virtual Optional getCacheSize(TargetTransformInfo::CacheLevel Level) const { return Optional( getST()->getCacheSize(static_cast(Level))); } virtual Optional getCacheAssociativity(TargetTransformInfo::CacheLevel Level) const { Optional TargetResult = getST()->getCacheAssociativity(static_cast(Level)); if (TargetResult) return TargetResult; return BaseT::getCacheAssociativity(Level); } virtual unsigned getCacheLineSize() const { return getST()->getCacheLineSize(); } virtual unsigned getPrefetchDistance() const { return getST()->getPrefetchDistance(); } virtual unsigned getMinPrefetchStride(unsigned NumMemAccesses, unsigned NumStridedMemAccesses, unsigned NumPrefetches, bool HasCall) const { return getST()->getMinPrefetchStride(NumMemAccesses, NumStridedMemAccesses, NumPrefetches, HasCall); } virtual unsigned getMaxPrefetchIterationsAhead() const { return getST()->getMaxPrefetchIterationsAhead(); } virtual bool enableWritePrefetching() const { return getST()->enableWritePrefetching(); } /// @} /// \name Vector TTI Implementations /// @{ unsigned getRegisterBitWidth(bool Vector) const { return 32; } /// Estimate the overhead of scalarizing an instruction. Insert and Extract /// are set if the result needs to be inserted and/or extracted from vectors. unsigned getScalarizationOverhead(Type *Ty, bool Insert, bool Extract) { auto *VTy = cast(Ty); unsigned Cost = 0; for (int i = 0, e = VTy->getNumElements(); i < e; ++i) { if (Insert) Cost += static_cast(this)->getVectorInstrCost( Instruction::InsertElement, VTy, i); if (Extract) Cost += static_cast(this)->getVectorInstrCost( Instruction::ExtractElement, VTy, i); } return Cost; } /// Estimate the overhead of scalarizing an instructions unique /// non-constant operands. The types of the arguments are ordinarily /// scalar, in which case the costs are multiplied with VF. unsigned getOperandsScalarizationOverhead(ArrayRef Args, unsigned VF) { unsigned Cost = 0; SmallPtrSet UniqueOperands; for (const Value *A : Args) { if (!isa(A) && UniqueOperands.insert(A).second) { Type *VecTy = nullptr; if (A->getType()->isVectorTy()) { VecTy = A->getType(); // If A is a vector operand, VF should be 1 or correspond to A. assert((VF == 1 || VF == cast(VecTy)->getNumElements()) && "Vector argument does not match VF"); } else VecTy = VectorType::get(A->getType(), VF); Cost += getScalarizationOverhead(VecTy, false, true); } } return Cost; } unsigned getScalarizationOverhead(Type *VecTy, ArrayRef Args) { unsigned Cost = 0; auto *VecVTy = cast(VecTy); Cost += getScalarizationOverhead(VecVTy, true, false); if (!Args.empty()) Cost += getOperandsScalarizationOverhead(Args, VecVTy->getNumElements()); else // When no information on arguments is provided, we add the cost // associated with one argument as a heuristic. Cost += getScalarizationOverhead(VecVTy, false, true); return Cost; } unsigned getMaxInterleaveFactor(unsigned VF) { return 1; } unsigned getArithmeticInstrCost( unsigned Opcode, Type *Ty, TTI::OperandValueKind Opd1Info = TTI::OK_AnyValue, TTI::OperandValueKind Opd2Info = TTI::OK_AnyValue, TTI::OperandValueProperties Opd1PropInfo = TTI::OP_None, TTI::OperandValueProperties Opd2PropInfo = TTI::OP_None, ArrayRef Args = ArrayRef(), const Instruction *CxtI = nullptr) { // Check if any of the operands are vector operands. const TargetLoweringBase *TLI = getTLI(); int ISD = TLI->InstructionOpcodeToISD(Opcode); assert(ISD && "Invalid opcode"); std::pair LT = TLI->getTypeLegalizationCost(DL, Ty); bool IsFloat = Ty->isFPOrFPVectorTy(); // Assume that floating point arithmetic operations cost twice as much as // integer operations. unsigned OpCost = (IsFloat ? 2 : 1); if (TLI->isOperationLegalOrPromote(ISD, LT.second)) { // The operation is legal. Assume it costs 1. // TODO: Once we have extract/insert subvector cost we need to use them. return LT.first * OpCost; } if (!TLI->isOperationExpand(ISD, LT.second)) { // If the operation is custom lowered, then assume that the code is twice // as expensive. return LT.first * 2 * OpCost; } // Else, assume that we need to scalarize this op. // TODO: If one of the types get legalized by splitting, handle this // similarly to what getCastInstrCost() does. if (auto *VTy = dyn_cast(Ty)) { unsigned Num = VTy->getNumElements(); unsigned Cost = static_cast(this)->getArithmeticInstrCost( Opcode, VTy->getScalarType()); // Return the cost of multiple scalar invocation plus the cost of // inserting and extracting the values. return getScalarizationOverhead(VTy, Args) + Num * Cost; } // We don't know anything about this scalar instruction. return OpCost; } unsigned getShuffleCost(TTI::ShuffleKind Kind, Type *Tp, int Index, Type *SubTp) { switch (Kind) { case TTI::SK_Broadcast: return getBroadcastShuffleOverhead(Tp); case TTI::SK_Select: case TTI::SK_Reverse: case TTI::SK_Transpose: case TTI::SK_PermuteSingleSrc: case TTI::SK_PermuteTwoSrc: return getPermuteShuffleOverhead(Tp); case TTI::SK_ExtractSubvector: return getExtractSubvectorOverhead(Tp, Index, SubTp); case TTI::SK_InsertSubvector: return getInsertSubvectorOverhead(Tp, Index, SubTp); } llvm_unreachable("Unknown TTI::ShuffleKind"); } unsigned getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, const Instruction *I = nullptr) { const TargetLoweringBase *TLI = getTLI(); int ISD = TLI->InstructionOpcodeToISD(Opcode); assert(ISD && "Invalid opcode"); std::pair SrcLT = TLI->getTypeLegalizationCost(DL, Src); std::pair DstLT = TLI->getTypeLegalizationCost(DL, Dst); unsigned SrcSize = SrcLT.second.getSizeInBits(); unsigned DstSize = DstLT.second.getSizeInBits(); switch (Opcode) { default: break; case Instruction::Trunc: // Check for NOOP conversions. if (TLI->isTruncateFree(SrcLT.second, DstLT.second)) return 0; LLVM_FALLTHROUGH; case Instruction::BitCast: // Bitcast between types that are legalized to the same type are free. if (SrcLT.first == DstLT.first && SrcSize == DstSize) return 0; break; case Instruction::ZExt: if (TLI->isZExtFree(SrcLT.second, DstLT.second)) return 0; break; case Instruction::AddrSpaceCast: if (TLI->isFreeAddrSpaceCast(Src->getPointerAddressSpace(), Dst->getPointerAddressSpace())) return 0; break; } // If this is a zext/sext of a load, return 0 if the corresponding // extending load exists on target. if ((Opcode == Instruction::ZExt || Opcode == Instruction::SExt) && I && isa(I->getOperand(0))) { EVT ExtVT = EVT::getEVT(Dst); EVT LoadVT = EVT::getEVT(Src); unsigned LType = ((Opcode == Instruction::ZExt) ? ISD::ZEXTLOAD : ISD::SEXTLOAD); if (TLI->isLoadExtLegal(LType, ExtVT, LoadVT)) return 0; } // If the cast is marked as legal (or promote) then assume low cost. if (SrcLT.first == DstLT.first && TLI->isOperationLegalOrPromote(ISD, DstLT.second)) return 1; // Handle scalar conversions. if (!Src->isVectorTy() && !Dst->isVectorTy()) { // Scalar bitcasts are usually free. if (Opcode == Instruction::BitCast) return 0; // Just check the op cost. If the operation is legal then assume it costs // 1. if (!TLI->isOperationExpand(ISD, DstLT.second)) return 1; // Assume that illegal scalar instruction are expensive. return 4; } // Check vector-to-vector casts. if (Dst->isVectorTy() && Src->isVectorTy()) { auto *SrcVTy = cast(Src); auto *DstVTy = cast(Dst); // If the cast is between same-sized registers, then the check is simple. if (SrcLT.first == DstLT.first && SrcLT.second.getSizeInBits() == DstLT.second.getSizeInBits()) { // Assume that Zext is done using AND. if (Opcode == Instruction::ZExt) return 1; // Assume that sext is done using SHL and SRA. if (Opcode == Instruction::SExt) return 2; // Just check the op cost. If the operation is legal then assume it // costs // 1 and multiply by the type-legalization overhead. if (!TLI->isOperationExpand(ISD, DstLT.second)) return SrcLT.first * 1; } // If we are legalizing by splitting, query the concrete TTI for the cost // of casting the original vector twice. We also need to factor in the // cost of the split itself. Count that as 1, to be consistent with // TLI->getTypeLegalizationCost(). if ((TLI->getTypeAction(Src->getContext(), TLI->getValueType(DL, Src)) == TargetLowering::TypeSplitVector || TLI->getTypeAction(Dst->getContext(), TLI->getValueType(DL, Dst)) == TargetLowering::TypeSplitVector) && SrcVTy->getNumElements() > 1 && DstVTy->getNumElements() > 1) { Type *SplitDst = VectorType::get(DstVTy->getElementType(), DstVTy->getNumElements() / 2); Type *SplitSrc = VectorType::get(SrcVTy->getElementType(), SrcVTy->getNumElements() / 2); T *TTI = static_cast(this); return TTI->getVectorSplitCost() + (2 * TTI->getCastInstrCost(Opcode, SplitDst, SplitSrc, I)); } // In other cases where the source or destination are illegal, assume // the operation will get scalarized. unsigned Num = DstVTy->getNumElements(); unsigned Cost = static_cast(this)->getCastInstrCost( Opcode, Dst->getScalarType(), Src->getScalarType(), I); // Return the cost of multiple scalar invocation plus the cost of // inserting and extracting the values. return getScalarizationOverhead(Dst, true, true) + Num * Cost; } // We already handled vector-to-vector and scalar-to-scalar conversions. // This // is where we handle bitcast between vectors and scalars. We need to assume // that the conversion is scalarized in one way or another. if (Opcode == Instruction::BitCast) // Illegal bitcasts are done by storing and loading from a stack slot. return (Src->isVectorTy() ? getScalarizationOverhead(Src, false, true) : 0) + (Dst->isVectorTy() ? getScalarizationOverhead(Dst, true, false) : 0); llvm_unreachable("Unhandled cast"); } unsigned getExtractWithExtendCost(unsigned Opcode, Type *Dst, VectorType *VecTy, unsigned Index) { return static_cast(this)->getVectorInstrCost( Instruction::ExtractElement, VecTy, Index) + static_cast(this)->getCastInstrCost(Opcode, Dst, VecTy->getElementType()); } unsigned getCFInstrCost(unsigned Opcode) { // Branches are assumed to be predicted. return 0; } unsigned getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, const Instruction *I) { const TargetLoweringBase *TLI = getTLI(); int ISD = TLI->InstructionOpcodeToISD(Opcode); assert(ISD && "Invalid opcode"); // Selects on vectors are actually vector selects. if (ISD == ISD::SELECT) { assert(CondTy && "CondTy must exist"); if (CondTy->isVectorTy()) ISD = ISD::VSELECT; } std::pair LT = TLI->getTypeLegalizationCost(DL, ValTy); if (!(ValTy->isVectorTy() && !LT.second.isVector()) && !TLI->isOperationExpand(ISD, LT.second)) { // The operation is legal. Assume it costs 1. Multiply // by the type-legalization overhead. return LT.first * 1; } // Otherwise, assume that the cast is scalarized. // TODO: If one of the types get legalized by splitting, handle this // similarly to what getCastInstrCost() does. if (auto *ValVTy = dyn_cast(ValTy)) { unsigned Num = ValVTy->getNumElements(); if (CondTy) CondTy = CondTy->getScalarType(); unsigned Cost = static_cast(this)->getCmpSelInstrCost( Opcode, ValVTy->getScalarType(), CondTy, I); // Return the cost of multiple scalar invocation plus the cost of // inserting and extracting the values. return getScalarizationOverhead(ValVTy, true, false) + Num * Cost; } // Unknown scalar opcode. return 1; } unsigned getVectorInstrCost(unsigned Opcode, Type *Val, unsigned Index) { std::pair LT = getTLI()->getTypeLegalizationCost(DL, Val->getScalarType()); return LT.first; } unsigned getMemoryOpCost(unsigned Opcode, Type *Src, MaybeAlign Alignment, unsigned AddressSpace, const Instruction *I = nullptr) { assert(!Src->isVoidTy() && "Invalid type"); std::pair LT = getTLI()->getTypeLegalizationCost(DL, Src); // Assuming that all loads of legal types cost 1. unsigned Cost = LT.first; if (Src->isVectorTy() && Src->getPrimitiveSizeInBits() < LT.second.getSizeInBits()) { // This is a vector load that legalizes to a larger type than the vector // itself. Unless the corresponding extending load or truncating store is // legal, then this will scalarize. TargetLowering::LegalizeAction LA = TargetLowering::Expand; EVT MemVT = getTLI()->getValueType(DL, Src); if (Opcode == Instruction::Store) LA = getTLI()->getTruncStoreAction(LT.second, MemVT); else LA = getTLI()->getLoadExtAction(ISD::EXTLOAD, LT.second, MemVT); if (LA != TargetLowering::Legal && LA != TargetLowering::Custom) { // This is a vector load/store for some illegal type that is scalarized. // We must account for the cost of building or decomposing the vector. Cost += getScalarizationOverhead(Src, Opcode != Instruction::Store, Opcode == Instruction::Store); } } return Cost; } unsigned getInterleavedMemoryOpCost(unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef Indices, unsigned Alignment, unsigned AddressSpace, bool UseMaskForCond = false, bool UseMaskForGaps = false) { auto *VT = cast(VecTy); unsigned NumElts = VT->getNumElements(); assert(Factor > 1 && NumElts % Factor == 0 && "Invalid interleave factor"); unsigned NumSubElts = NumElts / Factor; VectorType *SubVT = VectorType::get(VT->getElementType(), NumSubElts); // Firstly, the cost of load/store operation. unsigned Cost; if (UseMaskForCond || UseMaskForGaps) Cost = static_cast(this)->getMaskedMemoryOpCost( Opcode, VecTy, Alignment, AddressSpace); else Cost = static_cast(this)->getMemoryOpCost( Opcode, VecTy, MaybeAlign(Alignment), AddressSpace); // Legalize the vector type, and get the legalized and unlegalized type // sizes. MVT VecTyLT = getTLI()->getTypeLegalizationCost(DL, VecTy).second; unsigned VecTySize = static_cast(this)->getDataLayout().getTypeStoreSize(VecTy); unsigned VecTyLTSize = VecTyLT.getStoreSize(); // Return the ceiling of dividing A by B. auto ceil = [](unsigned A, unsigned B) { return (A + B - 1) / B; }; // Scale the cost of the memory operation by the fraction of legalized // instructions that will actually be used. We shouldn't account for the // cost of dead instructions since they will be removed. // // E.g., An interleaved load of factor 8: // %vec = load <16 x i64>, <16 x i64>* %ptr // %v0 = shufflevector %vec, undef, <0, 8> // // If <16 x i64> is legalized to 8 v2i64 loads, only 2 of the loads will be // used (those corresponding to elements [0:1] and [8:9] of the unlegalized // type). The other loads are unused. // // We only scale the cost of loads since interleaved store groups aren't // allowed to have gaps. if (Opcode == Instruction::Load && VecTySize > VecTyLTSize) { // The number of loads of a legal type it will take to represent a load // of the unlegalized vector type. unsigned NumLegalInsts = ceil(VecTySize, VecTyLTSize); // The number of elements of the unlegalized type that correspond to a // single legal instruction. unsigned NumEltsPerLegalInst = ceil(NumElts, NumLegalInsts); // Determine which legal instructions will be used. BitVector UsedInsts(NumLegalInsts, false); for (unsigned Index : Indices) for (unsigned Elt = 0; Elt < NumSubElts; ++Elt) UsedInsts.set((Index + Elt * Factor) / NumEltsPerLegalInst); // Scale the cost of the load by the fraction of legal instructions that // will be used. Cost *= UsedInsts.count() / NumLegalInsts; } // Then plus the cost of interleave operation. if (Opcode == Instruction::Load) { // The interleave cost is similar to extract sub vectors' elements // from the wide vector, and insert them into sub vectors. // // E.g. An interleaved load of factor 2 (with one member of index 0): // %vec = load <8 x i32>, <8 x i32>* %ptr // %v0 = shuffle %vec, undef, <0, 2, 4, 6> ; Index 0 // The cost is estimated as extract elements at 0, 2, 4, 6 from the // <8 x i32> vector and insert them into a <4 x i32> vector. assert(Indices.size() <= Factor && "Interleaved memory op has too many members"); for (unsigned Index : Indices) { assert(Index < Factor && "Invalid index for interleaved memory op"); // Extract elements from loaded vector for each sub vector. for (unsigned i = 0; i < NumSubElts; i++) Cost += static_cast(this)->getVectorInstrCost( Instruction::ExtractElement, VT, Index + i * Factor); } unsigned InsSubCost = 0; for (unsigned i = 0; i < NumSubElts; i++) InsSubCost += static_cast(this)->getVectorInstrCost( Instruction::InsertElement, SubVT, i); Cost += Indices.size() * InsSubCost; } else { // The interleave cost is extract all elements from sub vectors, and // insert them into the wide vector. // // E.g. An interleaved store of factor 2: // %v0_v1 = shuffle %v0, %v1, <0, 4, 1, 5, 2, 6, 3, 7> // store <8 x i32> %interleaved.vec, <8 x i32>* %ptr // The cost is estimated as extract all elements from both <4 x i32> // vectors and insert into the <8 x i32> vector. unsigned ExtSubCost = 0; for (unsigned i = 0; i < NumSubElts; i++) ExtSubCost += static_cast(this)->getVectorInstrCost( Instruction::ExtractElement, SubVT, i); Cost += ExtSubCost * Factor; for (unsigned i = 0; i < NumElts; i++) Cost += static_cast(this) ->getVectorInstrCost(Instruction::InsertElement, VT, i); } if (!UseMaskForCond) return Cost; Type *I8Type = Type::getInt8Ty(VT->getContext()); VectorType *MaskVT = VectorType::get(I8Type, NumElts); SubVT = VectorType::get(I8Type, NumSubElts); // The Mask shuffling cost is extract all the elements of the Mask // and insert each of them Factor times into the wide vector: // // E.g. an interleaved group with factor 3: // %mask = icmp ult <8 x i32> %vec1, %vec2 // %interleaved.mask = shufflevector <8 x i1> %mask, <8 x i1> undef, // <24 x i32> <0,0,0,1,1,1,2,2,2,3,3,3,4,4,4,5,5,5,6,6,6,7,7,7> // The cost is estimated as extract all mask elements from the <8xi1> mask // vector and insert them factor times into the <24xi1> shuffled mask // vector. for (unsigned i = 0; i < NumSubElts; i++) Cost += static_cast(this)->getVectorInstrCost( Instruction::ExtractElement, SubVT, i); for (unsigned i = 0; i < NumElts; i++) Cost += static_cast(this)->getVectorInstrCost( Instruction::InsertElement, MaskVT, i); // The Gaps mask is invariant and created outside the loop, therefore the // cost of creating it is not accounted for here. However if we have both // a MaskForGaps and some other mask that guards the execution of the // memory access, we need to account for the cost of And-ing the two masks // inside the loop. if (UseMaskForGaps) Cost += static_cast(this)->getArithmeticInstrCost( BinaryOperator::And, MaskVT); return Cost; } /// Get intrinsic cost based on arguments. unsigned getIntrinsicInstrCost(Intrinsic::ID IID, Type *RetTy, ArrayRef Args, FastMathFlags FMF, unsigned VF = 1, const Instruction *I = nullptr) { unsigned RetVF = (RetTy->isVectorTy() ? cast(RetTy)->getNumElements() : 1); assert((RetVF == 1 || VF == 1) && "VF > 1 and RetVF is a vector type"); auto *ConcreteTTI = static_cast(this); switch (IID) { default: { // Assume that we need to scalarize this intrinsic. SmallVector Types; for (Value *Op : Args) { Type *OpTy = Op->getType(); assert(VF == 1 || !OpTy->isVectorTy()); Types.push_back(VF == 1 ? OpTy : VectorType::get(OpTy, VF)); } if (VF > 1 && !RetTy->isVoidTy()) RetTy = VectorType::get(RetTy, VF); // Compute the scalarization overhead based on Args for a vector // intrinsic. A vectorizer will pass a scalar RetTy and VF > 1, while // CostModel will pass a vector RetTy and VF is 1. unsigned ScalarizationCost = std::numeric_limits::max(); if (RetVF > 1 || VF > 1) { ScalarizationCost = 0; if (!RetTy->isVoidTy()) ScalarizationCost += getScalarizationOverhead(RetTy, true, false); ScalarizationCost += getOperandsScalarizationOverhead(Args, VF); } return ConcreteTTI->getIntrinsicInstrCost(IID, RetTy, Types, FMF, ScalarizationCost); } case Intrinsic::masked_scatter: { assert(VF == 1 && "Can't vectorize types here."); Value *Mask = Args[3]; bool VarMask = !isa(Mask); unsigned Alignment = cast(Args[2])->getZExtValue(); return ConcreteTTI->getGatherScatterOpCost(Instruction::Store, Args[0]->getType(), Args[1], VarMask, Alignment, I); } case Intrinsic::masked_gather: { assert(VF == 1 && "Can't vectorize types here."); Value *Mask = Args[2]; bool VarMask = !isa(Mask); unsigned Alignment = cast(Args[1])->getZExtValue(); return ConcreteTTI->getGatherScatterOpCost( Instruction::Load, RetTy, Args[0], VarMask, Alignment, I); } case Intrinsic::experimental_vector_reduce_add: case Intrinsic::experimental_vector_reduce_mul: case Intrinsic::experimental_vector_reduce_and: case Intrinsic::experimental_vector_reduce_or: case Intrinsic::experimental_vector_reduce_xor: case Intrinsic::experimental_vector_reduce_v2_fadd: case Intrinsic::experimental_vector_reduce_v2_fmul: case Intrinsic::experimental_vector_reduce_smax: case Intrinsic::experimental_vector_reduce_smin: case Intrinsic::experimental_vector_reduce_fmax: case Intrinsic::experimental_vector_reduce_fmin: case Intrinsic::experimental_vector_reduce_umax: case Intrinsic::experimental_vector_reduce_umin: return getIntrinsicInstrCost(IID, RetTy, Args[0]->getType(), FMF); case Intrinsic::fshl: case Intrinsic::fshr: { Value *X = Args[0]; Value *Y = Args[1]; Value *Z = Args[2]; TTI::OperandValueProperties OpPropsX, OpPropsY, OpPropsZ, OpPropsBW; TTI::OperandValueKind OpKindX = TTI::getOperandInfo(X, OpPropsX); TTI::OperandValueKind OpKindY = TTI::getOperandInfo(Y, OpPropsY); TTI::OperandValueKind OpKindZ = TTI::getOperandInfo(Z, OpPropsZ); TTI::OperandValueKind OpKindBW = TTI::OK_UniformConstantValue; OpPropsBW = isPowerOf2_32(RetTy->getScalarSizeInBits()) ? TTI::OP_PowerOf2 : TTI::OP_None; // fshl: (X << (Z % BW)) | (Y >> (BW - (Z % BW))) // fshr: (X << (BW - (Z % BW))) | (Y >> (Z % BW)) unsigned Cost = 0; Cost += ConcreteTTI->getArithmeticInstrCost(BinaryOperator::Or, RetTy); Cost += ConcreteTTI->getArithmeticInstrCost(BinaryOperator::Sub, RetTy); Cost += ConcreteTTI->getArithmeticInstrCost(BinaryOperator::Shl, RetTy, OpKindX, OpKindZ, OpPropsX); Cost += ConcreteTTI->getArithmeticInstrCost(BinaryOperator::LShr, RetTy, OpKindY, OpKindZ, OpPropsY); // Non-constant shift amounts requires a modulo. if (OpKindZ != TTI::OK_UniformConstantValue && OpKindZ != TTI::OK_NonUniformConstantValue) Cost += ConcreteTTI->getArithmeticInstrCost(BinaryOperator::URem, RetTy, OpKindZ, OpKindBW, OpPropsZ, OpPropsBW); // For non-rotates (X != Y) we must add shift-by-zero handling costs. if (X != Y) { Type *CondTy = RetTy->getWithNewBitWidth(1); Cost += ConcreteTTI->getCmpSelInstrCost(BinaryOperator::ICmp, RetTy, CondTy, nullptr); Cost += ConcreteTTI->getCmpSelInstrCost(BinaryOperator::Select, RetTy, CondTy, nullptr); } return Cost; } } } /// Get intrinsic cost based on argument types. /// If ScalarizationCostPassed is std::numeric_limits::max(), the /// cost of scalarizing the arguments and the return value will be computed /// based on types. unsigned getIntrinsicInstrCost( Intrinsic::ID IID, Type *RetTy, ArrayRef Tys, FastMathFlags FMF, unsigned ScalarizationCostPassed = std::numeric_limits::max(), const Instruction *I = nullptr) { auto *ConcreteTTI = static_cast(this); SmallVector ISDs; unsigned SingleCallCost = 10; // Library call cost. Make it expensive. switch (IID) { default: { // Assume that we need to scalarize this intrinsic. unsigned ScalarizationCost = ScalarizationCostPassed; unsigned ScalarCalls = 1; Type *ScalarRetTy = RetTy; if (RetTy->isVectorTy()) { if (ScalarizationCostPassed == std::numeric_limits::max()) ScalarizationCost = getScalarizationOverhead(RetTy, true, false); - ScalarCalls = std::max( - ScalarCalls, (unsigned)cast(RetTy)->getNumElements()); + ScalarCalls = + std::max(ScalarCalls, cast(RetTy)->getNumElements()); ScalarRetTy = RetTy->getScalarType(); } SmallVector ScalarTys; for (unsigned i = 0, ie = Tys.size(); i != ie; ++i) { Type *Ty = Tys[i]; if (Ty->isVectorTy()) { if (ScalarizationCostPassed == std::numeric_limits::max()) ScalarizationCost += getScalarizationOverhead(Ty, false, true); - ScalarCalls = std::max( - ScalarCalls, (unsigned)cast(Ty)->getNumElements()); + ScalarCalls = + std::max(ScalarCalls, cast(Ty)->getNumElements()); Ty = Ty->getScalarType(); } ScalarTys.push_back(Ty); } if (ScalarCalls == 1) return 1; // Return cost of a scalar intrinsic. Assume it to be cheap. unsigned ScalarCost = ConcreteTTI->getIntrinsicInstrCost(IID, ScalarRetTy, ScalarTys, FMF); return ScalarCalls * ScalarCost + ScalarizationCost; } // Look for intrinsics that can be lowered directly or turned into a scalar // intrinsic call. case Intrinsic::sqrt: ISDs.push_back(ISD::FSQRT); break; case Intrinsic::sin: ISDs.push_back(ISD::FSIN); break; case Intrinsic::cos: ISDs.push_back(ISD::FCOS); break; case Intrinsic::exp: ISDs.push_back(ISD::FEXP); break; case Intrinsic::exp2: ISDs.push_back(ISD::FEXP2); break; case Intrinsic::log: ISDs.push_back(ISD::FLOG); break; case Intrinsic::log10: ISDs.push_back(ISD::FLOG10); break; case Intrinsic::log2: ISDs.push_back(ISD::FLOG2); break; case Intrinsic::fabs: ISDs.push_back(ISD::FABS); break; case Intrinsic::canonicalize: ISDs.push_back(ISD::FCANONICALIZE); break; case Intrinsic::minnum: ISDs.push_back(ISD::FMINNUM); if (FMF.noNaNs()) ISDs.push_back(ISD::FMINIMUM); break; case Intrinsic::maxnum: ISDs.push_back(ISD::FMAXNUM); if (FMF.noNaNs()) ISDs.push_back(ISD::FMAXIMUM); break; case Intrinsic::copysign: ISDs.push_back(ISD::FCOPYSIGN); break; case Intrinsic::floor: ISDs.push_back(ISD::FFLOOR); break; case Intrinsic::ceil: ISDs.push_back(ISD::FCEIL); break; case Intrinsic::trunc: ISDs.push_back(ISD::FTRUNC); break; case Intrinsic::nearbyint: ISDs.push_back(ISD::FNEARBYINT); break; case Intrinsic::rint: ISDs.push_back(ISD::FRINT); break; case Intrinsic::round: ISDs.push_back(ISD::FROUND); break; case Intrinsic::pow: ISDs.push_back(ISD::FPOW); break; case Intrinsic::fma: ISDs.push_back(ISD::FMA); break; case Intrinsic::fmuladd: ISDs.push_back(ISD::FMA); break; case Intrinsic::experimental_constrained_fmuladd: ISDs.push_back(ISD::STRICT_FMA); break; // FIXME: We should return 0 whenever getIntrinsicCost == TCC_Free. case Intrinsic::lifetime_start: case Intrinsic::lifetime_end: case Intrinsic::sideeffect: return 0; case Intrinsic::masked_store: return ConcreteTTI->getMaskedMemoryOpCost(Instruction::Store, Tys[0], 0, 0); case Intrinsic::masked_load: return ConcreteTTI->getMaskedMemoryOpCost(Instruction::Load, RetTy, 0, 0); case Intrinsic::experimental_vector_reduce_add: return ConcreteTTI->getArithmeticReductionCost(Instruction::Add, Tys[0], /*IsPairwiseForm=*/false); case Intrinsic::experimental_vector_reduce_mul: return ConcreteTTI->getArithmeticReductionCost(Instruction::Mul, Tys[0], /*IsPairwiseForm=*/false); case Intrinsic::experimental_vector_reduce_and: return ConcreteTTI->getArithmeticReductionCost(Instruction::And, Tys[0], /*IsPairwiseForm=*/false); case Intrinsic::experimental_vector_reduce_or: return ConcreteTTI->getArithmeticReductionCost(Instruction::Or, Tys[0], /*IsPairwiseForm=*/false); case Intrinsic::experimental_vector_reduce_xor: return ConcreteTTI->getArithmeticReductionCost(Instruction::Xor, Tys[0], /*IsPairwiseForm=*/false); case Intrinsic::experimental_vector_reduce_v2_fadd: return ConcreteTTI->getArithmeticReductionCost( Instruction::FAdd, Tys[0], /*IsPairwiseForm=*/false); // FIXME: Add new flag for cost of strict // reductions. case Intrinsic::experimental_vector_reduce_v2_fmul: return ConcreteTTI->getArithmeticReductionCost( Instruction::FMul, Tys[0], /*IsPairwiseForm=*/false); // FIXME: Add new flag for cost of strict // reductions. case Intrinsic::experimental_vector_reduce_smax: case Intrinsic::experimental_vector_reduce_smin: case Intrinsic::experimental_vector_reduce_fmax: case Intrinsic::experimental_vector_reduce_fmin: return ConcreteTTI->getMinMaxReductionCost( Tys[0], CmpInst::makeCmpResultType(Tys[0]), /*IsPairwiseForm=*/false, /*IsUnsigned=*/false); case Intrinsic::experimental_vector_reduce_umax: case Intrinsic::experimental_vector_reduce_umin: return ConcreteTTI->getMinMaxReductionCost( Tys[0], CmpInst::makeCmpResultType(Tys[0]), /*IsPairwiseForm=*/false, /*IsUnsigned=*/true); case Intrinsic::sadd_sat: case Intrinsic::ssub_sat: { Type *CondTy = RetTy->getWithNewBitWidth(1); Type *OpTy = StructType::create({RetTy, CondTy}); Intrinsic::ID OverflowOp = IID == Intrinsic::sadd_sat ? Intrinsic::sadd_with_overflow : Intrinsic::ssub_with_overflow; // SatMax -> Overflow && SumDiff < 0 // SatMin -> Overflow && SumDiff >= 0 unsigned Cost = 0; Cost += ConcreteTTI->getIntrinsicInstrCost( OverflowOp, OpTy, {RetTy, RetTy}, FMF, ScalarizationCostPassed); Cost += ConcreteTTI->getCmpSelInstrCost(BinaryOperator::ICmp, RetTy, CondTy, nullptr); Cost += 2 * ConcreteTTI->getCmpSelInstrCost(BinaryOperator::Select, RetTy, CondTy, nullptr); return Cost; } case Intrinsic::uadd_sat: case Intrinsic::usub_sat: { Type *CondTy = RetTy->getWithNewBitWidth(1); Type *OpTy = StructType::create({RetTy, CondTy}); Intrinsic::ID OverflowOp = IID == Intrinsic::uadd_sat ? Intrinsic::uadd_with_overflow : Intrinsic::usub_with_overflow; unsigned Cost = 0; Cost += ConcreteTTI->getIntrinsicInstrCost( OverflowOp, OpTy, {RetTy, RetTy}, FMF, ScalarizationCostPassed); Cost += ConcreteTTI->getCmpSelInstrCost(BinaryOperator::Select, RetTy, CondTy, nullptr); return Cost; } case Intrinsic::smul_fix: case Intrinsic::umul_fix: { unsigned ExtSize = RetTy->getScalarSizeInBits() * 2; Type *ExtTy = RetTy->getWithNewBitWidth(ExtSize); unsigned ExtOp = IID == Intrinsic::smul_fix ? Instruction::SExt : Instruction::ZExt; unsigned Cost = 0; Cost += 2 * ConcreteTTI->getCastInstrCost(ExtOp, ExtTy, RetTy); Cost += ConcreteTTI->getArithmeticInstrCost(Instruction::Mul, ExtTy); Cost += 2 * ConcreteTTI->getCastInstrCost(Instruction::Trunc, RetTy, ExtTy); Cost += ConcreteTTI->getArithmeticInstrCost(Instruction::LShr, RetTy, TTI::OK_AnyValue, TTI::OK_UniformConstantValue); Cost += ConcreteTTI->getArithmeticInstrCost(Instruction::Shl, RetTy, TTI::OK_AnyValue, TTI::OK_UniformConstantValue); Cost += ConcreteTTI->getArithmeticInstrCost(Instruction::Or, RetTy); return Cost; } case Intrinsic::sadd_with_overflow: case Intrinsic::ssub_with_overflow: { Type *SumTy = RetTy->getContainedType(0); Type *OverflowTy = RetTy->getContainedType(1); unsigned Opcode = IID == Intrinsic::sadd_with_overflow ? BinaryOperator::Add : BinaryOperator::Sub; // LHSSign -> LHS >= 0 // RHSSign -> RHS >= 0 // SumSign -> Sum >= 0 // // Add: // Overflow -> (LHSSign == RHSSign) && (LHSSign != SumSign) // Sub: // Overflow -> (LHSSign != RHSSign) && (LHSSign != SumSign) unsigned Cost = 0; Cost += ConcreteTTI->getArithmeticInstrCost(Opcode, SumTy); Cost += 3 * ConcreteTTI->getCmpSelInstrCost(BinaryOperator::ICmp, SumTy, OverflowTy, nullptr); Cost += 2 * ConcreteTTI->getCmpSelInstrCost( BinaryOperator::ICmp, OverflowTy, OverflowTy, nullptr); Cost += ConcreteTTI->getArithmeticInstrCost(BinaryOperator::And, OverflowTy); return Cost; } case Intrinsic::uadd_with_overflow: case Intrinsic::usub_with_overflow: { Type *SumTy = RetTy->getContainedType(0); Type *OverflowTy = RetTy->getContainedType(1); unsigned Opcode = IID == Intrinsic::uadd_with_overflow ? BinaryOperator::Add : BinaryOperator::Sub; unsigned Cost = 0; Cost += ConcreteTTI->getArithmeticInstrCost(Opcode, SumTy); Cost += ConcreteTTI->getCmpSelInstrCost(BinaryOperator::ICmp, SumTy, OverflowTy, nullptr); return Cost; } case Intrinsic::smul_with_overflow: case Intrinsic::umul_with_overflow: { Type *MulTy = RetTy->getContainedType(0); Type *OverflowTy = RetTy->getContainedType(1); unsigned ExtSize = MulTy->getScalarSizeInBits() * 2; Type *ExtTy = MulTy->getWithNewBitWidth(ExtSize); unsigned ExtOp = IID == Intrinsic::smul_fix ? Instruction::SExt : Instruction::ZExt; unsigned Cost = 0; Cost += 2 * ConcreteTTI->getCastInstrCost(ExtOp, ExtTy, MulTy); Cost += ConcreteTTI->getArithmeticInstrCost(Instruction::Mul, ExtTy); Cost += 2 * ConcreteTTI->getCastInstrCost(Instruction::Trunc, MulTy, ExtTy); Cost += ConcreteTTI->getArithmeticInstrCost(Instruction::LShr, MulTy, TTI::OK_AnyValue, TTI::OK_UniformConstantValue); if (IID == Intrinsic::smul_with_overflow) Cost += ConcreteTTI->getArithmeticInstrCost( Instruction::AShr, MulTy, TTI::OK_AnyValue, TTI::OK_UniformConstantValue); Cost += ConcreteTTI->getCmpSelInstrCost(BinaryOperator::ICmp, MulTy, OverflowTy, nullptr); return Cost; } case Intrinsic::ctpop: ISDs.push_back(ISD::CTPOP); // In case of legalization use TCC_Expensive. This is cheaper than a // library call but still not a cheap instruction. SingleCallCost = TargetTransformInfo::TCC_Expensive; break; // FIXME: ctlz, cttz, ... case Intrinsic::bswap: ISDs.push_back(ISD::BSWAP); break; case Intrinsic::bitreverse: ISDs.push_back(ISD::BITREVERSE); break; } const TargetLoweringBase *TLI = getTLI(); std::pair LT = TLI->getTypeLegalizationCost(DL, RetTy); SmallVector LegalCost; SmallVector CustomCost; for (unsigned ISD : ISDs) { if (TLI->isOperationLegalOrPromote(ISD, LT.second)) { if (IID == Intrinsic::fabs && LT.second.isFloatingPoint() && TLI->isFAbsFree(LT.second)) { return 0; } // The operation is legal. Assume it costs 1. // If the type is split to multiple registers, assume that there is some // overhead to this. // TODO: Once we have extract/insert subvector cost we need to use them. if (LT.first > 1) LegalCost.push_back(LT.first * 2); else LegalCost.push_back(LT.first * 1); } else if (!TLI->isOperationExpand(ISD, LT.second)) { // If the operation is custom lowered then assume // that the code is twice as expensive. CustomCost.push_back(LT.first * 2); } } auto MinLegalCostI = std::min_element(LegalCost.begin(), LegalCost.end()); if (MinLegalCostI != LegalCost.end()) return *MinLegalCostI; auto MinCustomCostI = std::min_element(CustomCost.begin(), CustomCost.end()); if (MinCustomCostI != CustomCost.end()) return *MinCustomCostI; // If we can't lower fmuladd into an FMA estimate the cost as a floating // point mul followed by an add. if (IID == Intrinsic::fmuladd) return ConcreteTTI->getArithmeticInstrCost(BinaryOperator::FMul, RetTy) + ConcreteTTI->getArithmeticInstrCost(BinaryOperator::FAdd, RetTy); if (IID == Intrinsic::experimental_constrained_fmuladd) return ConcreteTTI->getIntrinsicCost( Intrinsic::experimental_constrained_fmul, RetTy, Tys, nullptr) + ConcreteTTI->getIntrinsicCost( Intrinsic::experimental_constrained_fadd, RetTy, Tys, nullptr); // Else, assume that we need to scalarize this intrinsic. For math builtins // this will emit a costly libcall, adding call overhead and spills. Make it // very expensive. if (RetTy->isVectorTy()) { unsigned ScalarizationCost = ((ScalarizationCostPassed != std::numeric_limits::max()) ? ScalarizationCostPassed : getScalarizationOverhead(RetTy, true, false)); unsigned ScalarCalls = cast(RetTy)->getNumElements(); SmallVector ScalarTys; for (unsigned i = 0, ie = Tys.size(); i != ie; ++i) { Type *Ty = Tys[i]; if (Ty->isVectorTy()) Ty = Ty->getScalarType(); ScalarTys.push_back(Ty); } unsigned ScalarCost = ConcreteTTI->getIntrinsicInstrCost( IID, RetTy->getScalarType(), ScalarTys, FMF); for (unsigned i = 0, ie = Tys.size(); i != ie; ++i) { if (Tys[i]->isVectorTy()) { if (ScalarizationCostPassed == std::numeric_limits::max()) ScalarizationCost += getScalarizationOverhead(Tys[i], false, true); ScalarCalls = - std::max(ScalarCalls, - (unsigned)cast(Tys[i])->getNumElements()); + std::max(ScalarCalls, cast(Tys[i])->getNumElements()); } } return ScalarCalls * ScalarCost + ScalarizationCost; } // This is going to be turned into a library call, make it expensive. return SingleCallCost; } /// Compute a cost of the given call instruction. /// /// Compute the cost of calling function F with return type RetTy and /// argument types Tys. F might be nullptr, in this case the cost of an /// arbitrary call with the specified signature will be returned. /// This is used, for instance, when we estimate call of a vector /// counterpart of the given function. /// \param F Called function, might be nullptr. /// \param RetTy Return value types. /// \param Tys Argument types. /// \returns The cost of Call instruction. unsigned getCallInstrCost(Function *F, Type *RetTy, ArrayRef Tys) { return 10; } unsigned getNumberOfParts(Type *Tp) { std::pair LT = getTLI()->getTypeLegalizationCost(DL, Tp); return LT.first; } unsigned getAddressComputationCost(Type *Ty, ScalarEvolution *, const SCEV *) { return 0; } /// Try to calculate arithmetic and shuffle op costs for reduction operations. /// We're assuming that reduction operation are performing the following way: /// 1. Non-pairwise reduction /// %val1 = shufflevector %val, %undef, /// /// \----------------v-------------/ \----------v------------/ /// n/2 elements n/2 elements /// %red1 = op %val, val1 /// After this operation we have a vector %red1 where only the first n/2 /// elements are meaningful, the second n/2 elements are undefined and can be /// dropped. All other operations are actually working with the vector of /// length n/2, not n, though the real vector length is still n. /// %val2 = shufflevector %red1, %undef, /// /// \----------------v-------------/ \----------v------------/ /// n/4 elements 3*n/4 elements /// %red2 = op %red1, val2 - working with the vector of /// length n/2, the resulting vector has length n/4 etc. /// 2. Pairwise reduction: /// Everything is the same except for an additional shuffle operation which /// is used to produce operands for pairwise kind of reductions. /// %val1 = shufflevector %val, %undef, /// /// \-------------v----------/ \----------v------------/ /// n/2 elements n/2 elements /// %val2 = shufflevector %val, %undef, /// /// \-------------v----------/ \----------v------------/ /// n/2 elements n/2 elements /// %red1 = op %val1, val2 /// Again, the operation is performed on vector, but the resulting /// vector %red1 is vector. /// /// The cost model should take into account that the actual length of the /// vector is reduced on each iteration. unsigned getArithmeticReductionCost(unsigned Opcode, Type *Ty, bool IsPairwise) { assert(Ty->isVectorTy() && "Expect a vector type"); Type *ScalarTy = cast(Ty)->getElementType(); unsigned NumVecElts = cast(Ty)->getNumElements(); unsigned NumReduxLevels = Log2_32(NumVecElts); unsigned ArithCost = 0; unsigned ShuffleCost = 0; auto *ConcreteTTI = static_cast(this); std::pair LT = ConcreteTTI->getTLI()->getTypeLegalizationCost(DL, Ty); unsigned LongVectorCount = 0; unsigned MVTLen = LT.second.isVector() ? LT.second.getVectorNumElements() : 1; while (NumVecElts > MVTLen) { NumVecElts /= 2; Type *SubTy = VectorType::get(ScalarTy, NumVecElts); // Assume the pairwise shuffles add a cost. ShuffleCost += (IsPairwise + 1) * ConcreteTTI->getShuffleCost(TTI::SK_ExtractSubvector, Ty, NumVecElts, SubTy); ArithCost += ConcreteTTI->getArithmeticInstrCost(Opcode, SubTy); Ty = SubTy; ++LongVectorCount; } NumReduxLevels -= LongVectorCount; // The minimal length of the vector is limited by the real length of vector // operations performed on the current platform. That's why several final // reduction operations are performed on the vectors with the same // architecture-dependent length. // Non pairwise reductions need one shuffle per reduction level. Pairwise // reductions need two shuffles on every level, but the last one. On that // level one of the shuffles is <0, u, u, ...> which is identity. unsigned NumShuffles = NumReduxLevels; if (IsPairwise && NumReduxLevels >= 1) NumShuffles += NumReduxLevels - 1; ShuffleCost += NumShuffles * ConcreteTTI->getShuffleCost(TTI::SK_PermuteSingleSrc, Ty, 0, Ty); ArithCost += NumReduxLevels * ConcreteTTI->getArithmeticInstrCost(Opcode, Ty); return ShuffleCost + ArithCost + ConcreteTTI->getVectorInstrCost(Instruction::ExtractElement, Ty, 0); } /// Try to calculate op costs for min/max reduction operations. /// \param CondTy Conditional type for the Select instruction. unsigned getMinMaxReductionCost(Type *Ty, Type *CondTy, bool IsPairwise, bool) { assert(Ty->isVectorTy() && "Expect a vector type"); Type *ScalarTy = cast(Ty)->getElementType(); Type *ScalarCondTy = cast(CondTy)->getElementType(); unsigned NumVecElts = cast(Ty)->getNumElements(); unsigned NumReduxLevels = Log2_32(NumVecElts); unsigned CmpOpcode; if (Ty->isFPOrFPVectorTy()) { CmpOpcode = Instruction::FCmp; } else { assert(Ty->isIntOrIntVectorTy() && "expecting floating point or integer type for min/max reduction"); CmpOpcode = Instruction::ICmp; } unsigned MinMaxCost = 0; unsigned ShuffleCost = 0; auto *ConcreteTTI = static_cast(this); std::pair LT = ConcreteTTI->getTLI()->getTypeLegalizationCost(DL, Ty); unsigned LongVectorCount = 0; unsigned MVTLen = LT.second.isVector() ? LT.second.getVectorNumElements() : 1; while (NumVecElts > MVTLen) { NumVecElts /= 2; Type *SubTy = VectorType::get(ScalarTy, NumVecElts); CondTy = VectorType::get(ScalarCondTy, NumVecElts); // Assume the pairwise shuffles add a cost. ShuffleCost += (IsPairwise + 1) * ConcreteTTI->getShuffleCost(TTI::SK_ExtractSubvector, Ty, NumVecElts, SubTy); MinMaxCost += ConcreteTTI->getCmpSelInstrCost(CmpOpcode, SubTy, CondTy, nullptr) + ConcreteTTI->getCmpSelInstrCost(Instruction::Select, SubTy, CondTy, nullptr); Ty = SubTy; ++LongVectorCount; } NumReduxLevels -= LongVectorCount; // The minimal length of the vector is limited by the real length of vector // operations performed on the current platform. That's why several final // reduction opertions are perfomed on the vectors with the same // architecture-dependent length. // Non pairwise reductions need one shuffle per reduction level. Pairwise // reductions need two shuffles on every level, but the last one. On that // level one of the shuffles is <0, u, u, ...> which is identity. unsigned NumShuffles = NumReduxLevels; if (IsPairwise && NumReduxLevels >= 1) NumShuffles += NumReduxLevels - 1; ShuffleCost += NumShuffles * ConcreteTTI->getShuffleCost(TTI::SK_PermuteSingleSrc, Ty, 0, Ty); MinMaxCost += NumReduxLevels * (ConcreteTTI->getCmpSelInstrCost(CmpOpcode, Ty, CondTy, nullptr) + ConcreteTTI->getCmpSelInstrCost(Instruction::Select, Ty, CondTy, nullptr)); // The last min/max should be in vector registers and we counted it above. // So just need a single extractelement. return ShuffleCost + MinMaxCost + ConcreteTTI->getVectorInstrCost(Instruction::ExtractElement, Ty, 0); } unsigned getVectorSplitCost() { return 1; } /// @} }; /// Concrete BasicTTIImpl that can be used if no further customization /// is needed. class BasicTTIImpl : public BasicTTIImplBase { using BaseT = BasicTTIImplBase; friend class BasicTTIImplBase; const TargetSubtargetInfo *ST; const TargetLoweringBase *TLI; const TargetSubtargetInfo *getST() const { return ST; } const TargetLoweringBase *getTLI() const { return TLI; } public: explicit BasicTTIImpl(const TargetMachine *TM, const Function &F); }; } // end namespace llvm #endif // LLVM_CODEGEN_BASICTTIIMPL_H diff --git a/llvm/include/llvm/IR/DerivedTypes.h b/llvm/include/llvm/IR/DerivedTypes.h index 1ce5cedf8522..92017448fe0d 100644 --- a/llvm/include/llvm/IR/DerivedTypes.h +++ b/llvm/include/llvm/IR/DerivedTypes.h @@ -1,611 +1,611 @@ //===- llvm/DerivedTypes.h - Classes for handling data types ----*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file contains the declarations of classes that represent "derived // types". These are things like "arrays of x" or "structure of x, y, z" or // "function returning x taking (y,z) as parameters", etc... // // The implementations of these classes live in the Type.cpp file. // //===----------------------------------------------------------------------===// #ifndef LLVM_IR_DERIVEDTYPES_H #define LLVM_IR_DERIVEDTYPES_H #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/Type.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/TypeSize.h" #include #include namespace llvm { class Value; class APInt; class LLVMContext; /// Class to represent integer types. Note that this class is also used to /// represent the built-in integer types: Int1Ty, Int8Ty, Int16Ty, Int32Ty and /// Int64Ty. /// Integer representation type class IntegerType : public Type { friend class LLVMContextImpl; protected: explicit IntegerType(LLVMContext &C, unsigned NumBits) : Type(C, IntegerTyID){ setSubclassData(NumBits); } public: /// This enum is just used to hold constants we need for IntegerType. enum { MIN_INT_BITS = 1, ///< Minimum number of bits that can be specified MAX_INT_BITS = (1<<24)-1 ///< Maximum number of bits that can be specified ///< Note that bit width is stored in the Type classes SubclassData field ///< which has 24 bits. This yields a maximum bit width of 16,777,215 ///< bits. }; /// This static method is the primary way of constructing an IntegerType. /// If an IntegerType with the same NumBits value was previously instantiated, /// that instance will be returned. Otherwise a new one will be created. Only /// one instance with a given NumBits value is ever created. /// Get or create an IntegerType instance. static IntegerType *get(LLVMContext &C, unsigned NumBits); /// Returns type twice as wide the input type. IntegerType *getExtendedType() const { return Type::getIntNTy(getContext(), 2 * getScalarSizeInBits()); } /// Get the number of bits in this IntegerType unsigned getBitWidth() const { return getSubclassData(); } /// Return a bitmask with ones set for all of the bits that can be set by an /// unsigned version of this type. This is 0xFF for i8, 0xFFFF for i16, etc. uint64_t getBitMask() const { return ~uint64_t(0UL) >> (64-getBitWidth()); } /// Return a uint64_t with just the most significant bit set (the sign bit, if /// the value is treated as a signed number). uint64_t getSignBit() const { return 1ULL << (getBitWidth()-1); } /// For example, this is 0xFF for an 8 bit integer, 0xFFFF for i16, etc. /// @returns a bit mask with ones set for all the bits of this type. /// Get a bit mask for this type. APInt getMask() const; /// This method determines if the width of this IntegerType is a power-of-2 /// in terms of 8 bit bytes. /// @returns true if this is a power-of-2 byte width. /// Is this a power-of-2 byte-width IntegerType ? bool isPowerOf2ByteWidth() const; /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(const Type *T) { return T->getTypeID() == IntegerTyID; } }; unsigned Type::getIntegerBitWidth() const { return cast(this)->getBitWidth(); } /// Class to represent function types /// class FunctionType : public Type { FunctionType(Type *Result, ArrayRef Params, bool IsVarArgs); public: FunctionType(const FunctionType &) = delete; FunctionType &operator=(const FunctionType &) = delete; /// This static method is the primary way of constructing a FunctionType. static FunctionType *get(Type *Result, ArrayRef Params, bool isVarArg); /// Create a FunctionType taking no parameters. static FunctionType *get(Type *Result, bool isVarArg); /// Return true if the specified type is valid as a return type. static bool isValidReturnType(Type *RetTy); /// Return true if the specified type is valid as an argument type. static bool isValidArgumentType(Type *ArgTy); bool isVarArg() const { return getSubclassData()!=0; } Type *getReturnType() const { return ContainedTys[0]; } using param_iterator = Type::subtype_iterator; param_iterator param_begin() const { return ContainedTys + 1; } param_iterator param_end() const { return &ContainedTys[NumContainedTys]; } ArrayRef params() const { return makeArrayRef(param_begin(), param_end()); } /// Parameter type accessors. Type *getParamType(unsigned i) const { return ContainedTys[i+1]; } /// Return the number of fixed parameters this function type requires. /// This does not consider varargs. unsigned getNumParams() const { return NumContainedTys - 1; } /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(const Type *T) { return T->getTypeID() == FunctionTyID; } }; static_assert(alignof(FunctionType) >= alignof(Type *), "Alignment sufficient for objects appended to FunctionType"); bool Type::isFunctionVarArg() const { return cast(this)->isVarArg(); } Type *Type::getFunctionParamType(unsigned i) const { return cast(this)->getParamType(i); } unsigned Type::getFunctionNumParams() const { return cast(this)->getNumParams(); } /// A handy container for a FunctionType+Callee-pointer pair, which can be /// passed around as a single entity. This assists in replacing the use of /// PointerType::getElementType() to access the function's type, since that's /// slated for removal as part of the [opaque pointer types] project. class FunctionCallee { public: // Allow implicit conversion from types which have a getFunctionType member // (e.g. Function and InlineAsm). template FunctionCallee(T *Fn) : FnTy(Fn ? Fn->getFunctionType() : nullptr), Callee(Fn) {} FunctionCallee(FunctionType *FnTy, Value *Callee) : FnTy(FnTy), Callee(Callee) { assert((FnTy == nullptr) == (Callee == nullptr)); } FunctionCallee(std::nullptr_t) {} FunctionCallee() = default; FunctionType *getFunctionType() { return FnTy; } Value *getCallee() { return Callee; } explicit operator bool() { return Callee; } private: FunctionType *FnTy = nullptr; Value *Callee = nullptr; }; /// Class to represent struct types. There are two different kinds of struct /// types: Literal structs and Identified structs. /// /// Literal struct types (e.g. { i32, i32 }) are uniqued structurally, and must /// always have a body when created. You can get one of these by using one of /// the StructType::get() forms. /// /// Identified structs (e.g. %foo or %42) may optionally have a name and are not /// uniqued. The names for identified structs are managed at the LLVMContext /// level, so there can only be a single identified struct with a given name in /// a particular LLVMContext. Identified structs may also optionally be opaque /// (have no body specified). You get one of these by using one of the /// StructType::create() forms. /// /// Independent of what kind of struct you have, the body of a struct type are /// laid out in memory consecutively with the elements directly one after the /// other (if the struct is packed) or (if not packed) with padding between the /// elements as defined by DataLayout (which is required to match what the code /// generator for a target expects). /// class StructType : public Type { StructType(LLVMContext &C) : Type(C, StructTyID) {} enum { /// This is the contents of the SubClassData field. SCDB_HasBody = 1, SCDB_Packed = 2, SCDB_IsLiteral = 4, SCDB_IsSized = 8 }; /// For a named struct that actually has a name, this is a pointer to the /// symbol table entry (maintained by LLVMContext) for the struct. /// This is null if the type is an literal struct or if it is a identified /// type that has an empty name. void *SymbolTableEntry = nullptr; public: StructType(const StructType &) = delete; StructType &operator=(const StructType &) = delete; /// This creates an identified struct. static StructType *create(LLVMContext &Context, StringRef Name); static StructType *create(LLVMContext &Context); static StructType *create(ArrayRef Elements, StringRef Name, bool isPacked = false); static StructType *create(ArrayRef Elements); static StructType *create(LLVMContext &Context, ArrayRef Elements, StringRef Name, bool isPacked = false); static StructType *create(LLVMContext &Context, ArrayRef Elements); template static std::enable_if_t::value, StructType *> create(StringRef Name, Type *elt1, Tys *... elts) { assert(elt1 && "Cannot create a struct type with no elements with this"); SmallVector StructFields({elt1, elts...}); return create(StructFields, Name); } /// This static method is the primary way to create a literal StructType. static StructType *get(LLVMContext &Context, ArrayRef Elements, bool isPacked = false); /// Create an empty structure type. static StructType *get(LLVMContext &Context, bool isPacked = false); /// This static method is a convenience method for creating structure types by /// specifying the elements as arguments. Note that this method always returns /// a non-packed struct, and requires at least one element type. template static std::enable_if_t::value, StructType *> get(Type *elt1, Tys *... elts) { assert(elt1 && "Cannot create a struct type with no elements with this"); LLVMContext &Ctx = elt1->getContext(); SmallVector StructFields({elt1, elts...}); return llvm::StructType::get(Ctx, StructFields); } bool isPacked() const { return (getSubclassData() & SCDB_Packed) != 0; } /// Return true if this type is uniqued by structural equivalence, false if it /// is a struct definition. bool isLiteral() const { return (getSubclassData() & SCDB_IsLiteral) != 0; } /// Return true if this is a type with an identity that has no body specified /// yet. These prints as 'opaque' in .ll files. bool isOpaque() const { return (getSubclassData() & SCDB_HasBody) == 0; } /// isSized - Return true if this is a sized type. bool isSized(SmallPtrSetImpl *Visited = nullptr) const; /// Return true if this is a named struct that has a non-empty name. bool hasName() const { return SymbolTableEntry != nullptr; } /// Return the name for this struct type if it has an identity. /// This may return an empty string for an unnamed struct type. Do not call /// this on an literal type. StringRef getName() const; /// Change the name of this type to the specified name, or to a name with a /// suffix if there is a collision. Do not call this on an literal type. void setName(StringRef Name); /// Specify a body for an opaque identified type. void setBody(ArrayRef Elements, bool isPacked = false); template std::enable_if_t::value, void> setBody(Type *elt1, Tys *... elts) { assert(elt1 && "Cannot create a struct type with no elements with this"); SmallVector StructFields({elt1, elts...}); setBody(StructFields); } /// Return true if the specified type is valid as a element type. static bool isValidElementType(Type *ElemTy); // Iterator access to the elements. using element_iterator = Type::subtype_iterator; element_iterator element_begin() const { return ContainedTys; } element_iterator element_end() const { return &ContainedTys[NumContainedTys];} ArrayRef const elements() const { return makeArrayRef(element_begin(), element_end()); } /// Return true if this is layout identical to the specified struct. bool isLayoutIdentical(StructType *Other) const; /// Random access to the elements unsigned getNumElements() const { return NumContainedTys; } Type *getElementType(unsigned N) const { assert(N < NumContainedTys && "Element number out of range!"); return ContainedTys[N]; } /// Given an index value into the type, return the type of the element. Type *getTypeAtIndex(const Value *V) const; Type *getTypeAtIndex(unsigned N) const { return getElementType(N); } bool indexValid(const Value *V) const; bool indexValid(unsigned Idx) const { return Idx < getNumElements(); } /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(const Type *T) { return T->getTypeID() == StructTyID; } }; StringRef Type::getStructName() const { return cast(this)->getName(); } unsigned Type::getStructNumElements() const { return cast(this)->getNumElements(); } Type *Type::getStructElementType(unsigned N) const { return cast(this)->getElementType(N); } /// Class to represent array types. class ArrayType : public Type { /// The element type of the array. Type *ContainedType; /// Number of elements in the array. uint64_t NumElements; ArrayType(Type *ElType, uint64_t NumEl); public: ArrayType(const ArrayType &) = delete; ArrayType &operator=(const ArrayType &) = delete; uint64_t getNumElements() const { return NumElements; } Type *getElementType() const { return ContainedType; } /// This static method is the primary way to construct an ArrayType static ArrayType *get(Type *ElementType, uint64_t NumElements); /// Return true if the specified type is valid as a element type. static bool isValidElementType(Type *ElemTy); /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(const Type *T) { return T->getTypeID() == ArrayTyID; } }; uint64_t Type::getArrayNumElements() const { return cast(this)->getNumElements(); } /// Class to represent vector types. class VectorType : public Type { /// A fully specified VectorType is of the form . 'n' is the /// minimum number of elements of type Ty contained within the vector, and /// 'vscale x' indicates that the total element count is an integer multiple /// of 'n', where the multiple is either guaranteed to be one, or is /// statically unknown at compile time. /// /// If the multiple is known to be 1, then the extra term is discarded in /// textual IR: /// /// <4 x i32> - a vector containing 4 i32s /// - a vector containing an unknown integer multiple /// of 4 i32s /// The element type of the vector. Type *ContainedType; /// Minumum number of elements in the vector. uint64_t NumElements; VectorType(Type *ElType, unsigned NumEl, bool Scalable = false); VectorType(Type *ElType, ElementCount EC); // If true, the total number of elements is an unknown multiple of the // minimum 'NumElements'. Otherwise the total number of elements is exactly // equal to 'NumElements'. bool Scalable; public: VectorType(const VectorType &) = delete; VectorType &operator=(const VectorType &) = delete; /// For scalable vectors, this will return the minimum number of elements /// in the vector. - uint64_t getNumElements() const { return NumElements; } + unsigned getNumElements() const { return NumElements; } Type *getElementType() const { return ContainedType; } /// This static method is the primary way to construct an VectorType. static VectorType *get(Type *ElementType, ElementCount EC); static VectorType *get(Type *ElementType, unsigned NumElements, bool Scalable = false) { return VectorType::get(ElementType, {NumElements, Scalable}); } /// This static method gets a VectorType with the same number of elements as /// the input type, and the element type is an integer type of the same width /// as the input element type. static VectorType *getInteger(VectorType *VTy) { unsigned EltBits = VTy->getElementType()->getPrimitiveSizeInBits(); assert(EltBits && "Element size must be of a non-zero size"); Type *EltTy = IntegerType::get(VTy->getContext(), EltBits); return VectorType::get(EltTy, VTy->getElementCount()); } /// This static method is like getInteger except that the element types are /// twice as wide as the elements in the input type. static VectorType *getExtendedElementVectorType(VectorType *VTy) { assert(VTy->isIntOrIntVectorTy() && "VTy expected to be a vector of ints."); auto *EltTy = cast(VTy->getElementType()); return VectorType::get(EltTy->getExtendedType(), VTy->getElementCount()); } // This static method gets a VectorType with the same number of elements as // the input type, and the element type is an integer or float type which // is half as wide as the elements in the input type. static VectorType *getTruncatedElementVectorType(VectorType *VTy) { Type *EltTy; if (VTy->getElementType()->isFloatingPointTy()) { switch(VTy->getElementType()->getTypeID()) { case DoubleTyID: EltTy = Type::getFloatTy(VTy->getContext()); break; case FloatTyID: EltTy = Type::getHalfTy(VTy->getContext()); break; default: llvm_unreachable("Cannot create narrower fp vector element type"); } } else { unsigned EltBits = VTy->getElementType()->getPrimitiveSizeInBits(); assert((EltBits & 1) == 0 && "Cannot truncate vector element with odd bit-width"); EltTy = IntegerType::get(VTy->getContext(), EltBits / 2); } return VectorType::get(EltTy, VTy->getElementCount()); } // This static method returns a VectorType with a smaller number of elements // of a larger type than the input element type. For example, a <16 x i8> // subdivided twice would return <4 x i32> static VectorType *getSubdividedVectorType(VectorType *VTy, int NumSubdivs) { for (int i = 0; i < NumSubdivs; ++i) { VTy = VectorType::getDoubleElementsVectorType(VTy); VTy = VectorType::getTruncatedElementVectorType(VTy); } return VTy; } /// This static method returns a VectorType with half as many elements as the /// input type and the same element type. static VectorType *getHalfElementsVectorType(VectorType *VTy) { auto EltCnt = VTy->getElementCount(); assert ((EltCnt.Min & 1) == 0 && "Cannot halve vector with odd number of elements."); return VectorType::get(VTy->getElementType(), EltCnt/2); } /// This static method returns a VectorType with twice as many elements as the /// input type and the same element type. static VectorType *getDoubleElementsVectorType(VectorType *VTy) { auto EltCnt = VTy->getElementCount(); assert((VTy->getNumElements() * 2ull) <= UINT_MAX && "Too many elements in vector"); return VectorType::get(VTy->getElementType(), EltCnt*2); } /// Return true if the specified type is valid as a element type. static bool isValidElementType(Type *ElemTy); /// Return an ElementCount instance to represent the (possibly scalable) /// number of elements in the vector. ElementCount getElementCount() const { uint64_t MinimumEltCnt = getNumElements(); assert(MinimumEltCnt <= UINT_MAX && "Too many elements in vector"); return { (unsigned)MinimumEltCnt, Scalable }; } /// Returns whether or not this is a scalable vector (meaning the total /// element count is a multiple of the minimum). bool isScalable() const { return Scalable; } /// Return the minimum number of bits in the Vector type. /// Returns zero when the vector is a vector of pointers. unsigned getBitWidth() const { return getNumElements() * getElementType()->getPrimitiveSizeInBits(); } /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(const Type *T) { return T->getTypeID() == VectorTyID; } }; unsigned Type::getVectorNumElements() const { return cast(this)->getNumElements(); } bool Type::getVectorIsScalable() const { return cast(this)->isScalable(); } ElementCount Type::getVectorElementCount() const { return cast(this)->getElementCount(); } bool Type::isVectorTy() const { return isa(this); } /// Class to represent pointers. class PointerType : public Type { explicit PointerType(Type *ElType, unsigned AddrSpace); Type *PointeeTy; public: PointerType(const PointerType &) = delete; PointerType &operator=(const PointerType &) = delete; /// This constructs a pointer to an object of the specified type in a numbered /// address space. static PointerType *get(Type *ElementType, unsigned AddressSpace); /// This constructs a pointer to an object of the specified type in the /// generic address space (address space zero). static PointerType *getUnqual(Type *ElementType) { return PointerType::get(ElementType, 0); } Type *getElementType() const { return PointeeTy; } /// Return true if the specified type is valid as a element type. static bool isValidElementType(Type *ElemTy); /// Return true if we can load or store from a pointer to this type. static bool isLoadableOrStorableType(Type *ElemTy); /// Return the address space of the Pointer type. inline unsigned getAddressSpace() const { return getSubclassData(); } /// Implement support type inquiry through isa, cast, and dyn_cast. static bool classof(const Type *T) { return T->getTypeID() == PointerTyID; } }; Type *Type::getExtendedType() const { assert( isIntOrIntVectorTy() && "Original type expected to be a vector of integers or a scalar integer."); if (auto *VTy = dyn_cast(this)) return VectorType::getExtendedElementVectorType( const_cast(VTy)); return cast(this)->getExtendedType(); } Type *Type::getWithNewBitWidth(unsigned NewBitWidth) const { assert( isIntOrIntVectorTy() && "Original type expected to be a vector of integers or a scalar integer."); Type *NewType = getIntNTy(getContext(), NewBitWidth); if (isVectorTy()) NewType = VectorType::get(NewType, getVectorElementCount()); return NewType; } unsigned Type::getPointerAddressSpace() const { return cast(getScalarType())->getAddressSpace(); } } // end namespace llvm #endif // LLVM_IR_DERIVEDTYPES_H diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp index 7791d26e0379..ec390b36b3a9 100644 --- a/llvm/lib/IR/Function.cpp +++ b/llvm/lib/IR/Function.cpp @@ -1,1651 +1,1650 @@ //===- Function.cpp - Implement the Global object classes -----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements the Function class for the IR library. // //===----------------------------------------------------------------------===// #include "llvm/IR/Function.h" #include "SymbolTableListTraitsImpl.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/None.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/IR/Argument.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsAArch64.h" #include "llvm/IR/IntrinsicsAMDGPU.h" #include "llvm/IR/IntrinsicsARM.h" #include "llvm/IR/IntrinsicsBPF.h" #include "llvm/IR/IntrinsicsHexagon.h" #include "llvm/IR/IntrinsicsMips.h" #include "llvm/IR/IntrinsicsNVPTX.h" #include "llvm/IR/IntrinsicsPowerPC.h" #include "llvm/IR/IntrinsicsR600.h" #include "llvm/IR/IntrinsicsRISCV.h" #include "llvm/IR/IntrinsicsS390.h" #include "llvm/IR/IntrinsicsWebAssembly.h" #include "llvm/IR/IntrinsicsX86.h" #include "llvm/IR/IntrinsicsXCore.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/SymbolTableListTraits.h" #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/IR/ValueSymbolTable.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/ErrorHandling.h" #include #include #include #include #include #include using namespace llvm; using ProfileCount = Function::ProfileCount; // Explicit instantiations of SymbolTableListTraits since some of the methods // are not in the public header file... template class llvm::SymbolTableListTraits; //===----------------------------------------------------------------------===// // Argument Implementation //===----------------------------------------------------------------------===// Argument::Argument(Type *Ty, const Twine &Name, Function *Par, unsigned ArgNo) : Value(Ty, Value::ArgumentVal), Parent(Par), ArgNo(ArgNo) { setName(Name); } void Argument::setParent(Function *parent) { Parent = parent; } bool Argument::hasNonNullAttr() const { if (!getType()->isPointerTy()) return false; if (getParent()->hasParamAttribute(getArgNo(), Attribute::NonNull)) return true; else if (getDereferenceableBytes() > 0 && !NullPointerIsDefined(getParent(), getType()->getPointerAddressSpace())) return true; return false; } bool Argument::hasByValAttr() const { if (!getType()->isPointerTy()) return false; return hasAttribute(Attribute::ByVal); } bool Argument::hasSwiftSelfAttr() const { return getParent()->hasParamAttribute(getArgNo(), Attribute::SwiftSelf); } bool Argument::hasSwiftErrorAttr() const { return getParent()->hasParamAttribute(getArgNo(), Attribute::SwiftError); } bool Argument::hasInAllocaAttr() const { if (!getType()->isPointerTy()) return false; return hasAttribute(Attribute::InAlloca); } bool Argument::hasByValOrInAllocaAttr() const { if (!getType()->isPointerTy()) return false; AttributeList Attrs = getParent()->getAttributes(); return Attrs.hasParamAttribute(getArgNo(), Attribute::ByVal) || Attrs.hasParamAttribute(getArgNo(), Attribute::InAlloca); } unsigned Argument::getParamAlignment() const { assert(getType()->isPointerTy() && "Only pointers have alignments"); return getParent()->getParamAlignment(getArgNo()); } MaybeAlign Argument::getParamAlign() const { assert(getType()->isPointerTy() && "Only pointers have alignments"); return getParent()->getParamAlign(getArgNo()); } Type *Argument::getParamByValType() const { assert(getType()->isPointerTy() && "Only pointers have byval types"); return getParent()->getParamByValType(getArgNo()); } uint64_t Argument::getDereferenceableBytes() const { assert(getType()->isPointerTy() && "Only pointers have dereferenceable bytes"); return getParent()->getParamDereferenceableBytes(getArgNo()); } uint64_t Argument::getDereferenceableOrNullBytes() const { assert(getType()->isPointerTy() && "Only pointers have dereferenceable bytes"); return getParent()->getParamDereferenceableOrNullBytes(getArgNo()); } bool Argument::hasNestAttr() const { if (!getType()->isPointerTy()) return false; return hasAttribute(Attribute::Nest); } bool Argument::hasNoAliasAttr() const { if (!getType()->isPointerTy()) return false; return hasAttribute(Attribute::NoAlias); } bool Argument::hasNoCaptureAttr() const { if (!getType()->isPointerTy()) return false; return hasAttribute(Attribute::NoCapture); } bool Argument::hasStructRetAttr() const { if (!getType()->isPointerTy()) return false; return hasAttribute(Attribute::StructRet); } bool Argument::hasInRegAttr() const { return hasAttribute(Attribute::InReg); } bool Argument::hasReturnedAttr() const { return hasAttribute(Attribute::Returned); } bool Argument::hasZExtAttr() const { return hasAttribute(Attribute::ZExt); } bool Argument::hasSExtAttr() const { return hasAttribute(Attribute::SExt); } bool Argument::onlyReadsMemory() const { AttributeList Attrs = getParent()->getAttributes(); return Attrs.hasParamAttribute(getArgNo(), Attribute::ReadOnly) || Attrs.hasParamAttribute(getArgNo(), Attribute::ReadNone); } void Argument::addAttrs(AttrBuilder &B) { AttributeList AL = getParent()->getAttributes(); AL = AL.addParamAttributes(Parent->getContext(), getArgNo(), B); getParent()->setAttributes(AL); } void Argument::addAttr(Attribute::AttrKind Kind) { getParent()->addParamAttr(getArgNo(), Kind); } void Argument::addAttr(Attribute Attr) { getParent()->addParamAttr(getArgNo(), Attr); } void Argument::removeAttr(Attribute::AttrKind Kind) { getParent()->removeParamAttr(getArgNo(), Kind); } bool Argument::hasAttribute(Attribute::AttrKind Kind) const { return getParent()->hasParamAttribute(getArgNo(), Kind); } Attribute Argument::getAttribute(Attribute::AttrKind Kind) const { return getParent()->getParamAttribute(getArgNo(), Kind); } //===----------------------------------------------------------------------===// // Helper Methods in Function //===----------------------------------------------------------------------===// LLVMContext &Function::getContext() const { return getType()->getContext(); } unsigned Function::getInstructionCount() const { unsigned NumInstrs = 0; for (const BasicBlock &BB : BasicBlocks) NumInstrs += std::distance(BB.instructionsWithoutDebug().begin(), BB.instructionsWithoutDebug().end()); return NumInstrs; } Function *Function::Create(FunctionType *Ty, LinkageTypes Linkage, const Twine &N, Module &M) { return Create(Ty, Linkage, M.getDataLayout().getProgramAddressSpace(), N, &M); } void Function::removeFromParent() { getParent()->getFunctionList().remove(getIterator()); } void Function::eraseFromParent() { getParent()->getFunctionList().erase(getIterator()); } //===----------------------------------------------------------------------===// // Function Implementation //===----------------------------------------------------------------------===// static unsigned computeAddrSpace(unsigned AddrSpace, Module *M) { // If AS == -1 and we are passed a valid module pointer we place the function // in the program address space. Otherwise we default to AS0. if (AddrSpace == static_cast(-1)) return M ? M->getDataLayout().getProgramAddressSpace() : 0; return AddrSpace; } Function::Function(FunctionType *Ty, LinkageTypes Linkage, unsigned AddrSpace, const Twine &name, Module *ParentModule) : GlobalObject(Ty, Value::FunctionVal, OperandTraits::op_begin(this), 0, Linkage, name, computeAddrSpace(AddrSpace, ParentModule)), NumArgs(Ty->getNumParams()) { assert(FunctionType::isValidReturnType(getReturnType()) && "invalid return type"); setGlobalObjectSubClassData(0); // We only need a symbol table for a function if the context keeps value names if (!getContext().shouldDiscardValueNames()) SymTab = std::make_unique(); // If the function has arguments, mark them as lazily built. if (Ty->getNumParams()) setValueSubclassData(1); // Set the "has lazy arguments" bit. if (ParentModule) ParentModule->getFunctionList().push_back(this); HasLLVMReservedName = getName().startswith("llvm."); // Ensure intrinsics have the right parameter attributes. // Note, the IntID field will have been set in Value::setName if this function // name is a valid intrinsic ID. if (IntID) setAttributes(Intrinsic::getAttributes(getContext(), IntID)); } Function::~Function() { dropAllReferences(); // After this it is safe to delete instructions. // Delete all of the method arguments and unlink from symbol table... if (Arguments) clearArguments(); // Remove the function from the on-the-side GC table. clearGC(); } void Function::BuildLazyArguments() const { // Create the arguments vector, all arguments start out unnamed. auto *FT = getFunctionType(); if (NumArgs > 0) { Arguments = std::allocator().allocate(NumArgs); for (unsigned i = 0, e = NumArgs; i != e; ++i) { Type *ArgTy = FT->getParamType(i); assert(!ArgTy->isVoidTy() && "Cannot have void typed arguments!"); new (Arguments + i) Argument(ArgTy, "", const_cast(this), i); } } // Clear the lazy arguments bit. unsigned SDC = getSubclassDataFromValue(); SDC &= ~(1 << 0); const_cast(this)->setValueSubclassData(SDC); assert(!hasLazyArguments()); } static MutableArrayRef makeArgArray(Argument *Args, size_t Count) { return MutableArrayRef(Args, Count); } bool Function::isConstrainedFPIntrinsic() const { switch (getIntrinsicID()) { #define INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC) \ case Intrinsic::INTRINSIC: #include "llvm/IR/ConstrainedOps.def" return true; #undef INSTRUCTION default: return false; } } void Function::clearArguments() { for (Argument &A : makeArgArray(Arguments, NumArgs)) { A.setName(""); A.~Argument(); } std::allocator().deallocate(Arguments, NumArgs); Arguments = nullptr; } void Function::stealArgumentListFrom(Function &Src) { assert(isDeclaration() && "Expected no references to current arguments"); // Drop the current arguments, if any, and set the lazy argument bit. if (!hasLazyArguments()) { assert(llvm::all_of(makeArgArray(Arguments, NumArgs), [](const Argument &A) { return A.use_empty(); }) && "Expected arguments to be unused in declaration"); clearArguments(); setValueSubclassData(getSubclassDataFromValue() | (1 << 0)); } // Nothing to steal if Src has lazy arguments. if (Src.hasLazyArguments()) return; // Steal arguments from Src, and fix the lazy argument bits. assert(arg_size() == Src.arg_size()); Arguments = Src.Arguments; Src.Arguments = nullptr; for (Argument &A : makeArgArray(Arguments, NumArgs)) { // FIXME: This does the work of transferNodesFromList inefficiently. SmallString<128> Name; if (A.hasName()) Name = A.getName(); if (!Name.empty()) A.setName(""); A.setParent(this); if (!Name.empty()) A.setName(Name); } setValueSubclassData(getSubclassDataFromValue() & ~(1 << 0)); assert(!hasLazyArguments()); Src.setValueSubclassData(Src.getSubclassDataFromValue() | (1 << 0)); } // dropAllReferences() - This function causes all the subinstructions to "let // go" of all references that they are maintaining. This allows one to // 'delete' a whole class at a time, even though there may be circular // references... first all references are dropped, and all use counts go to // zero. Then everything is deleted for real. Note that no operations are // valid on an object that has "dropped all references", except operator // delete. // void Function::dropAllReferences() { setIsMaterializable(false); for (BasicBlock &BB : *this) BB.dropAllReferences(); // Delete all basic blocks. They are now unused, except possibly by // blockaddresses, but BasicBlock's destructor takes care of those. while (!BasicBlocks.empty()) BasicBlocks.begin()->eraseFromParent(); // Drop uses of any optional data (real or placeholder). if (getNumOperands()) { User::dropAllReferences(); setNumHungOffUseOperands(0); setValueSubclassData(getSubclassDataFromValue() & ~0xe); } // Metadata is stored in a side-table. clearMetadata(); } void Function::addAttribute(unsigned i, Attribute::AttrKind Kind) { AttributeList PAL = getAttributes(); PAL = PAL.addAttribute(getContext(), i, Kind); setAttributes(PAL); } void Function::addAttribute(unsigned i, Attribute Attr) { AttributeList PAL = getAttributes(); PAL = PAL.addAttribute(getContext(), i, Attr); setAttributes(PAL); } void Function::addAttributes(unsigned i, const AttrBuilder &Attrs) { AttributeList PAL = getAttributes(); PAL = PAL.addAttributes(getContext(), i, Attrs); setAttributes(PAL); } void Function::addParamAttr(unsigned ArgNo, Attribute::AttrKind Kind) { AttributeList PAL = getAttributes(); PAL = PAL.addParamAttribute(getContext(), ArgNo, Kind); setAttributes(PAL); } void Function::addParamAttr(unsigned ArgNo, Attribute Attr) { AttributeList PAL = getAttributes(); PAL = PAL.addParamAttribute(getContext(), ArgNo, Attr); setAttributes(PAL); } void Function::addParamAttrs(unsigned ArgNo, const AttrBuilder &Attrs) { AttributeList PAL = getAttributes(); PAL = PAL.addParamAttributes(getContext(), ArgNo, Attrs); setAttributes(PAL); } void Function::removeAttribute(unsigned i, Attribute::AttrKind Kind) { AttributeList PAL = getAttributes(); PAL = PAL.removeAttribute(getContext(), i, Kind); setAttributes(PAL); } void Function::removeAttribute(unsigned i, StringRef Kind) { AttributeList PAL = getAttributes(); PAL = PAL.removeAttribute(getContext(), i, Kind); setAttributes(PAL); } void Function::removeAttributes(unsigned i, const AttrBuilder &Attrs) { AttributeList PAL = getAttributes(); PAL = PAL.removeAttributes(getContext(), i, Attrs); setAttributes(PAL); } void Function::removeParamAttr(unsigned ArgNo, Attribute::AttrKind Kind) { AttributeList PAL = getAttributes(); PAL = PAL.removeParamAttribute(getContext(), ArgNo, Kind); setAttributes(PAL); } void Function::removeParamAttr(unsigned ArgNo, StringRef Kind) { AttributeList PAL = getAttributes(); PAL = PAL.removeParamAttribute(getContext(), ArgNo, Kind); setAttributes(PAL); } void Function::removeParamAttrs(unsigned ArgNo, const AttrBuilder &Attrs) { AttributeList PAL = getAttributes(); PAL = PAL.removeParamAttributes(getContext(), ArgNo, Attrs); setAttributes(PAL); } void Function::addDereferenceableAttr(unsigned i, uint64_t Bytes) { AttributeList PAL = getAttributes(); PAL = PAL.addDereferenceableAttr(getContext(), i, Bytes); setAttributes(PAL); } void Function::addDereferenceableParamAttr(unsigned ArgNo, uint64_t Bytes) { AttributeList PAL = getAttributes(); PAL = PAL.addDereferenceableParamAttr(getContext(), ArgNo, Bytes); setAttributes(PAL); } void Function::addDereferenceableOrNullAttr(unsigned i, uint64_t Bytes) { AttributeList PAL = getAttributes(); PAL = PAL.addDereferenceableOrNullAttr(getContext(), i, Bytes); setAttributes(PAL); } void Function::addDereferenceableOrNullParamAttr(unsigned ArgNo, uint64_t Bytes) { AttributeList PAL = getAttributes(); PAL = PAL.addDereferenceableOrNullParamAttr(getContext(), ArgNo, Bytes); setAttributes(PAL); } const std::string &Function::getGC() const { assert(hasGC() && "Function has no collector"); return getContext().getGC(*this); } void Function::setGC(std::string Str) { setValueSubclassDataBit(14, !Str.empty()); getContext().setGC(*this, std::move(Str)); } void Function::clearGC() { if (!hasGC()) return; getContext().deleteGC(*this); setValueSubclassDataBit(14, false); } /// Copy all additional attributes (those not needed to create a Function) from /// the Function Src to this one. void Function::copyAttributesFrom(const Function *Src) { GlobalObject::copyAttributesFrom(Src); setCallingConv(Src->getCallingConv()); setAttributes(Src->getAttributes()); if (Src->hasGC()) setGC(Src->getGC()); else clearGC(); if (Src->hasPersonalityFn()) setPersonalityFn(Src->getPersonalityFn()); if (Src->hasPrefixData()) setPrefixData(Src->getPrefixData()); if (Src->hasPrologueData()) setPrologueData(Src->getPrologueData()); } /// Table of string intrinsic names indexed by enum value. static const char * const IntrinsicNameTable[] = { "not_intrinsic", #define GET_INTRINSIC_NAME_TABLE #include "llvm/IR/IntrinsicImpl.inc" #undef GET_INTRINSIC_NAME_TABLE }; /// Table of per-target intrinsic name tables. #define GET_INTRINSIC_TARGET_DATA #include "llvm/IR/IntrinsicImpl.inc" #undef GET_INTRINSIC_TARGET_DATA /// Find the segment of \c IntrinsicNameTable for intrinsics with the same /// target as \c Name, or the generic table if \c Name is not target specific. /// /// Returns the relevant slice of \c IntrinsicNameTable static ArrayRef findTargetSubtable(StringRef Name) { assert(Name.startswith("llvm.")); ArrayRef Targets(TargetInfos); // Drop "llvm." and take the first dotted component. That will be the target // if this is target specific. StringRef Target = Name.drop_front(5).split('.').first; auto It = partition_point( Targets, [=](const IntrinsicTargetInfo &TI) { return TI.Name < Target; }); // We've either found the target or just fall back to the generic set, which // is always first. const auto &TI = It != Targets.end() && It->Name == Target ? *It : Targets[0]; return makeArrayRef(&IntrinsicNameTable[1] + TI.Offset, TI.Count); } /// This does the actual lookup of an intrinsic ID which /// matches the given function name. Intrinsic::ID Function::lookupIntrinsicID(StringRef Name) { ArrayRef NameTable = findTargetSubtable(Name); int Idx = Intrinsic::lookupLLVMIntrinsicByName(NameTable, Name); if (Idx == -1) return Intrinsic::not_intrinsic; // Intrinsic IDs correspond to the location in IntrinsicNameTable, but we have // an index into a sub-table. int Adjust = NameTable.data() - IntrinsicNameTable; Intrinsic::ID ID = static_cast(Idx + Adjust); // If the intrinsic is not overloaded, require an exact match. If it is // overloaded, require either exact or prefix match. const auto MatchSize = strlen(NameTable[Idx]); assert(Name.size() >= MatchSize && "Expected either exact or prefix match"); bool IsExactMatch = Name.size() == MatchSize; return IsExactMatch || Intrinsic::isOverloaded(ID) ? ID : Intrinsic::not_intrinsic; } void Function::recalculateIntrinsicID() { StringRef Name = getName(); if (!Name.startswith("llvm.")) { HasLLVMReservedName = false; IntID = Intrinsic::not_intrinsic; return; } HasLLVMReservedName = true; IntID = lookupIntrinsicID(Name); } /// Returns a stable mangling for the type specified for use in the name /// mangling scheme used by 'any' types in intrinsic signatures. The mangling /// of named types is simply their name. Manglings for unnamed types consist /// of a prefix ('p' for pointers, 'a' for arrays, 'f_' for functions) /// combined with the mangling of their component types. A vararg function /// type will have a suffix of 'vararg'. Since function types can contain /// other function types, we close a function type mangling with suffix 'f' /// which can't be confused with it's prefix. This ensures we don't have /// collisions between two unrelated function types. Otherwise, you might /// parse ffXX as f(fXX) or f(fX)X. (X is a placeholder for any other type.) /// static std::string getMangledTypeStr(Type* Ty) { std::string Result; if (PointerType* PTyp = dyn_cast(Ty)) { Result += "p" + utostr(PTyp->getAddressSpace()) + getMangledTypeStr(PTyp->getElementType()); } else if (ArrayType* ATyp = dyn_cast(Ty)) { Result += "a" + utostr(ATyp->getNumElements()) + getMangledTypeStr(ATyp->getElementType()); } else if (StructType *STyp = dyn_cast(Ty)) { if (!STyp->isLiteral()) { Result += "s_"; Result += STyp->getName(); } else { Result += "sl_"; for (auto Elem : STyp->elements()) Result += getMangledTypeStr(Elem); } // Ensure nested structs are distinguishable. Result += "s"; } else if (FunctionType *FT = dyn_cast(Ty)) { Result += "f_" + getMangledTypeStr(FT->getReturnType()); for (size_t i = 0; i < FT->getNumParams(); i++) Result += getMangledTypeStr(FT->getParamType(i)); if (FT->isVarArg()) Result += "vararg"; // Ensure nested function types are distinguishable. Result += "f"; } else if (VectorType* VTy = dyn_cast(Ty)) { if (VTy->isScalable()) Result += "nx"; Result += "v" + utostr(VTy->getNumElements()) + getMangledTypeStr(VTy->getElementType()); } else if (Ty) { switch (Ty->getTypeID()) { default: llvm_unreachable("Unhandled type"); case Type::VoidTyID: Result += "isVoid"; break; case Type::MetadataTyID: Result += "Metadata"; break; case Type::HalfTyID: Result += "f16"; break; case Type::FloatTyID: Result += "f32"; break; case Type::DoubleTyID: Result += "f64"; break; case Type::X86_FP80TyID: Result += "f80"; break; case Type::FP128TyID: Result += "f128"; break; case Type::PPC_FP128TyID: Result += "ppcf128"; break; case Type::X86_MMXTyID: Result += "x86mmx"; break; case Type::IntegerTyID: Result += "i" + utostr(cast(Ty)->getBitWidth()); break; } } return Result; } StringRef Intrinsic::getName(ID id) { assert(id < num_intrinsics && "Invalid intrinsic ID!"); assert(!Intrinsic::isOverloaded(id) && "This version of getName does not support overloading"); return IntrinsicNameTable[id]; } std::string Intrinsic::getName(ID id, ArrayRef Tys) { assert(id < num_intrinsics && "Invalid intrinsic ID!"); std::string Result(IntrinsicNameTable[id]); for (Type *Ty : Tys) { Result += "." + getMangledTypeStr(Ty); } return Result; } /// IIT_Info - These are enumerators that describe the entries returned by the /// getIntrinsicInfoTableEntries function. /// /// NOTE: This must be kept in synch with the copy in TblGen/IntrinsicEmitter! enum IIT_Info { // Common values should be encoded with 0-15. IIT_Done = 0, IIT_I1 = 1, IIT_I8 = 2, IIT_I16 = 3, IIT_I32 = 4, IIT_I64 = 5, IIT_F16 = 6, IIT_F32 = 7, IIT_F64 = 8, IIT_V2 = 9, IIT_V4 = 10, IIT_V8 = 11, IIT_V16 = 12, IIT_V32 = 13, IIT_PTR = 14, IIT_ARG = 15, // Values from 16+ are only encodable with the inefficient encoding. IIT_V64 = 16, IIT_MMX = 17, IIT_TOKEN = 18, IIT_METADATA = 19, IIT_EMPTYSTRUCT = 20, IIT_STRUCT2 = 21, IIT_STRUCT3 = 22, IIT_STRUCT4 = 23, IIT_STRUCT5 = 24, IIT_EXTEND_ARG = 25, IIT_TRUNC_ARG = 26, IIT_ANYPTR = 27, IIT_V1 = 28, IIT_VARARG = 29, IIT_HALF_VEC_ARG = 30, IIT_SAME_VEC_WIDTH_ARG = 31, IIT_PTR_TO_ARG = 32, IIT_PTR_TO_ELT = 33, IIT_VEC_OF_ANYPTRS_TO_ELT = 34, IIT_I128 = 35, IIT_V512 = 36, IIT_V1024 = 37, IIT_STRUCT6 = 38, IIT_STRUCT7 = 39, IIT_STRUCT8 = 40, IIT_F128 = 41, IIT_VEC_ELEMENT = 42, IIT_SCALABLE_VEC = 43, IIT_SUBDIVIDE2_ARG = 44, IIT_SUBDIVIDE4_ARG = 45, IIT_VEC_OF_BITCASTS_TO_INT = 46, IIT_V128 = 47 }; static void DecodeIITType(unsigned &NextElt, ArrayRef Infos, SmallVectorImpl &OutputTable) { using namespace Intrinsic; IIT_Info Info = IIT_Info(Infos[NextElt++]); unsigned StructElts = 2; switch (Info) { case IIT_Done: OutputTable.push_back(IITDescriptor::get(IITDescriptor::Void, 0)); return; case IIT_VARARG: OutputTable.push_back(IITDescriptor::get(IITDescriptor::VarArg, 0)); return; case IIT_MMX: OutputTable.push_back(IITDescriptor::get(IITDescriptor::MMX, 0)); return; case IIT_TOKEN: OutputTable.push_back(IITDescriptor::get(IITDescriptor::Token, 0)); return; case IIT_METADATA: OutputTable.push_back(IITDescriptor::get(IITDescriptor::Metadata, 0)); return; case IIT_F16: OutputTable.push_back(IITDescriptor::get(IITDescriptor::Half, 0)); return; case IIT_F32: OutputTable.push_back(IITDescriptor::get(IITDescriptor::Float, 0)); return; case IIT_F64: OutputTable.push_back(IITDescriptor::get(IITDescriptor::Double, 0)); return; case IIT_F128: OutputTable.push_back(IITDescriptor::get(IITDescriptor::Quad, 0)); return; case IIT_I1: OutputTable.push_back(IITDescriptor::get(IITDescriptor::Integer, 1)); return; case IIT_I8: OutputTable.push_back(IITDescriptor::get(IITDescriptor::Integer, 8)); return; case IIT_I16: OutputTable.push_back(IITDescriptor::get(IITDescriptor::Integer,16)); return; case IIT_I32: OutputTable.push_back(IITDescriptor::get(IITDescriptor::Integer, 32)); return; case IIT_I64: OutputTable.push_back(IITDescriptor::get(IITDescriptor::Integer, 64)); return; case IIT_I128: OutputTable.push_back(IITDescriptor::get(IITDescriptor::Integer, 128)); return; case IIT_V1: OutputTable.push_back(IITDescriptor::get(IITDescriptor::Vector, 1)); DecodeIITType(NextElt, Infos, OutputTable); return; case IIT_V2: OutputTable.push_back(IITDescriptor::get(IITDescriptor::Vector, 2)); DecodeIITType(NextElt, Infos, OutputTable); return; case IIT_V4: OutputTable.push_back(IITDescriptor::get(IITDescriptor::Vector, 4)); DecodeIITType(NextElt, Infos, OutputTable); return; case IIT_V8: OutputTable.push_back(IITDescriptor::get(IITDescriptor::Vector, 8)); DecodeIITType(NextElt, Infos, OutputTable); return; case IIT_V16: OutputTable.push_back(IITDescriptor::get(IITDescriptor::Vector, 16)); DecodeIITType(NextElt, Infos, OutputTable); return; case IIT_V32: OutputTable.push_back(IITDescriptor::get(IITDescriptor::Vector, 32)); DecodeIITType(NextElt, Infos, OutputTable); return; case IIT_V64: OutputTable.push_back(IITDescriptor::get(IITDescriptor::Vector, 64)); DecodeIITType(NextElt, Infos, OutputTable); return; case IIT_V128: OutputTable.push_back(IITDescriptor::get(IITDescriptor::Vector, 128)); DecodeIITType(NextElt, Infos, OutputTable); return; case IIT_V512: OutputTable.push_back(IITDescriptor::get(IITDescriptor::Vector, 512)); DecodeIITType(NextElt, Infos, OutputTable); return; case IIT_V1024: OutputTable.push_back(IITDescriptor::get(IITDescriptor::Vector, 1024)); DecodeIITType(NextElt, Infos, OutputTable); return; case IIT_PTR: OutputTable.push_back(IITDescriptor::get(IITDescriptor::Pointer, 0)); DecodeIITType(NextElt, Infos, OutputTable); return; case IIT_ANYPTR: { // [ANYPTR addrspace, subtype] OutputTable.push_back(IITDescriptor::get(IITDescriptor::Pointer, Infos[NextElt++])); DecodeIITType(NextElt, Infos, OutputTable); return; } case IIT_ARG: { unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]); OutputTable.push_back(IITDescriptor::get(IITDescriptor::Argument, ArgInfo)); return; } case IIT_EXTEND_ARG: { unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]); OutputTable.push_back(IITDescriptor::get(IITDescriptor::ExtendArgument, ArgInfo)); return; } case IIT_TRUNC_ARG: { unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]); OutputTable.push_back(IITDescriptor::get(IITDescriptor::TruncArgument, ArgInfo)); return; } case IIT_HALF_VEC_ARG: { unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]); OutputTable.push_back(IITDescriptor::get(IITDescriptor::HalfVecArgument, ArgInfo)); return; } case IIT_SAME_VEC_WIDTH_ARG: { unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]); OutputTable.push_back(IITDescriptor::get(IITDescriptor::SameVecWidthArgument, ArgInfo)); return; } case IIT_PTR_TO_ARG: { unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]); OutputTable.push_back(IITDescriptor::get(IITDescriptor::PtrToArgument, ArgInfo)); return; } case IIT_PTR_TO_ELT: { unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]); OutputTable.push_back(IITDescriptor::get(IITDescriptor::PtrToElt, ArgInfo)); return; } case IIT_VEC_OF_ANYPTRS_TO_ELT: { unsigned short ArgNo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]); unsigned short RefNo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]); OutputTable.push_back( IITDescriptor::get(IITDescriptor::VecOfAnyPtrsToElt, ArgNo, RefNo)); return; } case IIT_EMPTYSTRUCT: OutputTable.push_back(IITDescriptor::get(IITDescriptor::Struct, 0)); return; case IIT_STRUCT8: ++StructElts; LLVM_FALLTHROUGH; case IIT_STRUCT7: ++StructElts; LLVM_FALLTHROUGH; case IIT_STRUCT6: ++StructElts; LLVM_FALLTHROUGH; case IIT_STRUCT5: ++StructElts; LLVM_FALLTHROUGH; case IIT_STRUCT4: ++StructElts; LLVM_FALLTHROUGH; case IIT_STRUCT3: ++StructElts; LLVM_FALLTHROUGH; case IIT_STRUCT2: { OutputTable.push_back(IITDescriptor::get(IITDescriptor::Struct,StructElts)); for (unsigned i = 0; i != StructElts; ++i) DecodeIITType(NextElt, Infos, OutputTable); return; } case IIT_SUBDIVIDE2_ARG: { unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]); OutputTable.push_back(IITDescriptor::get(IITDescriptor::Subdivide2Argument, ArgInfo)); return; } case IIT_SUBDIVIDE4_ARG: { unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]); OutputTable.push_back(IITDescriptor::get(IITDescriptor::Subdivide4Argument, ArgInfo)); return; } case IIT_VEC_ELEMENT: { unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]); OutputTable.push_back(IITDescriptor::get(IITDescriptor::VecElementArgument, ArgInfo)); return; } case IIT_SCALABLE_VEC: { OutputTable.push_back(IITDescriptor::get(IITDescriptor::ScalableVecArgument, 0)); DecodeIITType(NextElt, Infos, OutputTable); return; } case IIT_VEC_OF_BITCASTS_TO_INT: { unsigned ArgInfo = (NextElt == Infos.size() ? 0 : Infos[NextElt++]); OutputTable.push_back(IITDescriptor::get(IITDescriptor::VecOfBitcastsToInt, ArgInfo)); return; } } llvm_unreachable("unhandled"); } #define GET_INTRINSIC_GENERATOR_GLOBAL #include "llvm/IR/IntrinsicImpl.inc" #undef GET_INTRINSIC_GENERATOR_GLOBAL void Intrinsic::getIntrinsicInfoTableEntries(ID id, SmallVectorImpl &T){ // Check to see if the intrinsic's type was expressible by the table. unsigned TableVal = IIT_Table[id-1]; // Decode the TableVal into an array of IITValues. SmallVector IITValues; ArrayRef IITEntries; unsigned NextElt = 0; if ((TableVal >> 31) != 0) { // This is an offset into the IIT_LongEncodingTable. IITEntries = IIT_LongEncodingTable; // Strip sentinel bit. NextElt = (TableVal << 1) >> 1; } else { // Decode the TableVal into an array of IITValues. If the entry was encoded // into a single word in the table itself, decode it now. do { IITValues.push_back(TableVal & 0xF); TableVal >>= 4; } while (TableVal); IITEntries = IITValues; NextElt = 0; } // Okay, decode the table into the output vector of IITDescriptors. DecodeIITType(NextElt, IITEntries, T); while (NextElt != IITEntries.size() && IITEntries[NextElt] != 0) DecodeIITType(NextElt, IITEntries, T); } static Type *DecodeFixedType(ArrayRef &Infos, ArrayRef Tys, LLVMContext &Context) { using namespace Intrinsic; IITDescriptor D = Infos.front(); Infos = Infos.slice(1); switch (D.Kind) { case IITDescriptor::Void: return Type::getVoidTy(Context); case IITDescriptor::VarArg: return Type::getVoidTy(Context); case IITDescriptor::MMX: return Type::getX86_MMXTy(Context); case IITDescriptor::Token: return Type::getTokenTy(Context); case IITDescriptor::Metadata: return Type::getMetadataTy(Context); case IITDescriptor::Half: return Type::getHalfTy(Context); case IITDescriptor::Float: return Type::getFloatTy(Context); case IITDescriptor::Double: return Type::getDoubleTy(Context); case IITDescriptor::Quad: return Type::getFP128Ty(Context); case IITDescriptor::Integer: return IntegerType::get(Context, D.Integer_Width); case IITDescriptor::Vector: return VectorType::get(DecodeFixedType(Infos, Tys, Context),D.Vector_Width); case IITDescriptor::Pointer: return PointerType::get(DecodeFixedType(Infos, Tys, Context), D.Pointer_AddressSpace); case IITDescriptor::Struct: { SmallVector Elts; for (unsigned i = 0, e = D.Struct_NumElements; i != e; ++i) Elts.push_back(DecodeFixedType(Infos, Tys, Context)); return StructType::get(Context, Elts); } case IITDescriptor::Argument: return Tys[D.getArgumentNumber()]; case IITDescriptor::ExtendArgument: { Type *Ty = Tys[D.getArgumentNumber()]; if (VectorType *VTy = dyn_cast(Ty)) return VectorType::getExtendedElementVectorType(VTy); return IntegerType::get(Context, 2 * cast(Ty)->getBitWidth()); } case IITDescriptor::TruncArgument: { Type *Ty = Tys[D.getArgumentNumber()]; if (VectorType *VTy = dyn_cast(Ty)) return VectorType::getTruncatedElementVectorType(VTy); IntegerType *ITy = cast(Ty); assert(ITy->getBitWidth() % 2 == 0); return IntegerType::get(Context, ITy->getBitWidth() / 2); } case IITDescriptor::Subdivide2Argument: case IITDescriptor::Subdivide4Argument: { Type *Ty = Tys[D.getArgumentNumber()]; VectorType *VTy = dyn_cast(Ty); assert(VTy && "Expected an argument of Vector Type"); int SubDivs = D.Kind == IITDescriptor::Subdivide2Argument ? 1 : 2; return VectorType::getSubdividedVectorType(VTy, SubDivs); } case IITDescriptor::HalfVecArgument: return VectorType::getHalfElementsVectorType(cast( Tys[D.getArgumentNumber()])); case IITDescriptor::SameVecWidthArgument: { Type *EltTy = DecodeFixedType(Infos, Tys, Context); Type *Ty = Tys[D.getArgumentNumber()]; if (auto *VTy = dyn_cast(Ty)) return VectorType::get(EltTy, VTy->getElementCount()); return EltTy; } case IITDescriptor::PtrToArgument: { Type *Ty = Tys[D.getArgumentNumber()]; return PointerType::getUnqual(Ty); } case IITDescriptor::PtrToElt: { Type *Ty = Tys[D.getArgumentNumber()]; VectorType *VTy = dyn_cast(Ty); if (!VTy) llvm_unreachable("Expected an argument of Vector Type"); Type *EltTy = VTy->getElementType(); return PointerType::getUnqual(EltTy); } case IITDescriptor::VecElementArgument: { Type *Ty = Tys[D.getArgumentNumber()]; if (VectorType *VTy = dyn_cast(Ty)) return VTy->getElementType(); llvm_unreachable("Expected an argument of Vector Type"); } case IITDescriptor::VecOfBitcastsToInt: { Type *Ty = Tys[D.getArgumentNumber()]; VectorType *VTy = dyn_cast(Ty); assert(VTy && "Expected an argument of Vector Type"); return VectorType::getInteger(VTy); } case IITDescriptor::VecOfAnyPtrsToElt: // Return the overloaded type (which determines the pointers address space) return Tys[D.getOverloadArgNumber()]; case IITDescriptor::ScalableVecArgument: { auto *Ty = cast(DecodeFixedType(Infos, Tys, Context)); - return VectorType::get(Ty->getElementType(), - {(unsigned)Ty->getNumElements(), true}); + return VectorType::get(Ty->getElementType(), {Ty->getNumElements(), true}); } } llvm_unreachable("unhandled"); } FunctionType *Intrinsic::getType(LLVMContext &Context, ID id, ArrayRef Tys) { SmallVector Table; getIntrinsicInfoTableEntries(id, Table); ArrayRef TableRef = Table; Type *ResultTy = DecodeFixedType(TableRef, Tys, Context); SmallVector ArgTys; while (!TableRef.empty()) ArgTys.push_back(DecodeFixedType(TableRef, Tys, Context)); // DecodeFixedType returns Void for IITDescriptor::Void and IITDescriptor::VarArg // If we see void type as the type of the last argument, it is vararg intrinsic if (!ArgTys.empty() && ArgTys.back()->isVoidTy()) { ArgTys.pop_back(); return FunctionType::get(ResultTy, ArgTys, true); } return FunctionType::get(ResultTy, ArgTys, false); } bool Intrinsic::isOverloaded(ID id) { #define GET_INTRINSIC_OVERLOAD_TABLE #include "llvm/IR/IntrinsicImpl.inc" #undef GET_INTRINSIC_OVERLOAD_TABLE } bool Intrinsic::isLeaf(ID id) { switch (id) { default: return true; case Intrinsic::experimental_gc_statepoint: case Intrinsic::experimental_patchpoint_void: case Intrinsic::experimental_patchpoint_i64: return false; } } /// This defines the "Intrinsic::getAttributes(ID id)" method. #define GET_INTRINSIC_ATTRIBUTES #include "llvm/IR/IntrinsicImpl.inc" #undef GET_INTRINSIC_ATTRIBUTES Function *Intrinsic::getDeclaration(Module *M, ID id, ArrayRef Tys) { // There can never be multiple globals with the same name of different types, // because intrinsics must be a specific type. return cast( M->getOrInsertFunction(getName(id, Tys), getType(M->getContext(), id, Tys)) .getCallee()); } // This defines the "Intrinsic::getIntrinsicForGCCBuiltin()" method. #define GET_LLVM_INTRINSIC_FOR_GCC_BUILTIN #include "llvm/IR/IntrinsicImpl.inc" #undef GET_LLVM_INTRINSIC_FOR_GCC_BUILTIN // This defines the "Intrinsic::getIntrinsicForMSBuiltin()" method. #define GET_LLVM_INTRINSIC_FOR_MS_BUILTIN #include "llvm/IR/IntrinsicImpl.inc" #undef GET_LLVM_INTRINSIC_FOR_MS_BUILTIN using DeferredIntrinsicMatchPair = std::pair>; static bool matchIntrinsicType( Type *Ty, ArrayRef &Infos, SmallVectorImpl &ArgTys, SmallVectorImpl &DeferredChecks, bool IsDeferredCheck) { using namespace Intrinsic; // If we ran out of descriptors, there are too many arguments. if (Infos.empty()) return true; // Do this before slicing off the 'front' part auto InfosRef = Infos; auto DeferCheck = [&DeferredChecks, &InfosRef](Type *T) { DeferredChecks.emplace_back(T, InfosRef); return false; }; IITDescriptor D = Infos.front(); Infos = Infos.slice(1); switch (D.Kind) { case IITDescriptor::Void: return !Ty->isVoidTy(); case IITDescriptor::VarArg: return true; case IITDescriptor::MMX: return !Ty->isX86_MMXTy(); case IITDescriptor::Token: return !Ty->isTokenTy(); case IITDescriptor::Metadata: return !Ty->isMetadataTy(); case IITDescriptor::Half: return !Ty->isHalfTy(); case IITDescriptor::Float: return !Ty->isFloatTy(); case IITDescriptor::Double: return !Ty->isDoubleTy(); case IITDescriptor::Quad: return !Ty->isFP128Ty(); case IITDescriptor::Integer: return !Ty->isIntegerTy(D.Integer_Width); case IITDescriptor::Vector: { VectorType *VT = dyn_cast(Ty); return !VT || VT->getNumElements() != D.Vector_Width || matchIntrinsicType(VT->getElementType(), Infos, ArgTys, DeferredChecks, IsDeferredCheck); } case IITDescriptor::Pointer: { PointerType *PT = dyn_cast(Ty); return !PT || PT->getAddressSpace() != D.Pointer_AddressSpace || matchIntrinsicType(PT->getElementType(), Infos, ArgTys, DeferredChecks, IsDeferredCheck); } case IITDescriptor::Struct: { StructType *ST = dyn_cast(Ty); if (!ST || ST->getNumElements() != D.Struct_NumElements) return true; for (unsigned i = 0, e = D.Struct_NumElements; i != e; ++i) if (matchIntrinsicType(ST->getElementType(i), Infos, ArgTys, DeferredChecks, IsDeferredCheck)) return true; return false; } case IITDescriptor::Argument: // If this is the second occurrence of an argument, // verify that the later instance matches the previous instance. if (D.getArgumentNumber() < ArgTys.size()) return Ty != ArgTys[D.getArgumentNumber()]; if (D.getArgumentNumber() > ArgTys.size() || D.getArgumentKind() == IITDescriptor::AK_MatchType) return IsDeferredCheck || DeferCheck(Ty); assert(D.getArgumentNumber() == ArgTys.size() && !IsDeferredCheck && "Table consistency error"); ArgTys.push_back(Ty); switch (D.getArgumentKind()) { case IITDescriptor::AK_Any: return false; // Success case IITDescriptor::AK_AnyInteger: return !Ty->isIntOrIntVectorTy(); case IITDescriptor::AK_AnyFloat: return !Ty->isFPOrFPVectorTy(); case IITDescriptor::AK_AnyVector: return !isa(Ty); case IITDescriptor::AK_AnyPointer: return !isa(Ty); default: break; } llvm_unreachable("all argument kinds not covered"); case IITDescriptor::ExtendArgument: { // If this is a forward reference, defer the check for later. if (D.getArgumentNumber() >= ArgTys.size()) return IsDeferredCheck || DeferCheck(Ty); Type *NewTy = ArgTys[D.getArgumentNumber()]; if (VectorType *VTy = dyn_cast(NewTy)) NewTy = VectorType::getExtendedElementVectorType(VTy); else if (IntegerType *ITy = dyn_cast(NewTy)) NewTy = IntegerType::get(ITy->getContext(), 2 * ITy->getBitWidth()); else return true; return Ty != NewTy; } case IITDescriptor::TruncArgument: { // If this is a forward reference, defer the check for later. if (D.getArgumentNumber() >= ArgTys.size()) return IsDeferredCheck || DeferCheck(Ty); Type *NewTy = ArgTys[D.getArgumentNumber()]; if (VectorType *VTy = dyn_cast(NewTy)) NewTy = VectorType::getTruncatedElementVectorType(VTy); else if (IntegerType *ITy = dyn_cast(NewTy)) NewTy = IntegerType::get(ITy->getContext(), ITy->getBitWidth() / 2); else return true; return Ty != NewTy; } case IITDescriptor::HalfVecArgument: // If this is a forward reference, defer the check for later. if (D.getArgumentNumber() >= ArgTys.size()) return IsDeferredCheck || DeferCheck(Ty); return !isa(ArgTys[D.getArgumentNumber()]) || VectorType::getHalfElementsVectorType( cast(ArgTys[D.getArgumentNumber()])) != Ty; case IITDescriptor::SameVecWidthArgument: { if (D.getArgumentNumber() >= ArgTys.size()) { // Defer check and subsequent check for the vector element type. Infos = Infos.slice(1); return IsDeferredCheck || DeferCheck(Ty); } auto *ReferenceType = dyn_cast(ArgTys[D.getArgumentNumber()]); auto *ThisArgType = dyn_cast(Ty); // Both must be vectors of the same number of elements or neither. if ((ReferenceType != nullptr) != (ThisArgType != nullptr)) return true; Type *EltTy = Ty; if (ThisArgType) { if (ReferenceType->getElementCount() != ThisArgType->getElementCount()) return true; EltTy = ThisArgType->getElementType(); } return matchIntrinsicType(EltTy, Infos, ArgTys, DeferredChecks, IsDeferredCheck); } case IITDescriptor::PtrToArgument: { if (D.getArgumentNumber() >= ArgTys.size()) return IsDeferredCheck || DeferCheck(Ty); Type * ReferenceType = ArgTys[D.getArgumentNumber()]; PointerType *ThisArgType = dyn_cast(Ty); return (!ThisArgType || ThisArgType->getElementType() != ReferenceType); } case IITDescriptor::PtrToElt: { if (D.getArgumentNumber() >= ArgTys.size()) return IsDeferredCheck || DeferCheck(Ty); VectorType * ReferenceType = dyn_cast (ArgTys[D.getArgumentNumber()]); PointerType *ThisArgType = dyn_cast(Ty); return (!ThisArgType || !ReferenceType || ThisArgType->getElementType() != ReferenceType->getElementType()); } case IITDescriptor::VecOfAnyPtrsToElt: { unsigned RefArgNumber = D.getRefArgNumber(); if (RefArgNumber >= ArgTys.size()) { if (IsDeferredCheck) return true; // If forward referencing, already add the pointer-vector type and // defer the checks for later. ArgTys.push_back(Ty); return DeferCheck(Ty); } if (!IsDeferredCheck){ assert(D.getOverloadArgNumber() == ArgTys.size() && "Table consistency error"); ArgTys.push_back(Ty); } // Verify the overloaded type "matches" the Ref type. // i.e. Ty is a vector with the same width as Ref. // Composed of pointers to the same element type as Ref. VectorType *ReferenceType = dyn_cast(ArgTys[RefArgNumber]); VectorType *ThisArgVecTy = dyn_cast(Ty); if (!ThisArgVecTy || !ReferenceType || (ReferenceType->getNumElements() != ThisArgVecTy->getNumElements())) return true; PointerType *ThisArgEltTy = dyn_cast(ThisArgVecTy->getElementType()); if (!ThisArgEltTy) return true; return ThisArgEltTy->getElementType() != ReferenceType->getElementType(); } case IITDescriptor::VecElementArgument: { if (D.getArgumentNumber() >= ArgTys.size()) return IsDeferredCheck ? true : DeferCheck(Ty); auto *ReferenceType = dyn_cast(ArgTys[D.getArgumentNumber()]); return !ReferenceType || Ty != ReferenceType->getElementType(); } case IITDescriptor::Subdivide2Argument: case IITDescriptor::Subdivide4Argument: { // If this is a forward reference, defer the check for later. if (D.getArgumentNumber() >= ArgTys.size()) return IsDeferredCheck || DeferCheck(Ty); Type *NewTy = ArgTys[D.getArgumentNumber()]; if (auto *VTy = dyn_cast(NewTy)) { int SubDivs = D.Kind == IITDescriptor::Subdivide2Argument ? 1 : 2; NewTy = VectorType::getSubdividedVectorType(VTy, SubDivs); return Ty != NewTy; } return true; } case IITDescriptor::ScalableVecArgument: { VectorType *VTy = dyn_cast(Ty); if (!VTy || !VTy->isScalable()) return true; return matchIntrinsicType(VTy, Infos, ArgTys, DeferredChecks, IsDeferredCheck); } case IITDescriptor::VecOfBitcastsToInt: { if (D.getArgumentNumber() >= ArgTys.size()) return IsDeferredCheck || DeferCheck(Ty); auto *ReferenceType = dyn_cast(ArgTys[D.getArgumentNumber()]); auto *ThisArgVecTy = dyn_cast(Ty); if (!ThisArgVecTy || !ReferenceType) return true; return ThisArgVecTy != VectorType::getInteger(ReferenceType); } } llvm_unreachable("unhandled"); } Intrinsic::MatchIntrinsicTypesResult Intrinsic::matchIntrinsicSignature(FunctionType *FTy, ArrayRef &Infos, SmallVectorImpl &ArgTys) { SmallVector DeferredChecks; if (matchIntrinsicType(FTy->getReturnType(), Infos, ArgTys, DeferredChecks, false)) return MatchIntrinsicTypes_NoMatchRet; unsigned NumDeferredReturnChecks = DeferredChecks.size(); for (auto Ty : FTy->params()) if (matchIntrinsicType(Ty, Infos, ArgTys, DeferredChecks, false)) return MatchIntrinsicTypes_NoMatchArg; for (unsigned I = 0, E = DeferredChecks.size(); I != E; ++I) { DeferredIntrinsicMatchPair &Check = DeferredChecks[I]; if (matchIntrinsicType(Check.first, Check.second, ArgTys, DeferredChecks, true)) return I < NumDeferredReturnChecks ? MatchIntrinsicTypes_NoMatchRet : MatchIntrinsicTypes_NoMatchArg; } return MatchIntrinsicTypes_Match; } bool Intrinsic::matchIntrinsicVarArg(bool isVarArg, ArrayRef &Infos) { // If there are no descriptors left, then it can't be a vararg. if (Infos.empty()) return isVarArg; // There should be only one descriptor remaining at this point. if (Infos.size() != 1) return true; // Check and verify the descriptor. IITDescriptor D = Infos.front(); Infos = Infos.slice(1); if (D.Kind == IITDescriptor::VarArg) return !isVarArg; return true; } Optional Intrinsic::remangleIntrinsicFunction(Function *F) { Intrinsic::ID ID = F->getIntrinsicID(); if (!ID) return None; FunctionType *FTy = F->getFunctionType(); // Accumulate an array of overloaded types for the given intrinsic SmallVector ArgTys; { SmallVector Table; getIntrinsicInfoTableEntries(ID, Table); ArrayRef TableRef = Table; if (Intrinsic::matchIntrinsicSignature(FTy, TableRef, ArgTys)) return None; if (Intrinsic::matchIntrinsicVarArg(FTy->isVarArg(), TableRef)) return None; } StringRef Name = F->getName(); if (Name == Intrinsic::getName(ID, ArgTys)) return None; auto NewDecl = Intrinsic::getDeclaration(F->getParent(), ID, ArgTys); NewDecl->setCallingConv(F->getCallingConv()); assert(NewDecl->getFunctionType() == FTy && "Shouldn't change the signature"); return NewDecl; } /// hasAddressTaken - returns true if there are any uses of this function /// other than direct calls or invokes to it. bool Function::hasAddressTaken(const User* *PutOffender) const { for (const Use &U : uses()) { const User *FU = U.getUser(); if (isa(FU)) continue; const auto *Call = dyn_cast(FU); if (!Call) { if (PutOffender) *PutOffender = FU; return true; } if (!Call->isCallee(&U)) { if (PutOffender) *PutOffender = FU; return true; } } return false; } bool Function::isDefTriviallyDead() const { // Check the linkage if (!hasLinkOnceLinkage() && !hasLocalLinkage() && !hasAvailableExternallyLinkage()) return false; // Check if the function is used by anything other than a blockaddress. for (const User *U : users()) if (!isa(U)) return false; return true; } /// callsFunctionThatReturnsTwice - Return true if the function has a call to /// setjmp or other function that gcc recognizes as "returning twice". bool Function::callsFunctionThatReturnsTwice() const { for (const Instruction &I : instructions(this)) if (const auto *Call = dyn_cast(&I)) if (Call->hasFnAttr(Attribute::ReturnsTwice)) return true; return false; } Constant *Function::getPersonalityFn() const { assert(hasPersonalityFn() && getNumOperands()); return cast(Op<0>()); } void Function::setPersonalityFn(Constant *Fn) { setHungoffOperand<0>(Fn); setValueSubclassDataBit(3, Fn != nullptr); } Constant *Function::getPrefixData() const { assert(hasPrefixData() && getNumOperands()); return cast(Op<1>()); } void Function::setPrefixData(Constant *PrefixData) { setHungoffOperand<1>(PrefixData); setValueSubclassDataBit(1, PrefixData != nullptr); } Constant *Function::getPrologueData() const { assert(hasPrologueData() && getNumOperands()); return cast(Op<2>()); } void Function::setPrologueData(Constant *PrologueData) { setHungoffOperand<2>(PrologueData); setValueSubclassDataBit(2, PrologueData != nullptr); } void Function::allocHungoffUselist() { // If we've already allocated a uselist, stop here. if (getNumOperands()) return; allocHungoffUses(3, /*IsPhi=*/ false); setNumHungOffUseOperands(3); // Initialize the uselist with placeholder operands to allow traversal. auto *CPN = ConstantPointerNull::get(Type::getInt1PtrTy(getContext(), 0)); Op<0>().set(CPN); Op<1>().set(CPN); Op<2>().set(CPN); } template void Function::setHungoffOperand(Constant *C) { if (C) { allocHungoffUselist(); Op().set(C); } else if (getNumOperands()) { Op().set( ConstantPointerNull::get(Type::getInt1PtrTy(getContext(), 0))); } } void Function::setValueSubclassDataBit(unsigned Bit, bool On) { assert(Bit < 16 && "SubclassData contains only 16 bits"); if (On) setValueSubclassData(getSubclassDataFromValue() | (1 << Bit)); else setValueSubclassData(getSubclassDataFromValue() & ~(1 << Bit)); } void Function::setEntryCount(ProfileCount Count, const DenseSet *S) { assert(Count.hasValue()); #if !defined(NDEBUG) auto PrevCount = getEntryCount(); assert(!PrevCount.hasValue() || PrevCount.getType() == Count.getType()); #endif auto ImportGUIDs = getImportGUIDs(); if (S == nullptr && ImportGUIDs.size()) S = &ImportGUIDs; MDBuilder MDB(getContext()); setMetadata( LLVMContext::MD_prof, MDB.createFunctionEntryCount(Count.getCount(), Count.isSynthetic(), S)); } void Function::setEntryCount(uint64_t Count, Function::ProfileCountType Type, const DenseSet *Imports) { setEntryCount(ProfileCount(Count, Type), Imports); } ProfileCount Function::getEntryCount(bool AllowSynthetic) const { MDNode *MD = getMetadata(LLVMContext::MD_prof); if (MD && MD->getOperand(0)) if (MDString *MDS = dyn_cast(MD->getOperand(0))) { if (MDS->getString().equals("function_entry_count")) { ConstantInt *CI = mdconst::extract(MD->getOperand(1)); uint64_t Count = CI->getValue().getZExtValue(); // A value of -1 is used for SamplePGO when there were no samples. // Treat this the same as unknown. if (Count == (uint64_t)-1) return ProfileCount::getInvalid(); return ProfileCount(Count, PCT_Real); } else if (AllowSynthetic && MDS->getString().equals("synthetic_function_entry_count")) { ConstantInt *CI = mdconst::extract(MD->getOperand(1)); uint64_t Count = CI->getValue().getZExtValue(); return ProfileCount(Count, PCT_Synthetic); } } return ProfileCount::getInvalid(); } DenseSet Function::getImportGUIDs() const { DenseSet R; if (MDNode *MD = getMetadata(LLVMContext::MD_prof)) if (MDString *MDS = dyn_cast(MD->getOperand(0))) if (MDS->getString().equals("function_entry_count")) for (unsigned i = 2; i < MD->getNumOperands(); i++) R.insert(mdconst::extract(MD->getOperand(i)) ->getValue() .getZExtValue()); return R; } void Function::setSectionPrefix(StringRef Prefix) { MDBuilder MDB(getContext()); setMetadata(LLVMContext::MD_section_prefix, MDB.createFunctionSectionPrefix(Prefix)); } Optional Function::getSectionPrefix() const { if (MDNode *MD = getMetadata(LLVMContext::MD_section_prefix)) { assert(cast(MD->getOperand(0)) ->getString() .equals("function_section_prefix") && "Metadata not match"); return cast(MD->getOperand(1))->getString(); } return None; } bool Function::nullPointerIsDefined() const { return getFnAttribute("null-pointer-is-valid") .getValueAsString() .equals("true"); } bool llvm::NullPointerIsDefined(const Function *F, unsigned AS) { if (F && F->nullPointerIsDefined()) return true; if (AS != 0) return true; return false; } diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index acc84f9e9a46..abb3fd74dfa9 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1,2982 +1,2982 @@ //===- StandardToLLVM.cpp - Standard to LLVM dialect conversion -----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements a pass to convert MLIR standard and builtin dialects // into the LLVM IR dialect. // //===----------------------------------------------------------------------===// #include "../PassDetail.h" #include "mlir/ADT/TypeSwitch.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Type.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" #include using namespace mlir; #define PASS_NAME "convert-std-to-llvm" // Extract an LLVM IR type from the LLVM IR dialect type. static LLVM::LLVMType unwrap(Type type) { if (!type) return nullptr; auto *mlirContext = type.getContext(); auto wrappedLLVMType = type.dyn_cast(); if (!wrappedLLVMType) emitError(UnknownLoc::get(mlirContext), "conversion resulted in a non-LLVM type"); return wrappedLLVMType; } /// Initialize customization to default callbacks. LLVMTypeConverterCustomization::LLVMTypeConverterCustomization() : funcArgConverter(structFuncArgTypeConverter), indexBitwidth(kDeriveIndexBitwidthFromDataLayout) {} /// Callback to convert function argument types. It converts a MemRef function /// argument to a list of non-aggregate types containing descriptor /// information, and an UnrankedmemRef function argument to a list containing /// the rank and a pointer to a descriptor struct. LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter, Type type, SmallVectorImpl &result) { if (auto memref = type.dyn_cast()) { auto converted = converter.convertMemRefSignature(memref); if (converted.empty()) return failure(); result.append(converted.begin(), converted.end()); return success(); } if (type.isa()) { auto converted = converter.convertUnrankedMemRefSignature(); if (converted.empty()) return failure(); result.append(converted.begin(), converted.end()); return success(); } auto converted = converter.convertType(type); if (!converted) return failure(); result.push_back(converted); return success(); } /// Convert a MemRef type to a bare pointer to the MemRef element type. static Type convertMemRefTypeToBarePtr(LLVMTypeConverter &converter, MemRefType type) { int64_t offset; SmallVector strides; if (failed(getStridesAndOffset(type, strides, offset))) return {}; LLVM::LLVMType elementType = unwrap(converter.convertType(type.getElementType())); if (!elementType) return {}; return elementType.getPointerTo(type.getMemorySpace()); } /// Callback to convert function argument types. It converts MemRef function /// arguments to bare pointers to the MemRef element type. LogicalResult mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter, Type type, SmallVectorImpl &result) { // TODO: Add support for unranked memref. if (auto memrefTy = type.dyn_cast()) { auto llvmTy = convertMemRefTypeToBarePtr(converter, memrefTy); if (!llvmTy) return failure(); result.push_back(llvmTy); return success(); } auto llvmTy = converter.convertType(type); if (!llvmTy) return failure(); result.push_back(llvmTy); return success(); } /// Create an LLVMTypeConverter using default LLVMTypeConverterCustomization. LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx) : LLVMTypeConverter(ctx, LLVMTypeConverterCustomization()) {} /// Create an LLVMTypeConverter using 'custom' customizations. LLVMTypeConverter::LLVMTypeConverter( MLIRContext *ctx, const LLVMTypeConverterCustomization &customs) : llvmDialect(ctx->getRegisteredDialect()), customizations(customs) { assert(llvmDialect && "LLVM IR dialect is not registered"); module = &llvmDialect->getLLVMModule(); if (customizations.indexBitwidth == kDeriveIndexBitwidthFromDataLayout) customizations.indexBitwidth = module->getDataLayout().getPointerSizeInBits(); // Register conversions for the standard types. addConversion([&](FloatType type) { return convertFloatType(type); }); addConversion([&](FunctionType type) { return convertFunctionType(type); }); addConversion([&](IndexType type) { return convertIndexType(type); }); addConversion([&](IntegerType type) { return convertIntegerType(type); }); addConversion([&](MemRefType type) { return convertMemRefType(type); }); addConversion( [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); }); addConversion([&](VectorType type) { return convertVectorType(type); }); // LLVMType is legal, so add a pass-through conversion. addConversion([](LLVM::LLVMType type) { return type; }); } /// Returns the MLIR context. MLIRContext &LLVMTypeConverter::getContext() { return *getDialect()->getContext(); } /// Get the LLVM context. llvm::LLVMContext &LLVMTypeConverter::getLLVMContext() { return module->getContext(); } LLVM::LLVMType LLVMTypeConverter::getIndexType() { return LLVM::LLVMType::getIntNTy(llvmDialect, getIndexTypeBitwidth()); } Type LLVMTypeConverter::convertIndexType(IndexType type) { return getIndexType(); } Type LLVMTypeConverter::convertIntegerType(IntegerType type) { return LLVM::LLVMType::getIntNTy(llvmDialect, type.getWidth()); } Type LLVMTypeConverter::convertFloatType(FloatType type) { switch (type.getKind()) { case mlir::StandardTypes::F32: return LLVM::LLVMType::getFloatTy(llvmDialect); case mlir::StandardTypes::F64: return LLVM::LLVMType::getDoubleTy(llvmDialect); case mlir::StandardTypes::F16: return LLVM::LLVMType::getHalfTy(llvmDialect); case mlir::StandardTypes::BF16: { auto *mlirContext = llvmDialect->getContext(); return emitError(UnknownLoc::get(mlirContext), "unsupported type: BF16"), Type(); } default: llvm_unreachable("non-float type in convertFloatType"); } } // Except for signatures, MLIR function types are converted into LLVM // pointer-to-function types. Type LLVMTypeConverter::convertFunctionType(FunctionType type) { SignatureConversion conversion(type.getNumInputs()); LLVM::LLVMType converted = convertFunctionSignature(type, /*isVariadic=*/false, conversion); return converted.getPointerTo(); } /// In signatures, MemRef descriptors are expanded into lists of non-aggregate /// values. SmallVector LLVMTypeConverter::convertMemRefSignature(MemRefType type) { SmallVector results; assert(isStrided(type) && "Non-strided layout maps must have been normalized away"); LLVM::LLVMType elementType = unwrap(convertType(type.getElementType())); if (!elementType) return {}; auto indexTy = getIndexType(); results.insert(results.begin(), 2, elementType.getPointerTo(type.getMemorySpace())); results.push_back(indexTy); auto rank = type.getRank(); results.insert(results.end(), 2 * rank, indexTy); return results; } /// In signatures, unranked MemRef descriptors are expanded into a pair "rank, /// pointer to descriptor". SmallVector LLVMTypeConverter::convertUnrankedMemRefSignature() { return {getIndexType(), LLVM::LLVMType::getInt8PtrTy(llvmDialect)}; } // Function types are converted to LLVM Function types by recursively converting // argument and result types. If MLIR Function has zero results, the LLVM // Function has one VoidType result. If MLIR Function has more than one result, // they are into an LLVM StructType in their order of appearance. LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature( FunctionType type, bool isVariadic, LLVMTypeConverter::SignatureConversion &result) { // Convert argument types one by one and check for errors. for (auto &en : llvm::enumerate(type.getInputs())) { Type type = en.value(); SmallVector converted; if (failed(customizations.funcArgConverter(*this, type, converted))) return {}; result.addInputs(en.index(), converted); } SmallVector argTypes; argTypes.reserve(llvm::size(result.getConvertedTypes())); for (Type type : result.getConvertedTypes()) argTypes.push_back(unwrap(type)); // If function does not return anything, create the void result type, // if it returns on element, convert it, otherwise pack the result types into // a struct. LLVM::LLVMType resultType = type.getNumResults() == 0 ? LLVM::LLVMType::getVoidTy(llvmDialect) : unwrap(packFunctionResults(type.getResults())); if (!resultType) return {}; return LLVM::LLVMType::getFunctionTy(resultType, argTypes, isVariadic); } /// Converts the function type to a C-compatible format, in particular using /// pointers to memref descriptors for arguments. LLVM::LLVMType LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) { SmallVector inputs; for (Type t : type.getInputs()) { auto converted = convertType(t).dyn_cast_or_null(); if (!converted) return {}; if (t.isa() || t.isa()) converted = converted.getPointerTo(); inputs.push_back(converted); } LLVM::LLVMType resultType = type.getNumResults() == 0 ? LLVM::LLVMType::getVoidTy(llvmDialect) : unwrap(packFunctionResults(type.getResults())); if (!resultType) return {}; return LLVM::LLVMType::getFunctionTy(resultType, inputs, false); } /// Creates descriptor structs from individual values constituting them. Operation *LLVMTypeConverter::materializeConversion(PatternRewriter &rewriter, Type type, ArrayRef values, Location loc) { if (auto unrankedMemRefType = type.dyn_cast()) return UnrankedMemRefDescriptor::pack(rewriter, loc, *this, unrankedMemRefType, values) .getDefiningOp(); auto memRefType = type.dyn_cast(); assert(memRefType && "1->N conversion is only supported for memrefs"); return MemRefDescriptor::pack(rewriter, loc, *this, memRefType, values) .getDefiningOp(); } // Convert a MemRef to an LLVM type. The result is a MemRef descriptor which // contains: // 1. the pointer to the data buffer, followed by // 2. a lowered `index`-type integer containing the distance between the // beginning of the buffer and the first element to be accessed through the // view, followed by // 3. an array containing as many `index`-type integers as the rank of the // MemRef: the array represents the size, in number of elements, of the memref // along the given dimension. For constant MemRef dimensions, the // corresponding size entry is a constant whose runtime value must match the // static value, followed by // 4. a second array containing as many `index`-type integers as the rank of // the MemRef: the second array represents the "stride" (in tensor abstraction // sense), i.e. the number of consecutive elements of the underlying buffer. // TODO(ntv, zinenko): add assertions for the static cases. // // template // struct { // Elem *allocatedPtr; // Elem *alignedPtr; // int64_t offset; // int64_t sizes[Rank]; // omitted when rank == 0 // int64_t strides[Rank]; // omitted when rank == 0 // }; static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor = 0; static constexpr unsigned kAlignedPtrPosInMemRefDescriptor = 1; static constexpr unsigned kOffsetPosInMemRefDescriptor = 2; static constexpr unsigned kSizePosInMemRefDescriptor = 3; static constexpr unsigned kStridePosInMemRefDescriptor = 4; Type LLVMTypeConverter::convertMemRefType(MemRefType type) { int64_t offset; SmallVector strides; bool strideSuccess = succeeded(getStridesAndOffset(type, strides, offset)); assert(strideSuccess && "Non-strided layout maps must have been normalized away"); (void)strideSuccess; LLVM::LLVMType elementType = unwrap(convertType(type.getElementType())); if (!elementType) return {}; auto ptrTy = elementType.getPointerTo(type.getMemorySpace()); auto indexTy = getIndexType(); auto rank = type.getRank(); if (rank > 0) { auto arrayTy = LLVM::LLVMType::getArrayTy(indexTy, type.getRank()); return LLVM::LLVMType::getStructTy(ptrTy, ptrTy, indexTy, arrayTy, arrayTy); } return LLVM::LLVMType::getStructTy(ptrTy, ptrTy, indexTy); } // Converts UnrankedMemRefType to LLVMType. The result is a descriptor which // contains: // 1. int64_t rank, the dynamic rank of this MemRef // 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be // stack allocated (alloca) copy of a MemRef descriptor that got casted to // be unranked. static constexpr unsigned kRankInUnrankedMemRefDescriptor = 0; static constexpr unsigned kPtrInUnrankedMemRefDescriptor = 1; Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) { auto rankTy = LLVM::LLVMType::getInt64Ty(llvmDialect); auto ptrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect); return LLVM::LLVMType::getStructTy(rankTy, ptrTy); } // Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when // n > 1. // For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and // `vector<4 x 8 x 16 f32>` converts to `!llvm<"[4 x [8 x <16 x float>]]">`. Type LLVMTypeConverter::convertVectorType(VectorType type) { auto elementType = unwrap(convertType(type.getElementType())); if (!elementType) return {}; auto vectorType = LLVM::LLVMType::getVectorTy(elementType, type.getShape().back()); auto shape = type.getShape(); for (int i = shape.size() - 2; i >= 0; --i) vectorType = LLVM::LLVMType::getArrayTy(vectorType, shape[i]); return vectorType; } ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, LLVMTypeConverter &typeConverter_, PatternBenefit benefit) : ConversionPattern(rootOpName, benefit, context), typeConverter(typeConverter_) {} /*============================================================================*/ /* StructBuilder implementation */ /*============================================================================*/ StructBuilder::StructBuilder(Value v) : value(v) { assert(value != nullptr && "value cannot be null"); structType = value.getType().dyn_cast(); assert(structType && "expected llvm type"); } Value StructBuilder::extractPtr(OpBuilder &builder, Location loc, unsigned pos) { Type type = structType.cast().getStructElementType(pos); return builder.create(loc, type, value, builder.getI64ArrayAttr(pos)); } void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr) { value = builder.create(loc, structType, value, ptr, builder.getI64ArrayAttr(pos)); } /*============================================================================*/ /* MemRefDescriptor implementation */ /*============================================================================*/ /// Construct a helper for the given descriptor value. MemRefDescriptor::MemRefDescriptor(Value descriptor) : StructBuilder(descriptor) { assert(value != nullptr && "value cannot be null"); indexType = value.getType().cast().getStructElementType( kOffsetPosInMemRefDescriptor); } /// Builds IR creating an `undef` value of the descriptor type. MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc, Type descriptorType) { Value descriptor = builder.create(loc, descriptorType.cast()); return MemRefDescriptor(descriptor); } /// Builds IR creating a MemRef descriptor that represents `type` and /// populates it with static shape and stride information extracted from the /// type. MemRefDescriptor MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, MemRefType type, Value memory) { assert(type.hasStaticShape() && "unexpected dynamic shape"); // Extract all strides and offsets and verify they are static. int64_t offset; SmallVector strides; auto result = getStridesAndOffset(type, strides, offset); (void)result; assert(succeeded(result) && "unexpected failure in stride computation"); assert(offset != MemRefType::getDynamicStrideOrOffset() && "expected static offset"); assert(!llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) && "expected static strides"); auto convertedType = typeConverter.convertType(type); assert(convertedType && "unexpected failure in memref type conversion"); auto descr = MemRefDescriptor::undef(builder, loc, convertedType); descr.setAllocatedPtr(builder, loc, memory); descr.setAlignedPtr(builder, loc, memory); descr.setConstantOffset(builder, loc, offset); // Fill in sizes and strides for (unsigned i = 0, e = type.getRank(); i != e; ++i) { descr.setConstantSize(builder, loc, i, type.getDimSize(i)); descr.setConstantStride(builder, loc, i, strides[i]); } return descr; } /// Builds IR extracting the allocated pointer from the descriptor. Value MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor); } /// Builds IR inserting the allocated pointer into the descriptor. void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc, Value ptr) { setPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor, ptr); } /// Builds IR extracting the aligned pointer from the descriptor. Value MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor); } /// Builds IR inserting the aligned pointer into the descriptor. void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, Value ptr) { setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr); } // Creates a constant Op producing a value of `resultType` from an index-typed // integer attribute. static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value) { return builder.create( loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value)); } /// Builds IR extracting the offset from the descriptor. Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) { return builder.create( loc, indexType, value, builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor)); } /// Builds IR inserting the offset into the descriptor. void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc, Value offset) { value = builder.create( loc, structType, value, offset, builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor)); } /// Builds IR inserting the offset into the descriptor. void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc, uint64_t offset) { setOffset(builder, loc, createIndexAttrConstant(builder, loc, indexType, offset)); } /// Builds IR extracting the pos-th size from the descriptor. Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) { return builder.create( loc, indexType, value, builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos})); } /// Builds IR inserting the pos-th size into the descriptor void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos, Value size) { value = builder.create( loc, structType, value, size, builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos})); } /// Builds IR inserting the pos-th size into the descriptor void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc, unsigned pos, uint64_t size) { setSize(builder, loc, pos, createIndexAttrConstant(builder, loc, indexType, size)); } /// Builds IR extracting the pos-th size from the descriptor. Value MemRefDescriptor::stride(OpBuilder &builder, Location loc, unsigned pos) { return builder.create( loc, indexType, value, builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos})); } /// Builds IR inserting the pos-th stride into the descriptor void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride) { value = builder.create( loc, structType, value, stride, builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos})); } /// Builds IR inserting the pos-th stride into the descriptor void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc, unsigned pos, uint64_t stride) { setStride(builder, loc, pos, createIndexAttrConstant(builder, loc, indexType, stride)); } LLVM::LLVMType MemRefDescriptor::getElementType() { return value.getType().cast().getStructElementType( kAlignedPtrPosInMemRefDescriptor); } /// Creates a MemRef descriptor structure from a list of individual values /// composing that descriptor, in the following order: /// - allocated pointer; /// - aligned pointer; /// - offset; /// - sizes; /// - shapes; /// where is the MemRef rank as provided in `type`. Value MemRefDescriptor::pack(OpBuilder &builder, Location loc, LLVMTypeConverter &converter, MemRefType type, ValueRange values) { Type llvmType = converter.convertType(type); auto d = MemRefDescriptor::undef(builder, loc, llvmType); d.setAllocatedPtr(builder, loc, values[kAllocatedPtrPosInMemRefDescriptor]); d.setAlignedPtr(builder, loc, values[kAlignedPtrPosInMemRefDescriptor]); d.setOffset(builder, loc, values[kOffsetPosInMemRefDescriptor]); int64_t rank = type.getRank(); for (unsigned i = 0; i < rank; ++i) { d.setSize(builder, loc, i, values[kSizePosInMemRefDescriptor + i]); d.setStride(builder, loc, i, values[kSizePosInMemRefDescriptor + rank + i]); } return d; } /// Builds IR extracting individual elements of a MemRef descriptor structure /// and returning them as `results` list. void MemRefDescriptor::unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl &results) { int64_t rank = type.getRank(); results.reserve(results.size() + getNumUnpackedValues(type)); MemRefDescriptor d(packed); results.push_back(d.allocatedPtr(builder, loc)); results.push_back(d.alignedPtr(builder, loc)); results.push_back(d.offset(builder, loc)); for (int64_t i = 0; i < rank; ++i) results.push_back(d.size(builder, loc, i)); for (int64_t i = 0; i < rank; ++i) results.push_back(d.stride(builder, loc, i)); } /// Returns the number of non-aggregate values that would be produced by /// `unpack`. unsigned MemRefDescriptor::getNumUnpackedValues(MemRefType type) { // Two pointers, offset, sizes, shapes. return 3 + 2 * type.getRank(); } /*============================================================================*/ /* MemRefDescriptorView implementation. */ /*============================================================================*/ MemRefDescriptorView::MemRefDescriptorView(ValueRange range) : rank((range.size() - kSizePosInMemRefDescriptor) / 2), elements(range) {} Value MemRefDescriptorView::allocatedPtr() { return elements[kAllocatedPtrPosInMemRefDescriptor]; } Value MemRefDescriptorView::alignedPtr() { return elements[kAlignedPtrPosInMemRefDescriptor]; } Value MemRefDescriptorView::offset() { return elements[kOffsetPosInMemRefDescriptor]; } Value MemRefDescriptorView::size(unsigned pos) { return elements[kSizePosInMemRefDescriptor + pos]; } Value MemRefDescriptorView::stride(unsigned pos) { return elements[kSizePosInMemRefDescriptor + rank + pos]; } /*============================================================================*/ /* UnrankedMemRefDescriptor implementation */ /*============================================================================*/ /// Construct a helper for the given descriptor value. UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor) : StructBuilder(descriptor) {} /// Builds IR creating an `undef` value of the descriptor type. UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder, Location loc, Type descriptorType) { Value descriptor = builder.create(loc, descriptorType.cast()); return UnrankedMemRefDescriptor(descriptor); } Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor); } void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc, Value v) { setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v); } Value UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor); } void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder, Location loc, Value v) { setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v); } /// Builds IR populating an unranked MemRef descriptor structure from a list /// of individual constituent values in the following order: /// - rank of the memref; /// - pointer to the memref descriptor. Value UnrankedMemRefDescriptor::pack(OpBuilder &builder, Location loc, LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values) { Type llvmType = converter.convertType(type); auto d = UnrankedMemRefDescriptor::undef(builder, loc, llvmType); d.setRank(builder, loc, values[kRankInUnrankedMemRefDescriptor]); d.setMemRefDescPtr(builder, loc, values[kPtrInUnrankedMemRefDescriptor]); return d; } /// Builds IR extracting individual elements that compose an unranked memref /// descriptor and returns them as `results` list. void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl &results) { UnrankedMemRefDescriptor d(packed); results.reserve(results.size() + 2); results.push_back(d.rank(builder, loc)); results.push_back(d.memRefDescPtr(builder, loc)); } LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const { return *typeConverter.getDialect(); } llvm::LLVMContext &ConvertToLLVMPattern::getContext() const { return typeConverter.getLLVMContext(); } llvm::Module &ConvertToLLVMPattern::getModule() const { return getDialect().getLLVMModule(); } LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const { return typeConverter.getIndexType(); } LLVM::LLVMType ConvertToLLVMPattern::getVoidType() const { return LLVM::LLVMType::getVoidTy(&getDialect()); } LLVM::LLVMType ConvertToLLVMPattern::getVoidPtrType() const { return LLVM::LLVMType::getInt8PtrTy(&getDialect()); } Value ConvertToLLVMPattern::createIndexConstant( ConversionPatternRewriter &builder, Location loc, uint64_t value) const { return createIndexAttrConstant(builder, loc, getIndexType(), value); } Value ConvertToLLVMPattern::linearizeSubscripts( ConversionPatternRewriter &builder, Location loc, ArrayRef indices, ArrayRef allocSizes) const { assert(indices.size() == allocSizes.size() && "mismatching number of indices and allocation sizes"); assert(!indices.empty() && "cannot linearize a 0-dimensional access"); Value linearized = indices.front(); for (int i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) { linearized = builder.create( loc, this->getIndexType(), ArrayRef{linearized, allocSizes[i]}); linearized = builder.create( loc, this->getIndexType(), ArrayRef{linearized, indices[i]}); } return linearized; } Value ConvertToLLVMPattern::getStridedElementPtr( Location loc, Type elementTypePtr, Value descriptor, ArrayRef indices, ArrayRef strides, int64_t offset, ConversionPatternRewriter &rewriter) const { MemRefDescriptor memRefDescriptor(descriptor); Value base = memRefDescriptor.alignedPtr(rewriter, loc); Value offsetValue = offset == MemRefType::getDynamicStrideOrOffset() ? memRefDescriptor.offset(rewriter, loc) : this->createIndexConstant(rewriter, loc, offset); for (int i = 0, e = indices.size(); i < e; ++i) { Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset() ? memRefDescriptor.stride(rewriter, loc, i) : this->createIndexConstant(rewriter, loc, strides[i]); Value additionalOffset = rewriter.create(loc, indices[i], stride); offsetValue = rewriter.create(loc, offsetValue, additionalOffset); } return rewriter.create(loc, elementTypePtr, base, offsetValue); } Value ConvertToLLVMPattern::getDataPtr(Location loc, MemRefType type, Value memRefDesc, ArrayRef indices, ConversionPatternRewriter &rewriter, llvm::Module &module) const { LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType(); int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(type, strides, offset); assert(succeeded(successStrides) && "unexpected non-strided memref"); (void)successStrides; return getStridedElementPtr(loc, ptrType, memRefDesc, indices, strides, offset, rewriter); } /// Only retain those attributes that are not constructed by /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument /// attributes. static void filterFuncAttributes(ArrayRef attrs, bool filterArgAttrs, SmallVectorImpl &result) { for (const auto &attr : attrs) { if (attr.first == SymbolTable::getSymbolAttrName() || attr.first == impl::getTypeAttrName() || attr.first == "std.varargs" || (filterArgAttrs && impl::isArgAttrName(attr.first.strref()))) continue; result.push_back(attr); } } /// Creates an auxiliary function with pointer-to-memref-descriptor-struct /// arguments instead of unpacked arguments. This function can be called from C /// by passing a pointer to a C struct corresponding to a memref descriptor. /// Internally, the auxiliary function unpacks the descriptor into individual /// components and forwards them to `newFuncOp`. static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, LLVMTypeConverter &typeConverter, FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) { auto type = funcOp.getType(); SmallVector attributes; filterFuncAttributes(funcOp.getAttrs(), /*filterArgAttrs=*/false, attributes); auto wrapperFuncOp = rewriter.create( loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), typeConverter.convertFunctionTypeCWrapper(type), LLVM::Linkage::External, attributes); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock()); SmallVector args; for (auto &en : llvm::enumerate(type.getInputs())) { Value arg = wrapperFuncOp.getArgument(en.index()); if (auto memrefType = en.value().dyn_cast()) { Value loaded = rewriter.create(loc, arg); MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args); continue; } if (en.value().isa()) { Value loaded = rewriter.create(loc, arg); UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args); continue; } args.push_back(wrapperFuncOp.getArgument(en.index())); } auto call = rewriter.create(loc, newFuncOp, args); rewriter.create(loc, call.getResults()); } /// Creates an auxiliary function with pointer-to-memref-descriptor-struct /// arguments instead of unpacked arguments. Creates a body for the (external) /// `newFuncOp` that allocates a memref descriptor on stack, packs the /// individual arguments into this descriptor and passes a pointer to it into /// the auxiliary function. This auxiliary external function is now compatible /// with functions defined in C using pointers to C structs corresponding to a /// memref descriptor. static void wrapExternalFunction(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) { OpBuilder::InsertionGuard guard(builder); LLVM::LLVMType wrapperType = typeConverter.convertFunctionTypeCWrapper(funcOp.getType()); // This conversion can only fail if it could not convert one of the argument // types. But since it has been applies to a non-wrapper function before, it // should have failed earlier and not reach this point at all. assert(wrapperType && "unexpected type conversion failure"); SmallVector attributes; filterFuncAttributes(funcOp.getAttrs(), /*filterArgAttrs=*/false, attributes); // Create the auxiliary function. auto wrapperFunc = builder.create( loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), wrapperType, LLVM::Linkage::External, attributes); builder.setInsertionPointToStart(newFuncOp.addEntryBlock()); // Get a ValueRange containing arguments. FunctionType type = funcOp.getType(); SmallVector args; args.reserve(type.getNumInputs()); ValueRange wrapperArgsRange(newFuncOp.getArguments()); // Iterate over the inputs of the original function and pack values into // memref descriptors if the original type is a memref. for (auto &en : llvm::enumerate(type.getInputs())) { Value arg; int numToDrop = 1; auto memRefType = en.value().dyn_cast(); auto unrankedMemRefType = en.value().dyn_cast(); if (memRefType || unrankedMemRefType) { numToDrop = memRefType ? MemRefDescriptor::getNumUnpackedValues(memRefType) : UnrankedMemRefDescriptor::getNumUnpackedValues(); Value packed = memRefType ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType, wrapperArgsRange.take_front(numToDrop)) : UnrankedMemRefDescriptor::pack( builder, loc, typeConverter, unrankedMemRefType, wrapperArgsRange.take_front(numToDrop)); auto ptrTy = packed.getType().cast().getPointerTo(); Value one = builder.create( loc, typeConverter.convertType(builder.getIndexType()), builder.getIntegerAttr(builder.getIndexType(), 1)); Value allocated = builder.create(loc, ptrTy, one, /*alignment=*/0); builder.create(loc, packed, allocated); arg = allocated; } else { arg = wrapperArgsRange[0]; } args.push_back(arg); wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop); } assert(wrapperArgsRange.empty() && "did not map some of the arguments"); auto call = builder.create(loc, wrapperFunc, args); builder.create(loc, call.getResults()); } namespace { struct FuncOpConversionBase : public ConvertOpToLLVMPattern { protected: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using UnsignedTypePair = std::pair; // Gather the positions and types of memref-typed arguments in a given // FunctionType. void getMemRefArgIndicesAndTypes( FunctionType type, SmallVectorImpl &argsInfo) const { argsInfo.reserve(type.getNumInputs()); for (auto en : llvm::enumerate(type.getInputs())) { if (en.value().isa() || en.value().isa()) argsInfo.push_back({en.index(), en.value()}); } } // Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided // to this legalization pattern. LLVM::LLVMFuncOp convertFuncOpToLLVMFuncOp(FuncOp funcOp, ConversionPatternRewriter &rewriter) const { // Convert the original function arguments. They are converted using the // LLVMTypeConverter provided to this legalization pattern. auto varargsAttr = funcOp.getAttrOfType("std.varargs"); TypeConverter::SignatureConversion result(funcOp.getNumArguments()); auto llvmType = typeConverter.convertFunctionSignature( funcOp.getType(), varargsAttr && varargsAttr.getValue(), result); // Propagate argument attributes to all converted arguments obtained after // converting a given original argument. SmallVector attributes; filterFuncAttributes(funcOp.getAttrs(), /*filterArgAttrs=*/true, attributes); for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) { auto attr = impl::getArgAttrDict(funcOp, i); if (!attr) continue; auto mapping = result.getInputMapping(i); assert(mapping.hasValue() && "unexpected deletion of function argument"); SmallString<8> name; for (size_t j = 0; j < mapping->size; ++j) { impl::getArgAttrName(mapping->inputNo + j, name); attributes.push_back(rewriter.getNamedAttr(name, attr)); } } // Create an LLVM function, use external linkage by default until MLIR // functions have linkage. auto newFuncOp = rewriter.create( funcOp.getLoc(), funcOp.getName(), llvmType, LLVM::Linkage::External, attributes); rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); // Tell the rewriter to convert the region signature. rewriter.applySignatureConversion(&newFuncOp.getBody(), result); return newFuncOp; } }; /// FuncOp legalization pattern that converts MemRef arguments to pointers to /// MemRef descriptors (LLVM struct data types) containing all the MemRef type /// information. static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface"; struct FuncOpConversion : public FuncOpConversionBase { FuncOpConversion(LLVMTypeConverter &converter, bool emitCWrappers) : FuncOpConversionBase(converter), emitWrappers(emitCWrappers) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto funcOp = cast(op); auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); if (emitWrappers || funcOp.getAttrOfType(kEmitIfaceAttrName)) { if (newFuncOp.isExternal()) wrapExternalFunction(rewriter, op->getLoc(), typeConverter, funcOp, newFuncOp); else wrapForExternalCallers(rewriter, op->getLoc(), typeConverter, funcOp, newFuncOp); } rewriter.eraseOp(op); return success(); } private: /// If true, also create the adaptor functions having signatures compatible /// with those produced by clang. const bool emitWrappers; }; /// FuncOp legalization pattern that converts MemRef arguments to bare pointers /// to the MemRef element type. This will impact the calling convention and ABI. struct BarePtrFuncOpConversion : public FuncOpConversionBase { using FuncOpConversionBase::FuncOpConversionBase; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto funcOp = cast(op); // Store the positions and type of memref-typed arguments so that we can // promote them to MemRef descriptor structs at the beginning of the // function. SmallVector promotedArgsInfo; getMemRefArgIndicesAndTypes(funcOp.getType(), promotedArgsInfo); auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); if (newFuncOp.getBody().empty()) { rewriter.eraseOp(op); return success(); } // Promote bare pointers from MemRef arguments to a MemRef descriptor struct // at the beginning of the function so that all the MemRefs in the function // have a uniform representation. Block *firstBlock = &newFuncOp.getBody().front(); rewriter.setInsertionPoint(firstBlock, firstBlock->begin()); auto funcLoc = funcOp.getLoc(); for (const auto &argInfo : promotedArgsInfo) { // TODO: Add support for unranked MemRefs. if (auto memrefType = argInfo.second.dyn_cast()) { // Replace argument with a placeholder (undef), promote argument to a // MemRef descriptor and replace placeholder with the last instruction // of the MemRef descriptor. The placeholder is needed to avoid // replacing argument uses in the MemRef descriptor instructions. BlockArgument arg = firstBlock->getArgument(argInfo.first); Value placeHolder = rewriter.create(funcLoc, arg.getType()); rewriter.replaceUsesOfBlockArgument(arg, placeHolder); auto desc = MemRefDescriptor::fromStaticShape( rewriter, funcLoc, typeConverter, memrefType, arg); rewriter.replaceOp(placeHolder.getDefiningOp(), {desc}); } } rewriter.eraseOp(op); return success(); } }; //////////////// Support for Lowering operations on n-D vectors //////////////// // Helper struct to "unroll" operations on n-D vectors in terms of operations on // 1-D LLVM vectors. struct NDVectorTypeInfo { // LLVM array struct which encodes n-D vectors. LLVM::LLVMType llvmArrayTy; // LLVM vector type which encodes the inner 1-D vector type. LLVM::LLVMType llvmVectorTy; // Multiplicity of llvmArrayTy to llvmVectorTy. SmallVector arraySizes; }; } // namespace // For >1-D vector types, extracts the necessary information to iterate over all // 1-D subvectors in the underlying llrepresentation of the n-D vector // Iterates on the llvm array type until we hit a non-array type (which is // asserted to be an llvm vector type). static NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, LLVMTypeConverter &converter) { assert(vectorType.getRank() > 1 && "expected >1D vector type"); NDVectorTypeInfo info; info.llvmArrayTy = converter.convertType(vectorType).dyn_cast(); if (!info.llvmArrayTy) return info; info.arraySizes.reserve(vectorType.getRank() - 1); auto llvmTy = info.llvmArrayTy; while (llvmTy.isArrayTy()) { info.arraySizes.push_back(llvmTy.getArrayNumElements()); llvmTy = llvmTy.getArrayElementType(); } if (!llvmTy.isVectorTy()) return info; info.llvmVectorTy = llvmTy; return info; } // Express `linearIndex` in terms of coordinates of `basis`. // Returns the empty vector when linearIndex is out of the range [0, P] where // P is the product of all the basis coordinates. // // Prerequisites: // Basis is an array of nonnegative integers (signed type inherited from // vector shape type). static SmallVector getCoordinates(ArrayRef basis, unsigned linearIndex) { SmallVector res; res.reserve(basis.size()); for (unsigned basisElement : llvm::reverse(basis)) { res.push_back(linearIndex % basisElement); linearIndex = linearIndex / basisElement; } if (linearIndex > 0) return {}; std::reverse(res.begin(), res.end()); return res; } // Iterate of linear index, convert to coords space and insert splatted 1-D // vector in each position. template void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, Lambda fun) { unsigned ub = 1; for (auto s : info.arraySizes) ub *= s; for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) { auto coords = getCoordinates(info.arraySizes, linearIndex); // Linear index is out of bounds, we are done. if (coords.empty()) break; assert(coords.size() == info.arraySizes.size()); auto position = builder.getI64ArrayAttr(coords); fun(position); } } ////////////// End Support for Lowering operations on n-D vectors ////////////// /// Replaces the given operation "op" with a new operation of type "targetOp" /// and given operands. LogicalResult LLVM::detail::oneToOneRewrite( Operation *op, StringRef targetOp, ValueRange operands, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { unsigned numResults = op->getNumResults(); Type packedType; if (numResults != 0) { packedType = typeConverter.packFunctionResults(op->getResultTypes()); if (!packedType) return failure(); } // Create the operation through state since we don't know its C++ type. OperationState state(op->getLoc(), targetOp); state.addTypes(packedType); state.addOperands(operands); state.addAttributes(op->getAttrs()); Operation *newOp = rewriter.createOperation(state); // If the operation produced 0 or 1 result, return them immediately. if (numResults == 0) return rewriter.eraseOp(op), success(); if (numResults == 1) return rewriter.replaceOp(op, newOp->getResult(0)), success(); // Otherwise, it had been converted to an operation producing a structure. // Extract individual results from the structure and return them as list. SmallVector results; results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { auto type = typeConverter.convertType(op->getResult(i).getType()); results.push_back(rewriter.create( op->getLoc(), type, newOp->getResult(0), rewriter.getI64ArrayAttr(i))); } rewriter.replaceOp(op, results); return success(); } static LogicalResult handleMultidimensionalVectors( Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter, std::function createOperand, ConversionPatternRewriter &rewriter) { auto vectorType = op->getResult(0).getType().dyn_cast(); if (!vectorType) return failure(); auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, typeConverter); auto llvmVectorTy = vectorTypeInfo.llvmVectorTy; auto llvmArrayTy = operands[0].getType().cast(); if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy) return failure(); auto loc = op->getLoc(); Value desc = rewriter.create(loc, llvmArrayTy); nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) { // For this unrolled `position` corresponding to the `linearIndex`^th // element, extract operand vectors SmallVector extractedOperands; for (auto operand : operands) extractedOperands.push_back(rewriter.create( loc, llvmVectorTy, operand, position)); Value newVal = createOperand(llvmVectorTy, extractedOperands); desc = rewriter.create(loc, llvmArrayTy, desc, newVal, position); }); rewriter.replaceOp(op, desc); return success(); } LogicalResult LLVM::detail::vectorOneToOneRewrite( Operation *op, StringRef targetOp, ValueRange operands, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { assert(!operands.empty()); // Cannot convert ops if their operands are not of LLVM type. if (!llvm::all_of(operands.getTypes(), [](Type t) { return t.isa(); })) return failure(); auto llvmArrayTy = operands[0].getType().cast(); if (!llvmArrayTy.isArrayTy()) return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter); auto callback = [op, targetOp, &rewriter](LLVM::LLVMType llvmVectorTy, ValueRange operands) { OperationState state(op->getLoc(), targetOp); state.addTypes(llvmVectorTy); state.addOperands(operands); state.addAttributes(op->getAttrs()); return rewriter.createOperation(state)->getResult(0); }; return handleMultidimensionalVectors(op, operands, typeConverter, callback, rewriter); } namespace { // Straightforward lowerings. using AbsFOpLowering = VectorConvertToLLVMPattern; using AddFOpLowering = VectorConvertToLLVMPattern; using AddIOpLowering = VectorConvertToLLVMPattern; using AndOpLowering = VectorConvertToLLVMPattern; using CeilFOpLowering = VectorConvertToLLVMPattern; using ConstLLVMOpLowering = OneToOneConvertToLLVMPattern; using CopySignOpLowering = VectorConvertToLLVMPattern; using CosOpLowering = VectorConvertToLLVMPattern; using DivFOpLowering = VectorConvertToLLVMPattern; using ExpOpLowering = VectorConvertToLLVMPattern; using Exp2OpLowering = VectorConvertToLLVMPattern; using Log10OpLowering = VectorConvertToLLVMPattern; using Log2OpLowering = VectorConvertToLLVMPattern; using LogOpLowering = VectorConvertToLLVMPattern; using MulFOpLowering = VectorConvertToLLVMPattern; using MulIOpLowering = VectorConvertToLLVMPattern; using NegFOpLowering = VectorConvertToLLVMPattern; using OrOpLowering = VectorConvertToLLVMPattern; using RemFOpLowering = VectorConvertToLLVMPattern; using SelectOpLowering = OneToOneConvertToLLVMPattern; using ShiftLeftOpLowering = OneToOneConvertToLLVMPattern; using SignedDivIOpLowering = VectorConvertToLLVMPattern; using SignedRemIOpLowering = VectorConvertToLLVMPattern; using SignedShiftRightOpLowering = OneToOneConvertToLLVMPattern; using SqrtOpLowering = VectorConvertToLLVMPattern; using SubFOpLowering = VectorConvertToLLVMPattern; using SubIOpLowering = VectorConvertToLLVMPattern; using UnsignedDivIOpLowering = VectorConvertToLLVMPattern; using UnsignedRemIOpLowering = VectorConvertToLLVMPattern; using UnsignedShiftRightOpLowering = OneToOneConvertToLLVMPattern; using XOrOpLowering = VectorConvertToLLVMPattern; // Check if the MemRefType `type` is supported by the lowering. We currently // only support memrefs with identity maps. static bool isSupportedMemRefType(MemRefType type) { return type.getAffineMaps().empty() || llvm::all_of(type.getAffineMaps(), [](AffineMap map) { return map.isIdentity(); }); } /// Lowering for AllocOp and AllocaOp. template struct AllocLikeOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::createIndexConstant; using ConvertOpToLLVMPattern::getIndexType; using ConvertOpToLLVMPattern::typeConverter; using ConvertOpToLLVMPattern::getVoidPtrType; explicit AllocLikeOpLowering(LLVMTypeConverter &converter, bool useAlignedAlloc = false) : ConvertOpToLLVMPattern(converter), useAlignedAlloc(useAlignedAlloc) {} LogicalResult match(Operation *op) const override { MemRefType memRefType = cast(op).getType(); if (isSupportedMemRefType(memRefType)) return success(); int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(memRefType, strides, offset); if (failed(successStrides)) return failure(); // Dynamic strides are ok if they can be deduced from dynamic sizes (which // is guaranteed when succeeded(successStrides)). Dynamic offset however can // never be alloc'ed. if (offset == MemRefType::getDynamicStrideOrOffset()) return failure(); return success(); } // Returns bump = (alignment - (input % alignment))% alignment, which is the // increment necessary to align `input` to `alignment` boundary. // TODO: this can be made more efficient by just using a single addition // and two bit shifts: (ptr + align - 1)/align, align is always power of 2. Value createBumpToAlign(Location loc, OpBuilder b, Value input, Value alignment) const { Value modAlign = b.create(loc, input, alignment); Value diff = b.create(loc, alignment, modAlign); Value shift = b.create(loc, diff, alignment); return shift; } /// Creates and populates the memref descriptor struct given all its fields. /// This method also performs any post allocation alignment needed for heap /// allocations when `accessAlignment` is non null. This is used with /// allocators that do not support alignment. MemRefDescriptor createMemRefDescriptor( Location loc, ConversionPatternRewriter &rewriter, MemRefType memRefType, Value allocatedTypePtr, Value allocatedBytePtr, Value accessAlignment, uint64_t offset, ArrayRef strides, ArrayRef sizes) const { auto elementPtrType = getElementPtrType(memRefType); auto structType = typeConverter.convertType(memRefType); auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType); // Field 1: Allocated pointer, used for malloc/free. memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedTypePtr); // Field 2: Actual aligned pointer to payload. Value alignedBytePtr = allocatedTypePtr; if (accessAlignment) { // offset = (align - (ptr % align))% align Value intVal = rewriter.create( loc, this->getIndexType(), allocatedBytePtr); Value offset = createBumpToAlign(loc, rewriter, intVal, accessAlignment); Value aligned = rewriter.create( loc, allocatedBytePtr.getType(), allocatedBytePtr, offset); alignedBytePtr = rewriter.create( loc, elementPtrType, ArrayRef(aligned)); } memRefDescriptor.setAlignedPtr(rewriter, loc, alignedBytePtr); // Field 3: Offset in aligned pointer. memRefDescriptor.setOffset(rewriter, loc, createIndexConstant(rewriter, loc, offset)); if (memRefType.getRank() == 0) // No size/stride descriptor in memref, return the descriptor value. return memRefDescriptor; // Fields 4 and 5: sizes and strides of the strided MemRef. // Store all sizes in the descriptor. Only dynamic sizes are passed in as // operands to AllocOp. Value runningStride = nullptr; // Iterate strides in reverse order, compute runningStride and strideValues. auto nStrides = strides.size(); SmallVector strideValues(nStrides, nullptr); for (unsigned i = 0; i < nStrides; ++i) { int64_t index = nStrides - 1 - i; if (strides[index] == MemRefType::getDynamicStrideOrOffset()) // Identity layout map is enforced in the match function, so we compute: // `runningStride *= sizes[index + 1]` runningStride = runningStride ? rewriter.create(loc, runningStride, sizes[index + 1]) : createIndexConstant(rewriter, loc, 1); else runningStride = createIndexConstant(rewriter, loc, strides[index]); strideValues[index] = runningStride; } // Fill size and stride descriptors in memref. for (auto indexedSize : llvm::enumerate(sizes)) { int64_t index = indexedSize.index(); memRefDescriptor.setSize(rewriter, loc, index, indexedSize.value()); memRefDescriptor.setStride(rewriter, loc, index, strideValues[index]); } return memRefDescriptor; } /// Determines sizes to be used in the memref descriptor. void getSizes(Location loc, MemRefType memRefType, ArrayRef operands, ConversionPatternRewriter &rewriter, SmallVectorImpl &sizes, Value &cumulativeSize, Value &one) const { sizes.reserve(memRefType.getRank()); unsigned i = 0; for (int64_t s : memRefType.getShape()) sizes.push_back(s == -1 ? operands[i++] : createIndexConstant(rewriter, loc, s)); if (sizes.empty()) sizes.push_back(createIndexConstant(rewriter, loc, 1)); // Compute the total number of memref elements. cumulativeSize = sizes.front(); for (unsigned i = 1, e = sizes.size(); i < e; ++i) cumulativeSize = rewriter.create( loc, getIndexType(), ArrayRef{cumulativeSize, sizes[i]}); // Compute the size of an individual element. This emits the MLIR equivalent // of the following sizeof(...) implementation in LLVM IR: // %0 = getelementptr %elementType* null, %indexType 1 // %1 = ptrtoint %elementType* %0 to %indexType // which is a common pattern of getting the size of a type in bytes. auto elementType = memRefType.getElementType(); auto convertedPtrType = typeConverter.convertType(elementType) .template cast() .getPointerTo(); auto nullPtr = rewriter.create(loc, convertedPtrType); one = createIndexConstant(rewriter, loc, 1); auto gep = rewriter.create(loc, convertedPtrType, ArrayRef{nullPtr, one}); auto elementSize = rewriter.create(loc, getIndexType(), gep); cumulativeSize = rewriter.create( loc, getIndexType(), ArrayRef{cumulativeSize, elementSize}); } /// Returns the type of a pointer to an element of the memref. Type getElementPtrType(MemRefType memRefType) const { auto elementType = memRefType.getElementType(); auto structElementType = typeConverter.convertType(elementType); return structElementType.template cast().getPointerTo( memRefType.getMemorySpace()); } /// Returns the memref's element size in bytes. // TODO: there are other places where this is used. Expose publicly? static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { auto elementType = memRefType.getElementType(); unsigned sizeInBits; if (elementType.isIntOrFloat()) { sizeInBits = elementType.getIntOrFloatBitWidth(); } else { auto vectorType = elementType.cast(); sizeInBits = vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); } return llvm::divideCeil(sizeInBits, 8); } /// Returns the alignment to be used for the allocation call itself. /// aligned_alloc requires the allocation size to be a power of two, and the /// allocation size to be a multiple of alignment, Optional getAllocationAlignment(AllocOp allocOp) const { // No alignment can be used for the 'malloc' call itself. if (!useAlignedAlloc) return None; if (allocOp.alignment()) return allocOp.alignment().getValue().getSExtValue(); // Whenever we don't have alignment set, we will use an alignment // consistent with the element type; since the allocation size has to be a // power of two, we will bump to the next power of two if it already isn't. auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType()); return std::max(kMinAlignedAllocAlignment, llvm::PowerOf2Ceil(eltSizeBytes)); } /// Returns true if the memref size in bytes is known to be a multiple of /// factor. static bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor) { uint64_t sizeDivisor = getMemRefEltSizeInBytes(type); for (unsigned i = 0, e = type.getRank(); i < e; i++) { if (type.isDynamic(type.getDimSize(i))) continue; sizeDivisor = sizeDivisor * type.getDimSize(i); } return sizeDivisor % factor == 0; } /// Allocates the underlying buffer using the right call. `allocatedBytePtr` /// is set to null for stack allocations. `accessAlignment` is set if /// alignment is neeeded post allocation (for eg. in conjunction with malloc). Value allocateBuffer(Location loc, Value cumulativeSize, Operation *op, MemRefType memRefType, Value one, Value &accessAlignment, Value &allocatedBytePtr, ConversionPatternRewriter &rewriter) const { auto elementPtrType = getElementPtrType(memRefType); // With alloca, one gets a pointer to the element type right away. // For stack allocations. if (auto allocaOp = dyn_cast(op)) { allocatedBytePtr = nullptr; accessAlignment = nullptr; return rewriter.create( loc, elementPtrType, cumulativeSize, allocaOp.alignment() ? allocaOp.alignment().getValue().getSExtValue() : 0); } // Heap allocations. AllocOp allocOp = cast(op); Optional allocationAlignment = getAllocationAlignment(allocOp); // Whether to use std lib function aligned_alloc that supports alignment. bool useAlignedAlloc = allocationAlignment.hasValue(); // Insert the malloc/aligned_alloc declaration if it is not already present. auto allocFuncName = useAlignedAlloc ? "aligned_alloc" : "malloc"; auto module = allocOp.getParentOfType(); auto allocFunc = module.lookupSymbol(allocFuncName); if (!allocFunc) { OpBuilder moduleBuilder(op->getParentOfType().getBodyRegion()); SmallVector callArgTypes = {getIndexType()}; // aligned_alloc(size_t alignment, size_t size) if (useAlignedAlloc) callArgTypes.push_back(getIndexType()); allocFunc = moduleBuilder.create( rewriter.getUnknownLoc(), allocFuncName, LLVM::LLVMType::getFunctionTy(getVoidPtrType(), callArgTypes, /*isVarArg=*/false)); } // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. SmallVector callArgs; if (useAlignedAlloc) { // Use aligned_alloc. assert(allocationAlignment && "allocation alignment should be present"); auto alignedAllocAlignmentValue = rewriter.create( loc, typeConverter.convertType(rewriter.getIntegerType(64)), rewriter.getI64IntegerAttr(allocationAlignment.getValue())); // aligned_alloc requires size to be a multiple of alignment; we will pad // the size to the next multiple if necessary. if (!isMemRefSizeMultipleOf(memRefType, allocationAlignment.getValue())) { Value bump = createBumpToAlign(loc, rewriter, cumulativeSize, alignedAllocAlignmentValue); cumulativeSize = rewriter.create(loc, cumulativeSize, bump); } callArgs = {alignedAllocAlignmentValue, cumulativeSize}; } else { // Adjust the allocation size to consider alignment. if (allocOp.alignment()) { accessAlignment = createIndexConstant( rewriter, loc, allocOp.alignment().getValue().getSExtValue()); cumulativeSize = rewriter.create( loc, rewriter.create(loc, cumulativeSize, accessAlignment), one); } callArgs.push_back(cumulativeSize); } auto allocFuncSymbol = rewriter.getSymbolRefAttr(allocFunc); allocatedBytePtr = rewriter .create(loc, getVoidPtrType(), allocFuncSymbol, callArgs) .getResult(0); // For heap allocations, the allocated pointer is a cast of the byte pointer // to the type pointer. return rewriter.create(loc, elementPtrType, allocatedBytePtr); } // An `alloc` is converted into a definition of a memref descriptor value and // a call to `malloc` to allocate the underlying data buffer. The memref // descriptor is of the LLVM structure type where: // 1. the first element is a pointer to the allocated (typed) data buffer, // 2. the second element is a pointer to the (typed) payload, aligned to the // specified alignment, // 3. the remaining elements serve to store all the sizes and strides of the // memref using LLVM-converted `index` type. // // Alignment is performed by allocating `alignment - 1` more bytes than // requested and shifting the aligned pointer relative to the allocated // memory. If alignment is unspecified, the two pointers are equal. // An `alloca` is converted into a definition of a memref descriptor value and // an llvm.alloca to allocate the underlying data buffer. void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { MemRefType memRefType = cast(op).getType(); auto loc = op->getLoc(); // Get actual sizes of the memref as values: static sizes are constant // values and dynamic sizes are passed to 'alloc' as operands. In case of // zero-dimensional memref, assume a scalar (size 1). SmallVector sizes; Value cumulativeSize, one; getSizes(loc, memRefType, operands, rewriter, sizes, cumulativeSize, one); // Allocate the underlying buffer. // Value holding the alignment that has to be performed post allocation // (in conjunction with allocators that do not support alignment, eg. // malloc); nullptr if no such adjustment needs to be performed. Value accessAlignment; // Byte pointer to the allocated buffer. Value allocatedBytePtr; Value allocatedTypePtr = allocateBuffer(loc, cumulativeSize, op, memRefType, one, accessAlignment, allocatedBytePtr, rewriter); int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(memRefType, strides, offset); (void)successStrides; assert(succeeded(successStrides) && "unexpected non-strided memref"); assert(offset != MemRefType::getDynamicStrideOrOffset() && "unexpected dynamic offset"); // 0-D memref corner case: they have size 1. assert( ((memRefType.getRank() == 0 && strides.empty() && sizes.size() == 1) || (strides.size() == sizes.size())) && "unexpected number of strides"); // Create the MemRef descriptor. auto memRefDescriptor = createMemRefDescriptor( loc, rewriter, memRefType, allocatedTypePtr, allocatedBytePtr, accessAlignment, offset, strides, sizes); // Return the final value of the descriptor. rewriter.replaceOp(op, {memRefDescriptor}); } protected: /// Use aligned_alloc instead of malloc for all heap allocations. bool useAlignedAlloc; /// The minimum alignment to use with aligned_alloc (has to be a power of 2). uint64_t kMinAlignedAllocAlignment = 16UL; }; struct AllocOpLowering : public AllocLikeOpLowering { explicit AllocOpLowering(LLVMTypeConverter &converter, bool useAlignedAlloc = false) : AllocLikeOpLowering(converter, useAlignedAlloc) {} }; using AllocaOpLowering = AllocLikeOpLowering; // A CallOp automatically promotes MemRefType to a sequence of alloca/store and // passes the pointer to the MemRef across function boundaries. template struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using Super = CallOpInterfaceLowering; using Base = ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { OperandAdaptor transformed(operands); auto callOp = cast(op); // Pack the result types into a struct. Type packedResult; unsigned numResults = callOp.getNumResults(); auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); for (Type resType : resultTypes) { assert(!resType.isa() && "Returning unranked memref is not supported. Pass result as an" "argument instead."); (void)resType; } if (numResults != 0) { if (!(packedResult = this->typeConverter.packFunctionResults(resultTypes))) return failure(); } auto promoted = this->typeConverter.promoteMemRefDescriptors( op->getLoc(), /*opOperands=*/op->getOperands(), operands, rewriter); auto newOp = rewriter.create(op->getLoc(), packedResult, promoted, op->getAttrs()); // If < 2 results, packing did not do anything and we can just return. if (numResults < 2) { rewriter.replaceOp(op, newOp.getResults()); return success(); } // Otherwise, it had been converted to an operation producing a structure. // Extract individual results from the structure and return them as list. // TODO(aminim, ntv, riverriddle, zinenko): this seems like patching around // a particular interaction between MemRefType and CallOp lowering. Find a // way to avoid special casing. SmallVector results; results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { auto type = this->typeConverter.convertType(op->getResult(i).getType()); results.push_back(rewriter.create( op->getLoc(), type, newOp.getOperation()->getResult(0), rewriter.getI64ArrayAttr(i))); } rewriter.replaceOp(op, results); return success(); } }; struct CallOpLowering : public CallOpInterfaceLowering { using Super::Super; }; struct CallIndirectOpLowering : public CallOpInterfaceLowering { using Super::Super; }; // A `dealloc` is converted into a call to `free` on the underlying data buffer. // The memref descriptor being an SSA value, there is no need to clean it up // in any way. struct DeallocOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; explicit DeallocOpLowering(LLVMTypeConverter &converter) : ConvertOpToLLVMPattern(converter) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { assert(operands.size() == 1 && "dealloc takes one operand"); OperandAdaptor transformed(operands); // Insert the `free` declaration if it is not already present. auto freeFunc = op->getParentOfType().lookupSymbol("free"); if (!freeFunc) { OpBuilder moduleBuilder(op->getParentOfType().getBodyRegion()); freeFunc = moduleBuilder.create( rewriter.getUnknownLoc(), "free", LLVM::LLVMType::getFunctionTy(getVoidType(), getVoidPtrType(), /*isVarArg=*/false)); } MemRefDescriptor memref(transformed.memref()); Value casted = rewriter.create( op->getLoc(), getVoidPtrType(), memref.allocatedPtr(rewriter, op->getLoc())); rewriter.replaceOpWithNewOp( op, ArrayRef(), rewriter.getSymbolRefAttr(freeFunc), casted); return success(); } }; // A `rsqrt` is converted into `1 / sqrt`. struct RsqrtOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { OperandAdaptor transformed(operands); auto operandType = transformed.operand().getType().dyn_cast(); if (!operandType) return failure(); auto loc = op->getLoc(); auto resultType = *op->result_type_begin(); auto floatType = getElementTypeOrSelf(resultType).cast(); auto floatOne = rewriter.getFloatAttr(floatType, 1.0); if (!operandType.isArrayTy()) { LLVM::ConstantOp one; if (operandType.isVectorTy()) { one = rewriter.create( loc, operandType, SplatElementsAttr::get(resultType.cast(), floatOne)); } else { one = rewriter.create(loc, operandType, floatOne); } auto sqrt = rewriter.create(loc, transformed.operand()); rewriter.replaceOpWithNewOp(op, operandType, one, sqrt); return success(); } auto vectorType = resultType.dyn_cast(); if (!vectorType) return failure(); return handleMultidimensionalVectors( op, operands, typeConverter, [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) { auto splatAttr = SplatElementsAttr::get( - mlir::VectorType::get({(unsigned)cast( - llvmVectorTy.getUnderlyingType()) - ->getNumElements()}, - floatType), + mlir::VectorType::get( + {cast(llvmVectorTy.getUnderlyingType()) + ->getNumElements()}, + floatType), floatOne); auto one = rewriter.create(loc, llvmVectorTy, splatAttr); auto sqrt = rewriter.create(loc, llvmVectorTy, operands[0]); return rewriter.create(loc, llvmVectorTy, one, sqrt); }, rewriter); } }; struct MemRefCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult match(Operation *op) const override { auto memRefCastOp = cast(op); Type srcType = memRefCastOp.getOperand().getType(); Type dstType = memRefCastOp.getType(); if (srcType.isa() && dstType.isa()) { MemRefType sourceType = memRefCastOp.getOperand().getType().cast(); MemRefType targetType = memRefCastOp.getType().cast(); return (isSupportedMemRefType(targetType) && isSupportedMemRefType(sourceType)) ? success() : failure(); } // At least one of the operands is unranked type assert(srcType.isa() || dstType.isa()); // Unranked to unranked cast is disallowed return !(srcType.isa() && dstType.isa()) ? success() : failure(); } void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto memRefCastOp = cast(op); OperandAdaptor transformed(operands); auto srcType = memRefCastOp.getOperand().getType(); auto dstType = memRefCastOp.getType(); auto targetStructType = typeConverter.convertType(memRefCastOp.getType()); auto loc = op->getLoc(); if (srcType.isa() && dstType.isa()) { // memref_cast is defined for source and destination memref types with the // same element type, same mappings, same address space and same rank. // Therefore a simple bitcast suffices. If not it is undefined behavior. rewriter.replaceOpWithNewOp(op, targetStructType, transformed.source()); } else if (srcType.isa() && dstType.isa()) { // Casting ranked to unranked memref type // Set the rank in the destination from the memref type // Allocate space on the stack and copy the src memref descriptor // Set the ptr in the destination to the stack space auto srcMemRefType = srcType.cast(); int64_t rank = srcMemRefType.getRank(); // ptr = AllocaOp sizeof(MemRefDescriptor) auto ptr = typeConverter.promoteOneMemRefDescriptor( loc, transformed.source(), rewriter); // voidptr = BitCastOp srcType* to void* auto voidPtr = rewriter.create(loc, getVoidPtrType(), ptr) .getResult(); // rank = ConstantOp srcRank auto rankVal = rewriter.create( loc, typeConverter.convertType(rewriter.getIntegerType(64)), rewriter.getI64IntegerAttr(rank)); // undef = UndefOp UnrankedMemRefDescriptor memRefDesc = UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType); // d1 = InsertValueOp undef, rank, 0 memRefDesc.setRank(rewriter, loc, rankVal); // d2 = InsertValueOp d1, voidptr, 1 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); rewriter.replaceOp(op, (Value)memRefDesc); } else if (srcType.isa() && dstType.isa()) { // Casting from unranked type to ranked. // The operation is assumed to be doing a correct cast. If the destination // type mismatches the unranked the type, it is undefined behavior. UnrankedMemRefDescriptor memRefDesc(transformed.source()); // ptr = ExtractValueOp src, 1 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); // castPtr = BitCastOp i8* to structTy* auto castPtr = rewriter .create( loc, targetStructType.cast().getPointerTo(), ptr) .getResult(); // struct = LoadOp castPtr auto loadOp = rewriter.create(loc, castPtr); rewriter.replaceOp(op, loadOp.getResult()); } else { llvm_unreachable("Unsupported unranked memref to unranked memref cast"); } } }; struct DialectCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto castOp = cast(op); OperandAdaptor transformed(operands); if (transformed.in().getType() != typeConverter.convertType(castOp.getType())) { return failure(); } rewriter.replaceOp(op, transformed.in()); return success(); } }; // A `dim` is converted to a constant for static sizes and to an access to the // size stored in the memref descriptor for dynamic sizes. struct DimOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dimOp = cast(op); OperandAdaptor transformed(operands); MemRefType type = dimOp.getOperand().getType().cast(); int64_t index = dimOp.getIndex(); // Extract dynamic size from the memref descriptor. if (type.isDynamicDim(index)) rewriter.replaceOp(op, {MemRefDescriptor(transformed.memrefOrTensor()) .size(rewriter, op->getLoc(), index)}); else // Use constant for static size. rewriter.replaceOp(op, createIndexConstant(rewriter, op->getLoc(), type.getDimSize(index))); return success(); } }; // Common base for load and store operations on MemRefs. Restricts the match // to supported MemRef types. Provides functionality to emit code accessing a // specific element of the underlying data buffer. template struct LoadStoreOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using Base = LoadStoreOpLowering; LogicalResult match(Operation *op) const override { MemRefType type = cast(op).getMemRefType(); return isSupportedMemRefType(type) ? success() : failure(); } }; // Load operation is lowered to obtaining a pointer to the indexed element // and loading it. struct LoadOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loadOp = cast(op); OperandAdaptor transformed(operands); auto type = loadOp.getMemRefType(); Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), transformed.indices(), rewriter, getModule()); rewriter.replaceOpWithNewOp(op, dataPtr); return success(); } }; // Store operation is lowered to obtaining a pointer to the indexed element, // and storing the given value to it. struct StoreOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto type = cast(op).getMemRefType(); OperandAdaptor transformed(operands); Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), transformed.indices(), rewriter, getModule()); rewriter.replaceOpWithNewOp(op, transformed.value(), dataPtr); return success(); } }; // The prefetch operation is lowered in a way similar to the load operation // except that the llvm.prefetch operation is used for replacement. struct PrefetchOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto prefetchOp = cast(op); OperandAdaptor transformed(operands); auto type = prefetchOp.getMemRefType(); Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), transformed.indices(), rewriter, getModule()); // Replace with llvm.prefetch. auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32)); auto isWrite = rewriter.create( op->getLoc(), llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite())); auto localityHint = rewriter.create( op->getLoc(), llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.localityHint().getZExtValue())); auto isData = rewriter.create( op->getLoc(), llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache())); rewriter.replaceOpWithNewOp(op, dataPtr, isWrite, localityHint, isData); return success(); } }; // The lowering of index_cast becomes an integer conversion since index becomes // an integer. If the bit width of the source and target integer types is the // same, just erase the cast. If the target type is wider, sign-extend the // value, otherwise truncate it. struct IndexCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { IndexCastOpOperandAdaptor transformed(operands); auto indexCastOp = cast(op); auto targetType = this->typeConverter.convertType(indexCastOp.getResult().getType()) .cast(); auto sourceType = transformed.in().getType().cast(); unsigned targetBits = targetType.getUnderlyingType()->getIntegerBitWidth(); unsigned sourceBits = sourceType.getUnderlyingType()->getIntegerBitWidth(); if (targetBits == sourceBits) rewriter.replaceOp(op, transformed.in()); else if (targetBits < sourceBits) rewriter.replaceOpWithNewOp(op, targetType, transformed.in()); else rewriter.replaceOpWithNewOp(op, targetType, transformed.in()); return success(); } }; // Convert std.cmp predicate into the LLVM dialect CmpPredicate. The two // enums share the numerical values so just cast. template static LLVMPredType convertCmpPredicate(StdPredType pred) { return static_cast(pred); } struct CmpIOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto cmpiOp = cast(op); CmpIOpOperandAdaptor transformed(operands); rewriter.replaceOpWithNewOp( op, typeConverter.convertType(cmpiOp.getResult().getType()), rewriter.getI64IntegerAttr(static_cast( convertCmpPredicate(cmpiOp.getPredicate()))), transformed.lhs(), transformed.rhs()); return success(); } }; struct CmpFOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto cmpfOp = cast(op); CmpFOpOperandAdaptor transformed(operands); rewriter.replaceOpWithNewOp( op, typeConverter.convertType(cmpfOp.getResult().getType()), rewriter.getI64IntegerAttr(static_cast( convertCmpPredicate(cmpfOp.getPredicate()))), transformed.lhs(), transformed.rhs()); return success(); } }; struct SIToFPLowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; struct FPExtLowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; struct FPTruncLowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; struct SignExtendIOpLowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; struct TruncateIOpLowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; struct ZeroExtendIOpLowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; // Base class for LLVM IR lowering terminator operations with successors. template struct OneToOneLLVMTerminatorLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using Super = OneToOneLLVMTerminatorLowering; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, operands, op->getSuccessors(), op->getAttrs()); return success(); } }; // Special lowering pattern for `ReturnOps`. Unlike all other operations, // `ReturnOp` interacts with the function signature and must have as many // operands as the function has return values. Because in LLVM IR, functions // can only return 0 or 1 value, we pack multiple values into a structure type. // Emit `UndefOp` followed by `InsertValueOp`s to create such structure if // necessary before returning it struct ReturnOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { unsigned numArguments = op->getNumOperands(); // If ReturnOp has 0 or 1 operand, create it and return immediately. if (numArguments == 0) { rewriter.replaceOpWithNewOp( op, ArrayRef(), ArrayRef(), op->getAttrs()); return success(); } if (numArguments == 1) { rewriter.replaceOpWithNewOp( op, ArrayRef(), operands.front(), op->getAttrs()); return success(); } // Otherwise, we need to pack the arguments into an LLVM struct type before // returning. auto packedType = typeConverter.packFunctionResults( llvm::to_vector<4>(op->getOperandTypes())); Value packed = rewriter.create(op->getLoc(), packedType); for (unsigned i = 0; i < numArguments; ++i) { packed = rewriter.create( op->getLoc(), packedType, packed, operands[i], rewriter.getI64ArrayAttr(i)); } rewriter.replaceOpWithNewOp(op, ArrayRef(), packed, op->getAttrs()); return success(); } }; // FIXME: this should be tablegen'ed as well. struct BranchOpLowering : public OneToOneLLVMTerminatorLowering { using Super::Super; }; struct CondBranchOpLowering : public OneToOneLLVMTerminatorLowering { using Super::Super; }; // The Splat operation is lowered to an insertelement + a shufflevector // operation. Splat to only 1-d vector result types are lowered. struct SplatOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto splatOp = cast(op); VectorType resultType = splatOp.getType().dyn_cast(); if (!resultType || resultType.getRank() != 1) return failure(); // First insert it into an undef vector so we can shuffle it. auto vectorType = typeConverter.convertType(splatOp.getType()); Value undef = rewriter.create(op->getLoc(), vectorType); auto zero = rewriter.create( op->getLoc(), typeConverter.convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); auto v = rewriter.create( op->getLoc(), vectorType, undef, splatOp.getOperand(), zero); int64_t width = splatOp.getType().cast().getDimSize(0); SmallVector zeroValues(width, 0); // Shuffle the value across the desired number of elements. ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues); rewriter.replaceOpWithNewOp(op, v, undef, zeroAttrs); return success(); } }; // The Splat operation is lowered to an insertelement + a shufflevector // operation. Splat to only 2+-d vector result types are lowered by the // SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. struct SplatNdOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto splatOp = cast(op); OperandAdaptor adaptor(operands); VectorType resultType = splatOp.getType().dyn_cast(); if (!resultType || resultType.getRank() == 1) return failure(); // First insert it into an undef vector so we can shuffle it. auto loc = op->getLoc(); auto vectorTypeInfo = extractNDVectorTypeInfo(resultType, typeConverter); auto llvmArrayTy = vectorTypeInfo.llvmArrayTy; auto llvmVectorTy = vectorTypeInfo.llvmVectorTy; if (!llvmArrayTy || !llvmVectorTy) return failure(); // Construct returned value. Value desc = rewriter.create(loc, llvmArrayTy); // Construct a 1-D vector with the splatted value that we insert in all the // places within the returned descriptor. Value vdesc = rewriter.create(loc, llvmVectorTy); auto zero = rewriter.create( loc, typeConverter.convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); Value v = rewriter.create(loc, llvmVectorTy, vdesc, adaptor.input(), zero); // Shuffle the value across the desired number of elements. int64_t width = resultType.getDimSize(resultType.getRank() - 1); SmallVector zeroValues(width, 0); ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues); v = rewriter.create(loc, v, v, zeroAttrs); // Iterate of linear index, convert to coords space and insert splatted 1-D // vector in each position. nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) { desc = rewriter.create(loc, llvmArrayTy, desc, v, position); }); rewriter.replaceOp(op, desc); return success(); } }; /// Conversion pattern that transforms a subview op into: /// 1. An `llvm.mlir.undef` operation to create a memref descriptor /// 2. Updates to the descriptor to introduce the data ptr, offset, size /// and stride. /// The subview op is replaced by the descriptor. struct SubViewOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto viewOp = cast(op); // TODO(b/144779634, ravishankarm) : After Tblgen is adapted to support // having multiple variadic operands where each operand can have different // number of entries, clean all of this up. SmallVector dynamicOffsets( std::next(operands.begin()), std::next(operands.begin(), 1 + viewOp.getNumOffsets())); SmallVector dynamicSizes( std::next(operands.begin(), 1 + viewOp.getNumOffsets()), std::next(operands.begin(), 1 + viewOp.getNumOffsets() + viewOp.getNumSizes())); SmallVector dynamicStrides( std::next(operands.begin(), 1 + viewOp.getNumOffsets() + viewOp.getNumSizes()), operands.end()); auto sourceMemRefType = viewOp.source().getType().cast(); auto sourceElementTy = typeConverter.convertType(sourceMemRefType.getElementType()) .dyn_cast_or_null(); auto viewMemRefType = viewOp.getType(); auto targetElementTy = typeConverter.convertType(viewMemRefType.getElementType()) .dyn_cast(); auto targetDescTy = typeConverter.convertType(viewMemRefType) .dyn_cast_or_null(); if (!sourceElementTy || !targetDescTy) return failure(); // Currently, only rank > 0 and full or no operands are supported. Fail to // convert otherwise. unsigned rank = sourceMemRefType.getRank(); if (viewMemRefType.getRank() == 0 || (!dynamicOffsets.empty() && rank != dynamicOffsets.size()) || (!dynamicSizes.empty() && rank != dynamicSizes.size()) || (!dynamicStrides.empty() && rank != dynamicStrides.size())) return failure(); int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); if (failed(successStrides)) return failure(); // Fail to convert if neither a dynamic nor static offset is available. if (dynamicOffsets.empty() && offset == MemRefType::getDynamicStrideOrOffset()) return failure(); // Create the descriptor. if (!operands.front().getType().isa()) return failure(); MemRefDescriptor sourceMemRef(operands.front()); auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); // Copy the buffer pointer from the old descriptor to the new one. Value extracted = sourceMemRef.allocatedPtr(rewriter, loc); Value bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(), extracted); targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); extracted = sourceMemRef.alignedPtr(rewriter, loc); bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(), extracted); targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); // Extract strides needed to compute offset. SmallVector strideValues; strideValues.reserve(viewMemRefType.getRank()); for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); // Fill in missing dynamic sizes. auto llvmIndexType = typeConverter.convertType(rewriter.getIndexType()); if (dynamicSizes.empty()) { dynamicSizes.reserve(viewMemRefType.getRank()); auto shape = viewMemRefType.getShape(); for (auto extent : shape) { dynamicSizes.push_back(rewriter.create( loc, llvmIndexType, rewriter.getI64IntegerAttr(extent))); } } // Offset. if (dynamicOffsets.empty()) { targetMemRef.setConstantOffset(rewriter, loc, offset); } else { Value baseOffset = sourceMemRef.offset(rewriter, loc); for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) { Value min = dynamicOffsets[i]; baseOffset = rewriter.create( loc, baseOffset, rewriter.create(loc, min, strideValues[i])); } targetMemRef.setOffset(rewriter, loc, baseOffset); } // Update sizes and strides. for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { targetMemRef.setSize(rewriter, loc, i, dynamicSizes[i]); Value newStride; if (dynamicStrides.empty()) newStride = rewriter.create( loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); else newStride = rewriter.create(loc, dynamicStrides[i], strideValues[i]); targetMemRef.setStride(rewriter, loc, i, newStride); } rewriter.replaceOp(op, {targetMemRef}); return success(); } }; /// Conversion pattern that transforms an op into: /// 1. An `llvm.mlir.undef` operation to create a memref descriptor /// 2. Updates to the descriptor to introduce the data ptr, offset, size /// and stride. /// The view op is replaced by the descriptor. struct ViewOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; // Build and return the value for the idx^th shape dimension, either by // returning the constant shape dimension or counting the proper dynamic size. Value getSize(ConversionPatternRewriter &rewriter, Location loc, ArrayRef shape, ArrayRef dynamicSizes, unsigned idx) const { assert(idx < shape.size()); if (!ShapedType::isDynamic(shape[idx])) return createIndexConstant(rewriter, loc, shape[idx]); // Count the number of dynamic dims in range [0, idx] unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) { return ShapedType::isDynamic(v); }); return dynamicSizes[nDynamic]; } // Build and return the idx^th stride, either by returning the constant stride // or by computing the dynamic stride from the current `runningStride` and // `nextSize`. The caller should keep a running stride and update it with the // result returned by this function. Value getStride(ConversionPatternRewriter &rewriter, Location loc, ArrayRef strides, Value nextSize, Value runningStride, unsigned idx) const { assert(idx < strides.size()); if (strides[idx] != MemRefType::getDynamicStrideOrOffset()) return createIndexConstant(rewriter, loc, strides[idx]); if (nextSize) return runningStride ? rewriter.create(loc, runningStride, nextSize) : nextSize; assert(!runningStride); return createIndexConstant(rewriter, loc, 1); } LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto viewOp = cast(op); ViewOpOperandAdaptor adaptor(operands); auto viewMemRefType = viewOp.getType(); auto targetElementTy = typeConverter.convertType(viewMemRefType.getElementType()) .dyn_cast(); auto targetDescTy = typeConverter.convertType(viewMemRefType).dyn_cast(); if (!targetDescTy) return op->emitWarning("Target descriptor type not converted to LLVM"), failure(); int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); if (failed(successStrides)) return op->emitWarning("cannot cast to non-strided shape"), failure(); // Create the descriptor. MemRefDescriptor sourceMemRef(adaptor.source()); auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); // Field 1: Copy the allocated pointer, used for malloc/free. Value extracted = sourceMemRef.allocatedPtr(rewriter, loc); Value bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(), extracted); targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); // Field 2: Copy the actual aligned pointer to payload. extracted = sourceMemRef.alignedPtr(rewriter, loc); bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(), extracted); targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); // Field 3: Copy the offset in aligned pointer. unsigned numDynamicSizes = llvm::size(viewOp.getDynamicSizes()); (void)numDynamicSizes; bool hasDynamicOffset = offset == MemRefType::getDynamicStrideOrOffset(); auto sizeAndOffsetOperands = adaptor.operands(); assert(llvm::size(sizeAndOffsetOperands) == numDynamicSizes + (hasDynamicOffset ? 1 : 0)); Value baseOffset = !hasDynamicOffset ? createIndexConstant(rewriter, loc, offset) // TODO(ntv): better adaptor. : sizeAndOffsetOperands.front(); targetMemRef.setOffset(rewriter, loc, baseOffset); // Early exit for 0-D corner case. if (viewMemRefType.getRank() == 0) return rewriter.replaceOp(op, {targetMemRef}), success(); // Fields 4 and 5: Update sizes and strides. if (strides.back() != 1) return op->emitWarning("cannot cast to non-contiguous shape"), failure(); Value stride = nullptr, nextSize = nullptr; // Drop the dynamic stride from the operand list, if present. ArrayRef sizeOperands(sizeAndOffsetOperands); if (hasDynamicOffset) sizeOperands = sizeOperands.drop_front(); for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { // Update size. Value size = getSize(rewriter, loc, viewMemRefType.getShape(), sizeOperands, i); targetMemRef.setSize(rewriter, loc, i, size); // Update stride. stride = getStride(rewriter, loc, strides, nextSize, stride, i); targetMemRef.setStride(rewriter, loc, i, stride); nextSize = size; } rewriter.replaceOp(op, {targetMemRef}); return success(); } }; struct AssumeAlignmentOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { OperandAdaptor transformed(operands); Value memref = transformed.memref(); unsigned alignment = cast(op).alignment().getZExtValue(); MemRefDescriptor memRefDescriptor(memref); Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc()); // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that // the asserted memref.alignedPtr isn't used anywhere else, as the real // users like load/store/views always re-extract memref.alignedPtr as they // get lowered. // // This relies on LLVM's CSE optimization (potentially after SROA), since // after CSE all memref.alignedPtr instances get de-duplicated into the same // pointer SSA value. Value zero = createIndexAttrConstant(rewriter, op->getLoc(), getIndexType(), 0); Value mask = createIndexAttrConstant(rewriter, op->getLoc(), getIndexType(), alignment - 1); Value ptrValue = rewriter.create(op->getLoc(), getIndexType(), ptr); rewriter.create( op->getLoc(), rewriter.create( op->getLoc(), LLVM::ICmpPredicate::eq, rewriter.create(op->getLoc(), ptrValue, mask), zero)); rewriter.eraseOp(op); return success(); } }; } // namespace /// Try to match the kind of a std.atomic_rmw to determine whether to use a /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg. static Optional matchSimpleAtomicOp(AtomicRMWOp atomicOp) { switch (atomicOp.kind()) { case AtomicRMWKind::addf: return LLVM::AtomicBinOp::fadd; case AtomicRMWKind::addi: return LLVM::AtomicBinOp::add; case AtomicRMWKind::assign: return LLVM::AtomicBinOp::xchg; case AtomicRMWKind::maxs: return LLVM::AtomicBinOp::max; case AtomicRMWKind::maxu: return LLVM::AtomicBinOp::umax; case AtomicRMWKind::mins: return LLVM::AtomicBinOp::min; case AtomicRMWKind::minu: return LLVM::AtomicBinOp::umin; default: return llvm::None; } llvm_unreachable("Invalid AtomicRMWKind"); } namespace { struct AtomicRMWOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto atomicOp = cast(op); auto maybeKind = matchSimpleAtomicOp(atomicOp); if (!maybeKind) return failure(); OperandAdaptor adaptor(operands); auto resultType = adaptor.value().getType(); auto memRefType = atomicOp.getMemRefType(); auto dataPtr = getDataPtr(op->getLoc(), memRefType, adaptor.memref(), adaptor.indices(), rewriter, getModule()); rewriter.replaceOpWithNewOp( op, resultType, *maybeKind, dataPtr, adaptor.value(), LLVM::AtomicOrdering::acq_rel); return success(); } }; /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be /// retried until it succeeds in atomically storing a new value into memory. /// /// +---------------------------------+ /// | | /// | | /// | br loop(%loaded) | /// +---------------------------------+ /// | /// -------| | /// | v v /// | +--------------------------------+ /// | | loop(%loaded): | /// | | | /// | | %pair = cmpxchg | /// | | %ok = %pair[0] | /// | | %new = %pair[1] | /// | | cond_br %ok, end, loop(%new) | /// | +--------------------------------+ /// | | | /// |----------- | /// v /// +--------------------------------+ /// | end: | /// | | /// +--------------------------------+ /// struct AtomicCmpXchgOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto atomicOp = cast(op); auto maybeKind = matchSimpleAtomicOp(atomicOp); if (maybeKind) return failure(); LLVM::FCmpPredicate predicate; switch (atomicOp.kind()) { case AtomicRMWKind::maxf: predicate = LLVM::FCmpPredicate::ogt; break; case AtomicRMWKind::minf: predicate = LLVM::FCmpPredicate::olt; break; default: return failure(); } OperandAdaptor adaptor(operands); auto loc = op->getLoc(); auto valueType = adaptor.value().getType().cast(); // Split the block into initial, loop, and ending parts. auto *initBlock = rewriter.getInsertionBlock(); auto initPosition = rewriter.getInsertionPoint(); auto *loopBlock = rewriter.splitBlock(initBlock, initPosition); auto loopArgument = loopBlock->addArgument(valueType); auto loopPosition = rewriter.getInsertionPoint(); auto *endBlock = rewriter.splitBlock(loopBlock, loopPosition); // Compute the loaded value and branch to the loop block. rewriter.setInsertionPointToEnd(initBlock); auto memRefType = atomicOp.getMemRefType(); auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(), adaptor.indices(), rewriter, getModule()); Value init = rewriter.create(loc, dataPtr); rewriter.create(loc, init, loopBlock); // Prepare the body of the loop block. rewriter.setInsertionPointToStart(loopBlock); auto predicateI64 = rewriter.getI64IntegerAttr(static_cast(predicate)); auto boolType = LLVM::LLVMType::getInt1Ty(&getDialect()); auto lhs = loopArgument; auto rhs = adaptor.value(); auto cmp = rewriter.create(loc, boolType, predicateI64, lhs, rhs); auto select = rewriter.create(loc, cmp, lhs, rhs); // Prepare the epilog of the loop block. rewriter.setInsertionPointToEnd(loopBlock); // Append the cmpxchg op to the end of the loop block. auto successOrdering = LLVM::AtomicOrdering::acq_rel; auto failureOrdering = LLVM::AtomicOrdering::monotonic; auto pairType = LLVM::LLVMType::getStructTy(valueType, boolType); auto cmpxchg = rewriter.create( loc, pairType, dataPtr, loopArgument, select, successOrdering, failureOrdering); // Extract the %new_loaded and %ok values from the pair. Value newLoaded = rewriter.create( loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0})); Value ok = rewriter.create( loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1})); // Conditionally branch to the end or back to the loop depending on %ok. rewriter.create(loc, ok, endBlock, ArrayRef(), loopBlock, newLoaded); // The 'result' of the atomic_rmw op is the newly loaded value. rewriter.replaceOp(op, {newLoaded}); return success(); } }; } // namespace /// Collect a set of patterns to convert from the Standard dialect to LLVM. void mlir::populateStdToLLVMNonMemoryConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { // FIXME: this should be tablegen'ed // clang-format off patterns.insert< AbsFOpLowering, AddFOpLowering, AddIOpLowering, AllocaOpLowering, AndOpLowering, AtomicCmpXchgOpLowering, AtomicRMWOpLowering, BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CeilFOpLowering, CmpFOpLowering, CmpIOpLowering, CondBranchOpLowering, CopySignOpLowering, CosOpLowering, ConstLLVMOpLowering, DialectCastOpLowering, DivFOpLowering, ExpOpLowering, Exp2OpLowering, LogOpLowering, Log10OpLowering, Log2OpLowering, FPExtLowering, FPTruncLowering, IndexCastOpLowering, MulFOpLowering, MulIOpLowering, NegFOpLowering, OrOpLowering, PrefetchOpLowering, RemFOpLowering, ReturnOpLowering, RsqrtOpLowering, SIToFPLowering, SelectOpLowering, ShiftLeftOpLowering, SignExtendIOpLowering, SignedDivIOpLowering, SignedRemIOpLowering, SignedShiftRightOpLowering, SplatOpLowering, SplatNdOpLowering, SqrtOpLowering, SubFOpLowering, SubIOpLowering, TruncateIOpLowering, UnsignedDivIOpLowering, UnsignedRemIOpLowering, UnsignedShiftRightOpLowering, XOrOpLowering, ZeroExtendIOpLowering>(converter); // clang-format on } void mlir::populateStdToLLVMMemoryConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns, bool useAlignedAlloc) { // clang-format off patterns.insert< AssumeAlignmentOpLowering, DeallocOpLowering, DimOpLowering, LoadOpLowering, MemRefCastOpLowering, StoreOpLowering, SubViewOpLowering, ViewOpLowering>(converter); patterns.insert< AllocOpLowering >(converter, useAlignedAlloc); // clang-format on } void mlir::populateStdToLLVMDefaultFuncOpConversionPattern( LLVMTypeConverter &converter, OwningRewritePatternList &patterns, bool emitCWrappers) { patterns.insert(converter, emitCWrappers); } void mlir::populateStdToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns, bool emitCWrappers, bool useAlignedAlloc) { populateStdToLLVMDefaultFuncOpConversionPattern(converter, patterns, emitCWrappers); populateStdToLLVMNonMemoryConversionPatterns(converter, patterns); populateStdToLLVMMemoryConversionPatterns(converter, patterns, useAlignedAlloc); } static void populateStdToLLVMBarePtrFuncOpConversionPattern( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { patterns.insert(converter); } void mlir::populateStdToLLVMBarePtrConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns, bool useAlignedAlloc) { populateStdToLLVMBarePtrFuncOpConversionPattern(converter, patterns); populateStdToLLVMNonMemoryConversionPatterns(converter, patterns); populateStdToLLVMMemoryConversionPatterns(converter, patterns, useAlignedAlloc); } // Create an LLVM IR structure type if there is more than one result. Type LLVMTypeConverter::packFunctionResults(ArrayRef types) { assert(!types.empty() && "expected non-empty list of type"); if (types.size() == 1) return convertType(types.front()); SmallVector resultTypes; resultTypes.reserve(types.size()); for (auto t : types) { auto converted = convertType(t).dyn_cast(); if (!converted) return {}; resultTypes.push_back(converted); } return LLVM::LLVMType::getStructTy(llvmDialect, resultTypes); } Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder) { auto *context = builder.getContext(); auto int64Ty = LLVM::LLVMType::getInt64Ty(getDialect()); auto indexType = IndexType::get(context); // Alloca with proper alignment. We do not expect optimizations of this // alloca op and so we omit allocating at the entry block. auto ptrType = operand.getType().cast().getPointerTo(); Value one = builder.create(loc, int64Ty, IntegerAttr::get(indexType, 1)); Value allocated = builder.create(loc, ptrType, one, /*alignment=*/0); // Store into the alloca'ed descriptor. builder.create(loc, operand, allocated); return allocated; } SmallVector LLVMTypeConverter::promoteMemRefDescriptors(Location loc, ValueRange opOperands, ValueRange operands, OpBuilder &builder) { SmallVector promotedOperands; promotedOperands.reserve(operands.size()); for (auto it : llvm::zip(opOperands, operands)) { auto operand = std::get<0>(it); auto llvmOperand = std::get<1>(it); if (operand.getType().isa()) { UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand, promotedOperands); continue; } if (auto memrefType = operand.getType().dyn_cast()) { MemRefDescriptor::unpack(builder, loc, llvmOperand, operand.getType().cast(), promotedOperands); continue; } promotedOperands.push_back(operand); } return promotedOperands; } namespace { /// A pass converting MLIR operations into the LLVM IR dialect. struct LLVMLoweringPass : public ConvertStandardToLLVMBase { LLVMLoweringPass() = default; LLVMLoweringPass(bool useBarePtrCallConv, bool emitCWrappers, unsigned indexBitwidth, bool useAlignedAlloc) { this->useBarePtrCallConv = useBarePtrCallConv; this->emitCWrappers = emitCWrappers; this->indexBitwidth = indexBitwidth; this->useAlignedAlloc = useAlignedAlloc; } /// Run the dialect converter on the module. void runOnOperation() override { if (useBarePtrCallConv && emitCWrappers) { getOperation().emitError() << "incompatible conversion options: bare-pointer calling convention " "and C wrapper emission"; signalPassFailure(); return; } ModuleOp m = getOperation(); LLVMTypeConverterCustomization customs; customs.funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter : structFuncArgTypeConverter; customs.indexBitwidth = indexBitwidth; LLVMTypeConverter typeConverter(&getContext(), customs); OwningRewritePatternList patterns; if (useBarePtrCallConv) populateStdToLLVMBarePtrConversionPatterns(typeConverter, patterns, useAlignedAlloc); else populateStdToLLVMConversionPatterns(typeConverter, patterns, emitCWrappers, useAlignedAlloc); LLVMConversionTarget target(getContext()); if (failed(applyPartialConversion(m, target, patterns, &typeConverter))) signalPassFailure(); } }; } // end namespace mlir::LLVMConversionTarget::LLVMConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) { this->addLegalDialect(); this->addIllegalOp(); this->addIllegalOp(); } std::unique_ptr> mlir::createLowerToLLVMPass(const LowerToLLVMOptions &options) { return std::make_unique( options.useBarePtrCallConv, options.emitCWrappers, options.indexBitwidth, options.useAlignedAlloc); }