diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -46,6 +46,7 @@ class LoopAccessInfo; class Loop; class LoopInfo; +class PredicatedInstruction; class ProfileSummaryInfo; class SCEV; class ScalarEvolution; @@ -1195,6 +1196,40 @@ /// Intrinsics") Use of %evl is discouraged when that is not the case. bool hasActiveVectorLength() const; + struct VPLegalization { + enum VPTransform { + // keep the predicating parameter + Legal = 0, + // where legal, discard the predicate parameter + Discard = 1, + // transform into something else that is also predicating + Convert = 2 + }; + + // How to transform the EVL parameter. + // Legal: keep the EVL parameter as it is. + // Discard: Ignore the EVL parameter where it is safe to do so. + // Convert: Fold the EVL into the mask parameter. + VPTransform EVLParamStrategy; + + // How to transform the operator. + // Legal: The target supports this operator. + // Convert: Convert this to a non-VP operation. + // The 'Discard' strategy is invalid. + VPTransform OpStrategy; + + bool doNothing() const { + return (EVLParamStrategy == Legal) && (OpStrategy == Legal); + } + VPLegalization(VPTransform EVLParamStrategy, VPTransform OpStrategy) + : EVLParamStrategy(EVLParamStrategy), OpStrategy(OpStrategy) {} + }; + + /// This will only be called for vector-predicated instructions. + /// \returns How the target needs this vector-predicated operation to be + /// transformed. + VPLegalization + getVPLegalizationStrategy(const PredicatedInstruction &PI) const; /// @} /// @} @@ -1462,6 +1497,8 @@ virtual bool shouldExpandReduction(const IntrinsicInst *II) const = 0; virtual unsigned getGISelRematGlobalCost() const = 0; virtual bool hasActiveVectorLength() const = 0; + virtual VPLegalization + getVPLegalizationStrategy(const PredicatedInstruction &PI) const = 0; virtual int getInstructionLatency(const Instruction *I) = 0; }; @@ -1946,6 +1983,11 @@ return Impl.hasActiveVectorLength(); } + VPLegalization + getVPLegalizationStrategy(const PredicatedInstruction &PI) const override { + return Impl.getVPLegalizationStrategy(PI); + } + int getInstructionLatency(const Instruction *I) override { return Impl.getInstructionLatency(I); } diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -623,6 +623,13 @@ bool hasActiveVectorLength() const { return false; } + TargetTransformInfo::VPLegalization + getVPLegalizationStrategy(const PredicatedInstruction &PI) const { + return TargetTransformInfo::VPLegalization( + /* EVLParamStrategy */ TargetTransformInfo::VPLegalization::Discard, + /* OperatorStrategy */ TargetTransformInfo::VPLegalization::Convert); + } + protected: // Obtain the minimum required size to hold the value (without the sign) // In case of a vector it returns the min required size for one element. diff --git a/llvm/include/llvm/CodeGen/ExpandVectorPredication.h b/llvm/include/llvm/CodeGen/ExpandVectorPredication.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/CodeGen/ExpandVectorPredication.h @@ -0,0 +1,23 @@ +//===-- ExpandVectorPredication.h - Expand vector predication ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CODEGEN_EXPANDVECTORPREDICATION_H +#define LLVM_CODEGEN_EXPANDVECTORPREDICATION_H + +#include "llvm/IR/PassManager.h" + +namespace llvm { + +class ExpandVectorPredicationPass + : public PassInfoMixin { +public: + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); +}; +} // end namespace llvm + +#endif // LLVM_CODEGEN_EXPANDVECTORPREDICATION_H diff --git a/llvm/include/llvm/CodeGen/Passes.h b/llvm/include/llvm/CodeGen/Passes.h --- a/llvm/include/llvm/CodeGen/Passes.h +++ b/llvm/include/llvm/CodeGen/Passes.h @@ -452,6 +452,11 @@ /// shuffles. FunctionPass *createExpandReductionsPass(); + /// This pass expands the vector predication intrinsics into unpredicated + /// instructions with selects or just the explicit vector length into the + /// predicate mask. + FunctionPass *createExpandVectorPredicationPass(); + // This pass expands memcmp() to load/stores. FunctionPass *createExpandMemCmpPass(); diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h --- a/llvm/include/llvm/IR/IntrinsicInst.h +++ b/llvm/include/llvm/IR/IntrinsicInst.h @@ -225,9 +225,11 @@ /// \return the mask parameter or nullptr. Value *getMaskParam() const; + void setMaskParam(Value *); /// \return the vector length parameter or nullptr. Value *getVectorLengthParam() const; + void setVectorLengthParam(Value *); /// \return whether the vector length param can be ignored. bool canIgnoreVectorLengthParam() const; diff --git a/llvm/include/llvm/IR/PredicatedInst.h b/llvm/include/llvm/IR/PredicatedInst.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/IR/PredicatedInst.h @@ -0,0 +1,102 @@ +//===-- PredicatedInst.h - Utility for predicated instructions --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines various classes for working with vector-predicated +// instructions. Predicated instructions are either regular instructions or +// calls to Vector Predication (VP) intrinsics that have a mask and an explicit +// vector length argument. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_IR_PREDICATEDINST_H +#define LLVM_IR_PREDICATEDINST_H + +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" + +#include + +namespace llvm { + +class BasicBlock; + +/// The PredicatedInstruction class makes VPIntrinsics appear as regular +/// instructions with an additional mask and evl parameter. Regular instructions +/// appear as if they had their predication parameters hard-wired to "all lanes +/// enabled". `PredicatedInstruction` is a conceptual superclass of +/// `Instruction`. +/// +/// For example, for the intrinsic call: +/// +/// %x = call @llvm.vp.add.v256i32(%y,%z,%mask,%evl) +/// +/// `PredicatedInstruction` reports the Opcode as Instruction::Add, just as +/// for: +/// +/// %x = add <256 x i32> %y, %z +/// +class PredicatedInstruction : public User { +public: + // The PredicatedInstruction class is intended to be used as a utility, and is + // never itself instantiated. + PredicatedInstruction() = delete; + ~PredicatedInstruction() = delete; + + void *operator new(size_t s) = delete; + + void copyIRFlags(const Value *V, bool IncludeWrapFlags) { + cast(this)->copyIRFlags(V, IncludeWrapFlags); + } + + BasicBlock *getParent() { return cast(this)->getParent(); } + const BasicBlock *getParent() const { + return cast(this)->getParent(); + } + + /// \returns the Mask parameter of this instruction or nullptr if there is + /// none. + Value *getMaskParam() const; + + /// \returns the Explicit Vector Length parameter of this instruction or + /// nullptr if there is none. + Value *getVectorLengthParam() const; + + /// \returns True if the passed vector length value has no predicating effect + /// on the operation. + bool canIgnoreVectorLengthParam() const; + + /// \returns True if the static operator of this instruction has a mask or + /// vector length parameter. + bool isVectorPredicatedOp() const; + + /// \returns Whether this is functionally a binary operator. + bool isBinaryOp() const; + + /// \returns the effective Opcode of this operation (ignoring the mask and + /// vector length param). + unsigned getOpcode() const; + + static bool classof(const Instruction *I) { return isa(I); } + static bool classof(const ConstantExpr *CE) { return false; } + static bool classof(const Value *V) { return isa(V); } + + /// Convenience function for getting all the fast-math flags, which must be an + /// operator which supports these flags. See LangRef.html for the meaning of + /// these flags. + FastMathFlags getFastMathFlags() const; +}; + +} // namespace llvm + +#endif // LLVM_IR_PREDICATEDINST_H diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h --- a/llvm/include/llvm/InitializePasses.h +++ b/llvm/include/llvm/InitializePasses.h @@ -149,6 +149,7 @@ void initializeExpandMemCmpPassPass(PassRegistry&); void initializeExpandPostRAPass(PassRegistry&); void initializeExpandReductionsPass(PassRegistry&); +void initializeExpandVectorPredicationPass(PassRegistry &); void initializeMakeGuardsExplicitLegacyPassPass(PassRegistry&); void initializeExternalAAWrapperPassPass(PassRegistry&); void initializeFEntryInserterPass(PassRegistry&); diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -937,6 +937,12 @@ return TTIImpl->useReductionIntrinsic(Opcode, Ty, Flags); } +TargetTransformInfo::VPLegalization +TargetTransformInfo::getVPLegalizationStrategy( + const PredicatedInstruction &PI) const { + return TTIImpl->getVPLegalizationStrategy(PI); +} + bool TargetTransformInfo::shouldExpandReduction(const IntrinsicInst *II) const { return TTIImpl->shouldExpandReduction(II); } diff --git a/llvm/lib/CodeGen/CMakeLists.txt b/llvm/lib/CodeGen/CMakeLists.txt --- a/llvm/lib/CodeGen/CMakeLists.txt +++ b/llvm/lib/CodeGen/CMakeLists.txt @@ -27,6 +27,7 @@ ExpandMemCmp.cpp ExpandPostRAPseudos.cpp ExpandReductions.cpp + ExpandVectorPredication.cpp FaultMaps.cpp FEntryInserter.cpp FinalizeISel.cpp diff --git a/llvm/lib/CodeGen/ExpandVectorPredication.cpp b/llvm/lib/CodeGen/ExpandVectorPredication.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/CodeGen/ExpandVectorPredication.cpp @@ -0,0 +1,459 @@ +//===--- CodeGen/ExpandVectorPredication.cpp - Expand VP intrinsics -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass implements IR expansion for vector predication intrinsics, allowing +// targets to enable vector predication until just before codegen. +// +//===----------------------------------------------------------------------===// + +#include "llvm/CodeGen/ExpandVectorPredication.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/CodeGen/Passes.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PredicatedInst.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" + +using namespace llvm; + +using VPLegalization = TargetTransformInfo::VPLegalization; + +#define DEBUG_TYPE "expand-vec-pred" + +STATISTIC(NumFoldedVL, "Number of folded vector length params"); +STATISTIC(NumLoweredVPOps, "Number of folded vector predication operations"); + +///// Helpers { + +/// \returns Whether the vector mask \p MaskVal has all lane bits set. +static bool isAllTrueMask(Value *MaskVal) { + auto ConstVec = dyn_cast(MaskVal); + if (!ConstVec) + return false; + return ConstVec->isAllOnesValue(); +} + +/// Computes the smallest integer bit width to hold the step vector <0, .., +/// NumVectorElements - 1> +static unsigned getLeastLaneBitsForStepVector(unsigned NumVectorElements) { + unsigned MostSignificantOne = + llvm::countLeadingZeros(NumVectorElements, ZB_Undefined); + return std::max(IntegerType::MIN_INT_BITS, 64 - MostSignificantOne); +} + +/// \returns A non-excepting divisor constant for this type. +static Constant *getSafeDivisor(Type *DivTy) { + assert(DivTy->isIntOrIntVectorTy()); + return ConstantInt::get(DivTy, 1u, false); +} + +/// Transfer operation properties from \p OldVPI to \p NewVal. +static void transferDecorations(Value &NewVal, PredicatedInstruction &OldPI) { + auto NewInst = dyn_cast(&NewVal); + if (!NewInst || !isa(NewVal)) + return; + + auto OldFMOp = dyn_cast(&OldPI); + if (!OldFMOp) + return; + + NewInst->setFastMathFlags(OldFMOp->getFastMathFlags()); +} + +/// Transfer all properties from \p OldOp to \p NewOp and replace all uses. +/// OldVP gets erased. +static void replaceOperation(Value &NewOp, PredicatedInstruction &OldOp) { + transferDecorations(NewOp, OldOp); + OldOp.replaceAllUsesWith(&NewOp); + cast(OldOp).eraseFromParent(); +} + +//// } Helpers + +namespace { + +// Expansion pass state at function scope. +struct CachingVPExpander { + Function &F; + const TargetTransformInfo &TTI; + + /// \returns A (fixed length) vector with ascending integer indices + /// (<0, 1, ..., NumElems-1>). + Value *createStepVector(IRBuilder<> &Builder, int32_t ElemBits, + int32_t NumElems); + + /// \returns A bitmask that is true where the lane position is less-than \p + /// EVLParam + /// + /// \p Builder + /// Used for instruction creation. + /// \p VLParam + /// The explicit vector length parameter to test against the lane + /// positions. + /// \p ElemCount + /// Static (potentially scalable) number of vector elements + Value *convertEVLToMask(IRBuilder<> &Builder, Value *EVLParam, + ElementCount ElemCount); + + Value *foldEVLIntoMask(PredicatedInstruction &OldPI); + + /// "Remove" the %evl parameter of \p PI by setting it to the static vector + /// length of the operation. + void discardEVLParameter(PredicatedInstruction &PI); + + /// \brief Lower this VP binary operator to a non-VP binary operator. + Value *expandPredicationInBinaryOperator(IRBuilder<> &Builder, + PredicatedInstruction &PI); + + /// \brief query TTI and expand the vector predication in \p P accordingly. + Value *expandPredication(PredicatedInstruction &PI); + + /// \brief return a good (for fast icmp) integer bit width to expand + /// the EVL comparison against the stepvector in. + std::map StaticVLToBitsCache; // TODO 'SmallMap' class + unsigned getLaneBitsForEVLCompare(unsigned StaticVL); + +public: + CachingVPExpander(Function &F, const TargetTransformInfo &TTI) + : F(F), TTI(TTI) {} + + // expand VP ops in \p F according to \p TTI. + bool expandVectorPredication(); +}; + +//// CachingVPExpander { + +unsigned CachingVPExpander::getLaneBitsForEVLCompare(unsigned StaticVL) { + auto ItCached = StaticVLToBitsCache.find(StaticVL); + if (ItCached != StaticVLToBitsCache.end()) + return ItCached->second; + + // The smallest integer to hold <0, .., ElemCount.Min -1> + // Cannot choose less bits than this or the expansion will be invalid. + unsigned MinLaneBits = getLeastLaneBitsForStepVector(StaticVL); + LLVM_DEBUG(dbgs() << "Least lane bits for " << StaticVL << " is " + << MinLaneBits << "\n";); + + // If the EVL compare will be expanded into scalar code, choose the + // smallest integer type. + if (TTI.getRegisterBitWidth(/* Vector */ true) == 0) + return MinLaneBits; + + // Otw, the generated vector operation will likely map to vector instructions. + // The largest bit width to fit the EVL expansion in one vector register. + unsigned MaxLaneBits = std::min( + IntegerType::MAX_INT_BITS, TTI.getRegisterBitWidth(true) / StaticVL); + + // Many SIMD instruction are restricted in their supported lane bit widths. + // We choose the bit width that gives us the cheapest vector compare. + int Cheapest = std::numeric_limits::max(); + auto &Ctx = F.getContext(); + unsigned CheapestLaneBits = MinLaneBits; + for (auto LaneBits = MinLaneBits; LaneBits < MaxLaneBits; ++LaneBits) { + int VecCmpCost = TTI.getCmpSelInstrCost( + Instruction::ICmp, VectorType::get(Type::getIntNTy(Ctx, LaneBits), + StaticVL, /* Scalable */ false)); + if (VecCmpCost < Cheapest) { + Cheapest = VecCmpCost; + CheapestLaneBits = LaneBits; + } + } + + StaticVLToBitsCache[StaticVL] = CheapestLaneBits; + return CheapestLaneBits; +} + +Value *CachingVPExpander::createStepVector(IRBuilder<> &Builder, + int32_t ElemBits, int32_t NumElems) { + // TODO add caching + SmallVector ConstElems; + + Type *LaneTy = Builder.getIntNTy(ElemBits); + + for (int32_t Idx = 0; Idx < NumElems; ++Idx) { + ConstElems.push_back(ConstantInt::get(LaneTy, Idx, false)); + } + + return ConstantVector::get(ConstElems); +} + +Value *CachingVPExpander::convertEVLToMask(IRBuilder<> &Builder, + Value *EVLParam, + ElementCount ElemCount) { + // TODO add caching + if (ElemCount.Scalable) { + auto M = Builder.GetInsertBlock()->getModule(); + auto BoolVecTy = VectorType::get(Builder.getInt1Ty(), ElemCount); + auto ActiveMaskFunc = Intrinsic::getDeclaration( + M, Intrinsic::get_active_lane_mask, {BoolVecTy, EVLParam->getType()}); + // `get_active_lane_mask` performs an implicit less-equal-than comparison. + // Offset the lane index by one accordingly. + auto ConstZero = Builder.getInt32(1); // Use the %evl param type + return Builder.CreateCall(ActiveMaskFunc, {ConstZero, EVLParam}); + } + + assert(!ElemCount.Scalable && "code path not applicable to scalable types"); + unsigned NumElems = ElemCount.Min; + + unsigned ElemBits = getLaneBitsForEVLCompare(NumElems); + + Type *LaneTy = Builder.getIntNTy(ElemBits); + + auto ExtVLParam = Builder.CreateZExtOrTrunc(EVLParam, LaneTy); + auto VLSplat = Builder.CreateVectorSplat(NumElems, ExtVLParam); + + auto IdxVec = createStepVector(Builder, ElemBits, NumElems); + + return Builder.CreateICmp(CmpInst::ICMP_ULT, IdxVec, VLSplat); +} + +Value *CachingVPExpander::expandPredicationInBinaryOperator( + IRBuilder<> &Builder, PredicatedInstruction &PI) { + assert(PI.canIgnoreVectorLengthParam()); + assert(PI.isBinaryOp()); + + auto FirstOp = PI.getOperand(0); + auto SndOp = PI.getOperand(1); + + auto Mask = PI.getMaskParam(); + + // Blend in safe operands + if (Mask && !isAllTrueMask(Mask)) { + switch (PI.getOpcode()) { + default: + // can safely ignore the predicate + break; + + // Division operators need a safe divisor on masked-off lanes (1) + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::URem: + case Instruction::SRem: + // 2nd operand must not be zero + auto SafeDivisor = getSafeDivisor(PI.getType()); + SndOp = Builder.CreateSelect(Mask, SndOp, SafeDivisor); + } + } + + auto NewOC = static_cast(PI.getOpcode()); + auto NewBinOp = Builder.CreateBinOp(NewOC, FirstOp, SndOp, PI.getName()); + + replaceOperation(*NewBinOp, PI); + return NewBinOp; +} + +void CachingVPExpander::discardEVLParameter(PredicatedInstruction &PI) { + LLVM_DEBUG(dbgs() << "Discard EVL parameter in " << PI << "\n"); + + if (PI.canIgnoreVectorLengthParam()) + return; + + Value *EVLParam = PI.getVectorLengthParam(); + if (!EVLParam) + return; + + VPIntrinsic &VPI = cast(PI); + ElementCount StaticElemCount = VPI.getStaticVectorLength(); + Value *MaxEVL = nullptr; + auto Int32Ty = Type::getInt32Ty(PI.getContext()); + if (StaticElemCount.Scalable) { + // TODO add caching + auto M = VPI.getModule(); + auto VScaleFunc = Intrinsic::getDeclaration(M, Intrinsic::vscale, Int32Ty); + IRBuilder<> Builder(VPI.getParent(), VPI.getIterator()); + auto FactorConst = Builder.getInt32(StaticElemCount.Min); + auto VScale = Builder.CreateCall(VScaleFunc, {}, "vscale"); + MaxEVL = Builder.CreateMul(VScale, FactorConst, "scalable_size", + /*NUW*/ true, /*NSW*/ false); + } else { + MaxEVL = ConstantInt::get(Int32Ty, StaticElemCount.Min, false); + } + VPI.setVectorLengthParam(MaxEVL); +} + +Value *CachingVPExpander::foldEVLIntoMask(PredicatedInstruction &OldPI) { + LLVM_DEBUG(dbgs() << "Folding vlen for " << OldPI << '\n'); + + IRBuilder<> Builder(cast(&OldPI)); + + // No %evl parameter and so nothing to do here + if (OldPI.canIgnoreVectorLengthParam()) { + return &OldPI; + } + + // Only VP intrinsics can have a %evl parameter + VPIntrinsic &VPI = cast(OldPI); + Value *OldMaskParam = VPI.getMaskParam(); + Value *OldEVLParam = VPI.getVectorLengthParam(); + assert(OldMaskParam && "no mask param to fold the vl param into"); + assert(OldEVLParam && "no EVL param to fold away"); + + LLVM_DEBUG(dbgs() << "OLD evl: " << *OldEVLParam << '\n'); + LLVM_DEBUG(dbgs() << "OLD mask: " << *OldMaskParam << '\n'); + + // Convert the %evl predication into vector mask predication. + ElementCount ElemCount = VPI.getStaticVectorLength(); + auto *VLMask = convertEVLToMask(Builder, OldEVLParam, ElemCount); + auto NewMaskParam = Builder.CreateAnd(VLMask, OldMaskParam); + VPI.setMaskParam(NewMaskParam); + + // Drop the EVl parameter + discardEVLParameter(OldPI); + assert(VPI.canIgnoreVectorLengthParam() && + "transformation did not render the evl param ineffective!"); + + // re-asses the modified instruction + return &VPI; +} + +Value *CachingVPExpander::expandPredication(PredicatedInstruction &OldPI) { + LLVM_DEBUG(dbgs() << "Lowering to unpredicated op: " << OldPI << '\n'); + + IRBuilder<> Builder(cast(&OldPI)); + + // Try lowering to a LLVM instruction first. + unsigned OC = OldPI.getOpcode(); +#define FIRST_BINARY_INST(X) unsigned FirstBinOp = X; +#define LAST_BINARY_INST(X) unsigned LastBinOp = X; +#include "llvm/IR/Instruction.def" + + if (FirstBinOp <= OC && OC <= LastBinOp) { + return expandPredicationInBinaryOperator(Builder, OldPI); + } + + return &OldPI; +} + +//// } CachingVPExpander + +struct TransformJob { + PredicatedInstruction *PI; + TargetTransformInfo::VPLegalization Strategy; + TransformJob(PredicatedInstruction *PI, + TargetTransformInfo::VPLegalization InitStrat) + : PI(PI), Strategy(InitStrat) {} + + bool isDone() const { return Strategy.doNothing(); } +}; + +void sanitizeStrategy(Instruction &I, VPLegalization &LegalizeStrat) { + // Speculatable instructions do not strictle need predication. + if (isSafeToSpeculativelyExecute(&I)) + return; + + // Preserve the predication effect of the EVL parameter by folding + // it into the predicate. + if (LegalizeStrat.EVLParamStrategy == VPLegalization::Discard) { + LegalizeStrat.EVLParamStrategy = VPLegalization::Convert; + } +} + +/// \brief Expand llvm.vp.* intrinsics as requested by \p TTI. +bool CachingVPExpander::expandVectorPredication() { + // Holds all vector-predicated ops with an effective vector length param that + SmallVector Worklist; + + for (auto &I : instructions(F)) { + auto &PI = cast(I); + auto VPStrat = TTI.getVPLegalizationStrategy(PI); + sanitizeStrategy(I, VPStrat); + if (!VPStrat.doNothing()) { + Worklist.emplace_back(&PI, VPStrat); + } + } + if (Worklist.empty()) + return false; + + LLVM_DEBUG(dbgs() << "\n:::: Transforming instructions. ::::\n"); + for (TransformJob Job : Worklist) { + // Transform the EVL parameter + switch (Job.Strategy.EVLParamStrategy) { + case VPLegalization::Legal: + break; + case VPLegalization::Discard: { + discardEVLParameter(*Job.PI); + } break; + case VPLegalization::Convert: { + if (foldEVLIntoMask(*Job.PI)) { + ++NumFoldedVL; + } + } break; + } + Job.Strategy.EVLParamStrategy = VPLegalization::Legal; + + // Replace the operator + switch (Job.Strategy.OpStrategy) { + case VPLegalization::Legal: + break; + case VPLegalization::Discard: + llvm_unreachable("Invalid strategy for operators."); + case VPLegalization::Convert: { + expandPredication(*Job.PI); + ++NumLoweredVPOps; + } break; + } + Job.Strategy.OpStrategy = VPLegalization::Legal; + + assert(Job.isDone() && "incomplete transformation"); + } + + return true; +} +class ExpandVectorPredication : public FunctionPass { +public: + static char ID; + ExpandVectorPredication() : FunctionPass(ID) { + initializeExpandVectorPredicationPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + const auto *TTI = &getAnalysis().getTTI(F); + CachingVPExpander VPExpander(F, *TTI); + return VPExpander.expandVectorPredication(); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + AU.setPreservesCFG(); + } +}; +} // namespace + +char ExpandVectorPredication::ID; +INITIALIZE_PASS_BEGIN(ExpandVectorPredication, "expand-vec-pred", + "Expand vector predication intrinsics", false, false) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_END(ExpandVectorPredication, "expand-vec-pred", + "Expand vector predication intrinsics", false, false) + +FunctionPass *llvm::createExpandVectorPredicationPass() { + return new ExpandVectorPredication(); +} + +PreservedAnalyses +ExpandVectorPredicationPass::run(Function &F, FunctionAnalysisManager &AM) { + const auto &TTI = AM.getResult(F); + CachingVPExpander VPExpander(F, TTI); + if (!VPExpander.expandVectorPredication()) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserveSet(); + return PA; +} diff --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp --- a/llvm/lib/CodeGen/TargetPassConfig.cpp +++ b/llvm/lib/CodeGen/TargetPassConfig.cpp @@ -704,6 +704,11 @@ // Instrument function entry and exit, e.g. with calls to mcount(). addPass(createPostInlineEntryExitInstrumenterPass()); + // Expand vector predication intrinsics into standard IR instructions. + // This pass has to run before ScalarizeMaskedMemIntrin and ExpandReduction + // passes since it emits those kinds of intrinsics. + addPass(createExpandVectorPredicationPass()); + // Add scalarization of target's unsupported masked memory intrinsics pass. // the unsupported intrinsic will be replaced with a chain of basic blocks, // that stores/loads element one-by-one if the appropriate mask bit is set. diff --git a/llvm/lib/IR/CMakeLists.txt b/llvm/lib/IR/CMakeLists.txt --- a/llvm/lib/IR/CMakeLists.txt +++ b/llvm/lib/IR/CMakeLists.txt @@ -44,6 +44,7 @@ PassManager.cpp PassRegistry.cpp PassTimingInfo.cpp + PredicatedInst.cpp SafepointIRVerifier.cpp ProfileSummary.cpp Statepoint.cpp diff --git a/llvm/lib/IR/IntrinsicInst.cpp b/llvm/lib/IR/IntrinsicInst.cpp --- a/llvm/lib/IR/IntrinsicInst.cpp +++ b/llvm/lib/IR/IntrinsicInst.cpp @@ -196,6 +196,12 @@ return nullptr; } +void VPIntrinsic::setMaskParam(Value *NewMask) { + auto MaskPos = GetMaskParamPos(getIntrinsicID()); + assert(MaskPos.hasValue()); + setArgOperand(MaskPos.getValue(), NewMask); +} + Value *VPIntrinsic::getVectorLengthParam() const { auto vlenPos = GetVectorLengthParamPos(getIntrinsicID()); if (vlenPos) @@ -203,6 +209,12 @@ return nullptr; } +void VPIntrinsic::setVectorLengthParam(Value *NewEVL) { + auto EVLPos = GetVectorLengthParamPos(getIntrinsicID()); + assert(EVLPos.hasValue()); + setArgOperand(EVLPos.getValue(), NewEVL); +} + Optional VPIntrinsic::GetMaskParamPos(Intrinsic::ID IntrinsicID) { switch (IntrinsicID) { default: diff --git a/llvm/lib/IR/PredicatedInst.cpp b/llvm/lib/IR/PredicatedInst.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/IR/PredicatedInst.cpp @@ -0,0 +1,72 @@ +//===--- llvm/PredicatedInst.cpp - Vector-predication utility subclass ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the PredicationInstruction class. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include + +namespace { +#define FIRST_BINARY_INST(X) const unsigned FirstBinOp = X; +#define LAST_BINARY_INST(X) const unsigned LastBinOp = X; +#include "llvm/IR/Instruction.def" +} // namespace + +namespace llvm { + +bool PredicatedInstruction::canIgnoreVectorLengthParam() const { + auto *VPI = dyn_cast(this); + if (!VPI) + return true; + + return VPI->canIgnoreVectorLengthParam(); +} + +FastMathFlags PredicatedInstruction::getFastMathFlags() const { + return cast(this)->getFastMathFlags(); +} + +bool PredicatedInstruction::isBinaryOp() const { + unsigned FuncOC = getOpcode(); + + return FirstBinOp <= FuncOC && FuncOC <= LastBinOp; +} + +Value *PredicatedInstruction::getMaskParam() const { + auto *thisVP = dyn_cast(this); + if (!thisVP) + return nullptr; + return thisVP->getMaskParam(); +} + +Value *PredicatedInstruction::getVectorLengthParam() const { + auto *thisVP = dyn_cast(this); + if (!thisVP) + return nullptr; + return thisVP->getVectorLengthParam(); +} + +bool PredicatedInstruction::isVectorPredicatedOp() const { + return isa(this); +} + +unsigned PredicatedInstruction::getOpcode() const { + auto *VPInst = dyn_cast(this); + + if (!VPInst) + return cast(this)->getOpcode(); + + return VPInst->getFunctionalOpcode(); +} + +} // namespace llvm diff --git a/llvm/test/CodeGen/AArch64/O0-pipeline.ll b/llvm/test/CodeGen/AArch64/O0-pipeline.ll --- a/llvm/test/CodeGen/AArch64/O0-pipeline.ll +++ b/llvm/test/CodeGen/AArch64/O0-pipeline.ll @@ -22,6 +22,7 @@ ; CHECK-NEXT: Lower constant intrinsics ; CHECK-NEXT: Remove unreachable blocks from the CFG ; CHECK-NEXT: Instrument function entry/exit with calls to e.g. mcount() (post inlining) +; CHECK-NEXT: Expand vector predication intrinsics ; CHECK-NEXT: Scalarize Masked Memory Intrinsics ; CHECK-NEXT: Expand reduction intrinsics ; CHECK-NEXT: AArch64 Stack Tagging diff --git a/llvm/test/CodeGen/AArch64/O3-pipeline.ll b/llvm/test/CodeGen/AArch64/O3-pipeline.ll --- a/llvm/test/CodeGen/AArch64/O3-pipeline.ll +++ b/llvm/test/CodeGen/AArch64/O3-pipeline.ll @@ -57,6 +57,7 @@ ; CHECK-NEXT: Constant Hoisting ; CHECK-NEXT: Partially inline calls to library functions ; CHECK-NEXT: Instrument function entry/exit with calls to e.g. mcount() (post inlining) +; CHECK-NEXT: Expand vector predication intrinsics ; CHECK-NEXT: Scalarize Masked Memory Intrinsics ; CHECK-NEXT: Expand reduction intrinsics ; CHECK-NEXT: Stack Safety Analysis diff --git a/llvm/test/CodeGen/ARM/O3-pipeline.ll b/llvm/test/CodeGen/ARM/O3-pipeline.ll --- a/llvm/test/CodeGen/ARM/O3-pipeline.ll +++ b/llvm/test/CodeGen/ARM/O3-pipeline.ll @@ -37,6 +37,7 @@ ; CHECK-NEXT: Constant Hoisting ; CHECK-NEXT: Partially inline calls to library functions ; CHECK-NEXT: Instrument function entry/exit with calls to e.g. mcount() (post inlining) +; CHECK-NEXT: Expand vector predication intrinsics ; CHECK-NEXT: Scalarize Masked Memory Intrinsics ; CHECK-NEXT: Expand reduction intrinsics ; CHECK-NEXT: Dominator Tree Construction diff --git a/llvm/test/CodeGen/Generic/expand-vp.ll b/llvm/test/CodeGen/Generic/expand-vp.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/Generic/expand-vp.ll @@ -0,0 +1,84 @@ +; RUN: opt --expand-vec-pred -S < %s | FileCheck %s + +; All VP intrinsics have to be lowered into non-VP ops +; CHECK-NOT: {{call.* @llvm.vp.add}} +; CHECK-NOT: {{call.* @llvm.vp.sub}} +; CHECK-NOT: {{call.* @llvm.vp.mul}} +; CHECK-NOT: {{call.* @llvm.vp.sdiv}} +; CHECK-NOT: {{call.* @llvm.vp.srem}} +; CHECK-NOT: {{call.* @llvm.vp.udiv}} +; CHECK-NOT: {{call.* @llvm.vp.urem}} +; CHECK-NOT: {{call.* @llvm.vp.and}} +; CHECK-NOT: {{call.* @llvm.vp.or}} +; CHECK-NOT: {{call.* @llvm.vp.xor}} +; CHECK-NOT: {{call.* @llvm.vp.ashr}} +; CHECK-NOT: {{call.* @llvm.vp.lshr}} +; CHECK-NOT: {{call.* @llvm.vp.shl}} + +define void @test_vp_int_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x i32> %i2, <8 x i32> %f3, <8 x i1> %m, i32 %n) { + %r0 = call <8 x i32> @llvm.vp.add.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r1 = call <8 x i32> @llvm.vp.sub.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r2 = call <8 x i32> @llvm.vp.mul.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r3 = call <8 x i32> @llvm.vp.sdiv.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r4 = call <8 x i32> @llvm.vp.srem.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r5 = call <8 x i32> @llvm.vp.udiv.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r6 = call <8 x i32> @llvm.vp.urem.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r7 = call <8 x i32> @llvm.vp.and.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r8 = call <8 x i32> @llvm.vp.or.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r9 = call <8 x i32> @llvm.vp.xor.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %rA = call <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %rB = call <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %rC = call <8 x i32> @llvm.vp.shl.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + ret void +} + +; fixed-width vectors +; integer arith +declare <8 x i32> @llvm.vp.add.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.sub.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.mul.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.sdiv.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.srem.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.udiv.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.urem.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +; bit arith +declare <8 x i32> @llvm.vp.and.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.xor.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.or.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.shl.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) + +define void @test_vp_int_vscale( %i0, %i1, %i2, %f3, %m, i32 %n) { + %r0 = call @llvm.vp.add.nxv4i32( %i0, %i1, %m, i32 %n) + %r1 = call @llvm.vp.sub.nxv4i32( %i0, %i1, %m, i32 %n) + %r2 = call @llvm.vp.mul.nxv4i32( %i0, %i1, %m, i32 %n) + %r3 = call @llvm.vp.sdiv.nxv4i32( %i0, %i1, %m, i32 %n) + %r4 = call @llvm.vp.srem.nxv4i32( %i0, %i1, %m, i32 %n) + %r5 = call @llvm.vp.udiv.nxv4i32( %i0, %i1, %m, i32 %n) + %r6 = call @llvm.vp.urem.nxv4i32( %i0, %i1, %m, i32 %n) + %r7 = call @llvm.vp.and.nxv4i32( %i0, %i1, %m, i32 %n) + %r8 = call @llvm.vp.or.nxv4i32( %i0, %i1, %m, i32 %n) + %r9 = call @llvm.vp.xor.nxv4i32( %i0, %i1, %m, i32 %n) + %rA = call @llvm.vp.ashr.nxv4i32( %i0, %i1, %m, i32 %n) + %rB = call @llvm.vp.lshr.nxv4i32( %i0, %i1, %m, i32 %n) + %rC = call @llvm.vp.shl.nxv4i32( %i0, %i1, %m, i32 %n) + ret void +} + +; scalable-width vectors +; integer arith +declare @llvm.vp.add.nxv4i32(, , , i32) +declare @llvm.vp.sub.nxv4i32(, , , i32) +declare @llvm.vp.mul.nxv4i32(, , , i32) +declare @llvm.vp.sdiv.nxv4i32(, , , i32) +declare @llvm.vp.srem.nxv4i32(, , , i32) +declare @llvm.vp.udiv.nxv4i32(, , , i32) +declare @llvm.vp.urem.nxv4i32(, , , i32) +; bit arith +declare @llvm.vp.and.nxv4i32(, , , i32) +declare @llvm.vp.xor.nxv4i32(, , , i32) +declare @llvm.vp.or.nxv4i32(, , , i32) +declare @llvm.vp.ashr.nxv4i32(, , , i32) +declare @llvm.vp.lshr.nxv4i32(, , , i32) +declare @llvm.vp.shl.nxv4i32(, , , i32) diff --git a/llvm/test/CodeGen/X86/O0-pipeline.ll b/llvm/test/CodeGen/X86/O0-pipeline.ll --- a/llvm/test/CodeGen/X86/O0-pipeline.ll +++ b/llvm/test/CodeGen/X86/O0-pipeline.ll @@ -24,6 +24,7 @@ ; CHECK-NEXT: Lower constant intrinsics ; CHECK-NEXT: Remove unreachable blocks from the CFG ; CHECK-NEXT: Instrument function entry/exit with calls to e.g. mcount() (post inlining) +; CHECK-NEXT: Expand vector predication intrinsics ; CHECK-NEXT: Scalarize Masked Memory Intrinsics ; CHECK-NEXT: Expand reduction intrinsics ; CHECK-NEXT: Expand indirectbr instructions diff --git a/llvm/test/CodeGen/X86/O3-pipeline.ll b/llvm/test/CodeGen/X86/O3-pipeline.ll --- a/llvm/test/CodeGen/X86/O3-pipeline.ll +++ b/llvm/test/CodeGen/X86/O3-pipeline.ll @@ -49,6 +49,7 @@ ; CHECK-NEXT: Constant Hoisting ; CHECK-NEXT: Partially inline calls to library functions ; CHECK-NEXT: Instrument function entry/exit with calls to e.g. mcount() (post inlining) +; CHECK-NEXT: Expand vector predication intrinsics ; CHECK-NEXT: Scalarize Masked Memory Intrinsics ; CHECK-NEXT: Expand reduction intrinsics ; CHECK-NEXT: Dominator Tree Construction diff --git a/llvm/tools/llc/llc.cpp b/llvm/tools/llc/llc.cpp --- a/llvm/tools/llc/llc.cpp +++ b/llvm/tools/llc/llc.cpp @@ -318,6 +318,7 @@ initializeVectorization(*Registry); initializeScalarizeMaskedMemIntrinPass(*Registry); initializeExpandReductionsPass(*Registry); + initializeExpandVectorPredicationPass(*Registry); initializeHardwareLoopsPass(*Registry); initializeTransformUtils(*Registry); diff --git a/llvm/tools/opt/opt.cpp b/llvm/tools/opt/opt.cpp --- a/llvm/tools/opt/opt.cpp +++ b/llvm/tools/opt/opt.cpp @@ -585,6 +585,7 @@ initializePostInlineEntryExitInstrumenterPass(Registry); initializeUnreachableBlockElimLegacyPassPass(Registry); initializeExpandReductionsPass(Registry); + initializeExpandVectorPredicationPass(Registry); initializeWasmEHPreparePass(Registry); initializeWriteBitcodePassPass(Registry); initializeHardwareLoopsPass(Registry);