diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -48,6 +48,7 @@ class LoadInst; class LoopAccessInfo; class Loop; +class PredicatedInstruction; class ProfileSummaryInfo; class SCEV; class ScalarEvolution; @@ -1161,6 +1162,40 @@ /// Intrinsics") Use of %evl is discouraged when that is not the case. bool hasActiveVectorLength() const; + struct VPLegalization { + enum VPTransform { + // keep the predicating parameter + Legal = 0, + // where legal, discard the predicate parameter + Discard = 1, + // transform into something else that is also predicating + Convert = 2 + }; + + // How to transform the EVL parameter. + // Legal: keep the EVL parameter as it is. + // Discard: Ignore the EVL parameter where it is safe to do so. + // Convert: Fold the EVL into the mask parameter. + VPTransform EVLParamStrategy; + + // How to transform the operator. + // Legal: The target supports this operator. + // Convert: Convert this to a non-VP operation. + // The 'Discard' strategy is invalid. + VPTransform OpStrategy; + + bool doNothing() const { + return (EVLParamStrategy == Legal) && (OpStrategy == Legal); + } + VPLegalization(VPTransform EVLParamStrategy, VPTransform OpStrategy) + : EVLParamStrategy(EVLParamStrategy), OpStrategy(OpStrategy) {} + }; + + /// This will only be called for vector-predicated instructions. + /// \returns How the target needs this vector-predicated operation to be + /// transformed. + VPLegalization + getVPLegalizationStrategy(const PredicatedInstruction &PI) const; /// @} /// @} @@ -1417,6 +1452,8 @@ virtual bool shouldExpandReduction(const IntrinsicInst *II) const = 0; virtual unsigned getGISelRematGlobalCost() const = 0; virtual bool hasActiveVectorLength() const = 0; + virtual VPLegalization + getVPLegalizationStrategy(const PredicatedInstruction &PI) const = 0; virtual int getInstructionLatency(const Instruction *I) = 0; }; @@ -1900,6 +1937,11 @@ return Impl.hasActiveVectorLength(); } + VPLegalization + getVPLegalizationStrategy(const PredicatedInstruction &PI) const override { + return Impl.getVPLegalizationStrategy(PI); + } + int getInstructionLatency(const Instruction *I) override { return Impl.getInstructionLatency(I); } diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -618,6 +618,13 @@ bool hasActiveVectorLength() const { return false; } + TargetTransformInfo::VPLegalization + getVPLegalizationStrategy(const PredicatedInstruction &PI) const { + return TargetTransformInfo::VPLegalization( + /* EVLParamStrategy */ TargetTransformInfo::VPLegalization::Discard, + /* OperatorStrategy */ TargetTransformInfo::VPLegalization::Convert); + } + protected: // Obtain the minimum required size to hold the value (without the sign) // In case of a vector it returns the min required size for one element. diff --git a/llvm/include/llvm/CodeGen/ExpandVectorPredication.h b/llvm/include/llvm/CodeGen/ExpandVectorPredication.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/CodeGen/ExpandVectorPredication.h @@ -0,0 +1,23 @@ +//===-- ExpandVectorPredication.h - Expand vector predication ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CODEGEN_EXPANDVECTORPREDICATION_H +#define LLVM_CODEGEN_EXPANDVECTORPREDICATION_H + +#include "llvm/IR/PassManager.h" + +namespace llvm { + +class ExpandVectorPredicationPass + : public PassInfoMixin { +public: + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); +}; +} // end namespace llvm + +#endif // LLVM_CODEGEN_EXPANDVECTORPREDICATION_H diff --git a/llvm/include/llvm/CodeGen/Passes.h b/llvm/include/llvm/CodeGen/Passes.h --- a/llvm/include/llvm/CodeGen/Passes.h +++ b/llvm/include/llvm/CodeGen/Passes.h @@ -451,6 +451,11 @@ /// shuffles. FunctionPass *createExpandReductionsPass(); + /// This pass expands the vector predication intrinsics into unpredicated + /// instructions with selects or just the explicit vector length into the + /// predicate mask. + FunctionPass *createExpandVectorPredicationPass(); + // This pass expands memcmp() to load/stores. FunctionPass *createExpandMemCmpPass(); diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h --- a/llvm/include/llvm/IR/IntrinsicInst.h +++ b/llvm/include/llvm/IR/IntrinsicInst.h @@ -225,9 +225,11 @@ /// \return the mask parameter or nullptr. Value *getMaskParam() const; + void setMaskParam(Value *); /// \return the vector length parameter or nullptr. Value *getVectorLengthParam() const; + void setVectorLengthParam(Value *); /// \return whether the vector length param can be ignored. bool canIgnoreVectorLengthParam() const; diff --git a/llvm/include/llvm/IR/PredicatedInst.h b/llvm/include/llvm/IR/PredicatedInst.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/IR/PredicatedInst.h @@ -0,0 +1,87 @@ +//===-- PredicatedInst.h - Utility for predicated instructions --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines various classes for working with vector-predicated +// instructions. Predicated instructions are either regular instructions or +// calls to Vector Predication (VP) intrinsics that have a mask and an explicit +// vector length argument. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_IR_PREDICATEDINST_H +#define LLVM_IR_PREDICATEDINST_H + +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" + +#include + +namespace llvm { + +class BasicBlock; + +class PredicatedInstruction : public User { +public: + // The PredicatedInstruction class is intended to be used as a utility, and is + // never itself instantiated. + PredicatedInstruction() = delete; + ~PredicatedInstruction() = delete; + + void copyIRFlags(const Value *V, bool IncludeWrapFlags) { + cast(this)->copyIRFlags(V, IncludeWrapFlags); + } + + BasicBlock *getParent() { return cast(this)->getParent(); } + const BasicBlock *getParent() const { + return cast(this)->getParent(); + } + + void *operator new(size_t s) = delete; + + /// \returns the Mask parameter of this instruction or nullptr if there is + /// none. + Value *getMaskParam() const; + + /// \returns the Explicit Vector Length parameter of this instruction or + /// nullptr if there is none. + Value *getVectorLengthParam() const; + + /// \returns True if the passed vector length value has no predicating effect + /// on the operation. + bool canIgnoreVectorLengthParam() const; + + /// \returns True if the static operator of this instruction has a mask or + /// vector length parameter. + bool isVectorPredicatedOp() const; + + /// \returns Whether this is functionally a binary operator. + bool isBinaryOp() const; + + /// \returns the effective Opcode of this operation (ignoring the mask and + /// vector length param). + unsigned getOpcode() const; + + static bool classof(const Instruction *I) { return isa(I); } + static bool classof(const ConstantExpr *CE) { return false; } + static bool classof(const Value *V) { return isa(V); } + + /// Convenience function for getting all the fast-math flags, which must be an + /// operator which supports these flags. See LangRef.html for the meaning of + /// these flags. + FastMathFlags getFastMathFlags() const; +}; + +} // namespace llvm + +#endif // LLVM_IR_PREDICATEDINST_H diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h --- a/llvm/include/llvm/InitializePasses.h +++ b/llvm/include/llvm/InitializePasses.h @@ -147,6 +147,7 @@ void initializeExpandMemCmpPassPass(PassRegistry&); void initializeExpandPostRAPass(PassRegistry&); void initializeExpandReductionsPass(PassRegistry&); +void initializeExpandVectorPredicationPass(PassRegistry&); void initializeMakeGuardsExplicitLegacyPassPass(PassRegistry&); void initializeExternalAAWrapperPassPass(PassRegistry&); void initializeFEntryInserterPass(PassRegistry&); diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -847,6 +847,12 @@ return TTIImpl->useReductionIntrinsic(Opcode, Ty, Flags); } +TargetTransformInfo::VPLegalization +TargetTransformInfo::getVPLegalizationStrategy( + const PredicatedInstruction &PI) const { + return TTIImpl->getVPLegalizationStrategy(PI); +} + bool TargetTransformInfo::shouldExpandReduction(const IntrinsicInst *II) const { return TTIImpl->shouldExpandReduction(II); } diff --git a/llvm/lib/CodeGen/CMakeLists.txt b/llvm/lib/CodeGen/CMakeLists.txt --- a/llvm/lib/CodeGen/CMakeLists.txt +++ b/llvm/lib/CodeGen/CMakeLists.txt @@ -27,6 +27,7 @@ ExpandMemCmp.cpp ExpandPostRAPseudos.cpp ExpandReductions.cpp + ExpandVectorPredication.cpp FaultMaps.cpp FEntryInserter.cpp FinalizeISel.cpp diff --git a/llvm/lib/CodeGen/ExpandVectorPredication.cpp b/llvm/lib/CodeGen/ExpandVectorPredication.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/CodeGen/ExpandVectorPredication.cpp @@ -0,0 +1,398 @@ +//===--- CodeGen/ExpandVectorPredication.cpp - Expand VP intrinsics -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass implements IR expansion for vector predication intrinsics, allowing +// targets to enable vector predication until just before codegen. +// +//===----------------------------------------------------------------------===// + +#include "llvm/CodeGen/ExpandVectorPredication.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/CodeGen/Passes.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PredicatedInst.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" + +using namespace llvm; + +using VPLegalization = TargetTransformInfo::VPLegalization; + +#define DEBUG_TYPE "expand-vec-pred" + +STATISTIC(NumFoldedVL, "Number of folded vector length params"); +STATISTIC(NumLoweredVPOps, "Number of folded vector predication operations"); + +/// \returns Whether the vector mask \p MaskVal has all lane bits set. +static bool IsAllTrueMask(Value *MaskVal) { + auto ConstVec = dyn_cast(MaskVal); + if (!ConstVec) + return false; + return ConstVec->isAllOnesValue(); +} + +namespace { + +/// \returns A vector with ascending integer indices (<0, 1, ..., NumElems-1>). +Value *createStepVector(IRBuilder<> &Builder, int32_t ElemBits, + int32_t NumElems) { + // TODO add caching + SmallVector ConstElems; + + Type *LaneTy = Builder.getIntNTy(ElemBits); + + for (int32_t Idx = 0; Idx < NumElems; ++Idx) { + ConstElems.push_back(ConstantInt::get(LaneTy, Idx, false)); + } + + return ConstantVector::get(ConstElems); +} + +/// Computes the smallest integer bit width to hold the step vector <0, .., +/// NumVectorElements - 1> +static unsigned getLeastLaneBitsForStepVector(unsigned NumVectorElements) { + unsigned MostSignificantOne = + llvm::countLeadingZeros(NumVectorElements, ZB_Undefined); + return std::max(IntegerType::MIN_INT_BITS, 64 - MostSignificantOne); +} + +static unsigned getLaneBitsForEVLCompare(LLVMContext &Ctx, unsigned StaticVL, + const TargetTransformInfo &TTI) { + // The smallest integer to hold <0, .., ElemCount.Min -1> + // Cannot choose less bits than this or the expansion will be invalid. + unsigned MinLaneBits = getLeastLaneBitsForStepVector(StaticVL); + LLVM_DEBUG(dbgs() << "Least lane bits for " << StaticVL << " is " + << MinLaneBits << "\n";); + + // If these vector operation will be lowered into scalar code, choose the + // smallest integer type. + if (TTI.getRegisterBitWidth(true) == 0) + return MinLaneBits; + + // Otw, the generated vector operation will likely map to vector instructions. + // The largest bit width to fit the EVL expansion in one vector register. + unsigned MaxLaneBits = std::min( + IntegerType::MAX_INT_BITS, TTI.getRegisterBitWidth(true) / StaticVL); + + // Many SIMD instruction are restricted in their supported lane bit widths. + // We choose the bit width that gives us the cheapest vector compare. + int Cheapest = std::numeric_limits::max(); + unsigned CheapestLaneBits = MinLaneBits; + for (auto LaneBits = MinLaneBits; LaneBits < MaxLaneBits; ++LaneBits) { + int VecCmpCost = TTI.getCmpSelInstrCost( + Instruction::ICmp, VectorType::get(Type::getIntNTy(Ctx, LaneBits), + StaticVL, /* Scalable */ false)); + if (VecCmpCost < Cheapest) { + Cheapest = VecCmpCost; + CheapestLaneBits = LaneBits; + } + } + + return CheapestLaneBits; +} + +/// \returns A bitmask that is true where the lane position is less-than \p +/// EVLParam +/// +/// \p Builder +/// Used for instruction creation. +/// \p VLParam +/// The explicit vector length parameter to test against the lane +/// positions. +/// \p ElemCount +/// Static (potentially scalable) number of vector elements +/// \p TTI +Value *convertEVLToMask(IRBuilder<> &Builder, Value *EVLParam, + ElementCount ElemCount, + const TargetTransformInfo &TTI) { + // TODO EVL to mask conversion using the expansion below requires a way to + // create a scalable step vector. + assert(!ElemCount.Scalable && "TODO ConvertEVLToMask for scalable VP ops"); + unsigned NumElems = ElemCount.Min; + + unsigned ElemBits = + getLaneBitsForEVLCompare(Builder.getContext(), NumElems, TTI); + + Type *LaneTy = Builder.getIntNTy(ElemBits); + + auto ExtVLParam = Builder.CreateZExtOrTrunc(EVLParam, LaneTy); + auto VLSplat = Builder.CreateVectorSplat(NumElems, ExtVLParam); + + auto IdxVec = createStepVector(Builder, ElemBits, NumElems); + + return Builder.CreateICmp(CmpInst::ICMP_ULT, IdxVec, VLSplat); +} + +/// \returns A non-excepting divisor constant for this type. +Constant *getSafeDivisor(Type *DivTy) { + assert(DivTy->isIntOrIntVectorTy()); + return Constant::getAllOnesValue(DivTy); +} + +/// Transfer operation properties from \p OldVPI to \p NewVal. +void transferDecorations(Value &NewVal, PredicatedInstruction &OldPI) { + auto NewInst = dyn_cast(&NewVal); + if (!NewInst || !isa(NewVal)) + return; + + auto OldFMOp = dyn_cast(&OldPI); + if (!OldFMOp) + return; + + NewInst->setFastMathFlags(OldFMOp->getFastMathFlags()); +} + +/// Transfer all properties from \p OldOp to \p NewOp and replace all uses. +/// OldVP gets erased. +void replaceOperation(Value &NewOp, PredicatedInstruction &OldOp) { + transferDecorations(NewOp, OldOp); + OldOp.replaceAllUsesWith(&NewOp); + cast(OldOp).eraseFromParent(); +} + +//// Legalizers { + +/// \brief Lower this VP binary operator to a non-VP binary operator. +Value *lowerPredicatedBinaryOperator(IRBuilder<> &Builder, + PredicatedInstruction &PI) { + assert(PI.canIgnoreVectorLengthParam()); + assert(PI.isBinaryOp()); + + auto FirstOp = PI.getOperand(0); + auto SndOp = PI.getOperand(1); + + auto Mask = PI.getMaskParam(); + + // Blend in safe operands + if (Mask && !IsAllTrueMask(Mask)) { + switch (PI.getOpcode()) { + default: + // can safely ignore the predicate + break; + + // Division operators need a safe divisor on masked-off lanes (1) + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::URem: + case Instruction::SRem: + // 2nd operand must not be zero + auto SafeDivisor = getSafeDivisor(PI.getType()); + SndOp = Builder.CreateSelect(Mask, SndOp, SafeDivisor); + } + } + + auto NewBinOp = + Builder.CreateBinOp(static_cast(PI.getOpcode()), + FirstOp, SndOp, PI.getName(), nullptr); + + replaceOperation(*NewBinOp, PI); + return NewBinOp; +} + +void discardEVLParameter(PredicatedInstruction &PI) { + LLVM_DEBUG(dbgs() << "Discard EVL parameter in " << PI << "\n"); + + if (PI.canIgnoreVectorLengthParam()) + return; + + Value *EVLParam = PI.getVectorLengthParam(); + if (!EVLParam) + return; + + VPIntrinsic &VPI = cast(PI); + ElementCount StaticElemCount = VPI.getStaticVectorLength(); + assert(!StaticElemCount.Scalable && "TODO # elements query"); + + auto MaxEVL = ConstantInt::get(Type::getInt32Ty(PI.getContext()), + StaticElemCount.Min, false); + VPI.setVectorLengthParam(MaxEVL); +} + +Value *foldEVLIntoMask(PredicatedInstruction &OldPI, + const TargetTransformInfo &TTI) { + LLVM_DEBUG(dbgs() << "Folding vlen for " << OldPI << '\n'); + + IRBuilder<> Builder(cast(&OldPI)); + + // No %evl parameter and so nothing to do here + if (OldPI.canIgnoreVectorLengthParam()) { + return &OldPI; + } + + // Only VP intrinsics can have a %evl parameter + VPIntrinsic &VPI = cast(OldPI); + Value *OldMaskParam = VPI.getMaskParam(); + Value *OldEVLParam = VPI.getVectorLengthParam(); + assert(OldMaskParam && "no mask param to fold the vl param into"); + assert(OldEVLParam && "no EVL param to fold away"); + + LLVM_DEBUG(dbgs() << "OLD evl: " << *OldEVLParam << '\n'); + LLVM_DEBUG(dbgs() << "OLD mask: " << *OldMaskParam << '\n'); + + // Lower EVL to Mask + ElementCount ElemCount = VPI.getStaticVectorLength(); + auto *VLMask = convertEVLToMask(Builder, OldEVLParam, ElemCount, TTI); + auto NewMaskParam = Builder.CreateAnd(VLMask, OldMaskParam); + VPI.setMaskParam(NewMaskParam); + + // Disable VL + auto FullVL = Builder.getInt32(ElemCount.Min); + VPI.setVectorLengthParam(FullVL); + assert(VPI.canIgnoreVectorLengthParam() && + "transformation did not render the evl param ineffective!"); + + LLVM_DEBUG(dbgs() << "NEW vlen: " << *FullVL << "\n" + << "NEW mask: " << *NewMaskParam << "\n"); + + // re-asses the modified instruction + return &VPI; +} + +Value *lowerToUnpredicatedOperation(PredicatedInstruction &OldPI) { + LLVM_DEBUG(dbgs() << "Lowering to unpredicated op: " << OldPI << '\n'); + + IRBuilder<> Builder(cast(&OldPI)); + + // Try lowering to a LLVM instruction first. + unsigned OC = OldPI.getOpcode(); +#define FIRST_BINARY_INST(X) unsigned FirstBinOp = X; +#define LAST_BINARY_INST(X) unsigned LastBinOp = X; +#include "llvm/IR/Instruction.def" + + if (FirstBinOp <= OC && OC <= LastBinOp) { + return lowerPredicatedBinaryOperator(Builder, OldPI); + } + + return &OldPI; +} + +//// } Legalizers + +struct TransformJob { + PredicatedInstruction *PI; + TargetTransformInfo::VPLegalization Strategy; + TransformJob(PredicatedInstruction *PI, + TargetTransformInfo::VPLegalization InitStrat) + : PI(PI), Strategy(InitStrat) {} + + bool isDone() const { return Strategy.doNothing(); } +}; + +void sanitizeStrategy(Instruction &I, VPLegalization &LegalizeStrat) { + if (!I.mayHaveSideEffects()) + return; + + // Do not discard the EVL parameter where this may invoke spurious side + // effects. + if (LegalizeStrat.EVLParamStrategy == VPLegalization::Discard) { + LegalizeStrat.EVLParamStrategy = VPLegalization::Convert; + } +} + +/// \brief Expand llvm.vp.* intrinsics as requested by \p TTI. +bool expandVectorPredication(Function &F, const TargetTransformInfo &TTI) { + // Holds all vector-predicated ops with an effective vector length param that + SmallVector Worklist; + + for (auto &I : instructions(F)) { + auto &PI = cast(I); + auto VPStrat = TTI.getVPLegalizationStrategy(PI); + sanitizeStrategy(I, VPStrat); + if (!VPStrat.doNothing()) { + Worklist.emplace_back(&PI, VPStrat); + } + } + if (Worklist.empty()) + return false; + + LLVM_DEBUG(dbgs() << "\n:::: Transforming instructions. ::::\n"); + for (TransformJob Job : Worklist) { + // Transform the EVL parameter + switch (Job.Strategy.EVLParamStrategy) { + case VPLegalization::Legal: + break; + case VPLegalization::Discard: { + discardEVLParameter(*Job.PI); + } break; + case VPLegalization::Convert: { + if (foldEVLIntoMask(*Job.PI, TTI)) { + ++NumFoldedVL; + } + } break; + } + Job.Strategy.EVLParamStrategy = VPLegalization::Legal; + + // Replace the operator + switch (Job.Strategy.OpStrategy) { + case VPLegalization::Legal: + break; + case VPLegalization::Discard: + llvm_unreachable("Invalid strategy for operators."); + case VPLegalization::Convert: { + lowerToUnpredicatedOperation(*Job.PI); + ++NumLoweredVPOps; + } break; + } + Job.Strategy.OpStrategy = VPLegalization::Legal; + + assert(Job.isDone() && "incomplete transformation"); + } + + return true; +} + +class ExpandVectorPredication : public FunctionPass { +public: + static char ID; + ExpandVectorPredication() : FunctionPass(ID) { + initializeExpandVectorPredicationPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + const auto *TTI = &getAnalysis().getTTI(F); + return expandVectorPredication(F, *TTI); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + AU.setPreservesCFG(); + } +}; +} // namespace + +char ExpandVectorPredication::ID; +INITIALIZE_PASS_BEGIN(ExpandVectorPredication, "expand-vec-pred", + "Expand vector predication intrinsics", false, false) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_END(ExpandVectorPredication, "expand-vec-pred", + "Expand vector predication intrinsics", false, false) + +FunctionPass *llvm::createExpandVectorPredicationPass() { + return new ExpandVectorPredication(); +} + +PreservedAnalyses +ExpandVectorPredicationPass::run(Function &F, FunctionAnalysisManager &AM) { + const auto &TTI = AM.getResult(F); + if (!expandVectorPredication(F, TTI)) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserveSet(); + return PA; +} diff --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp --- a/llvm/lib/CodeGen/TargetPassConfig.cpp +++ b/llvm/lib/CodeGen/TargetPassConfig.cpp @@ -702,6 +702,11 @@ // Instrument function entry and exit, e.g. with calls to mcount(). addPass(createPostInlineEntryExitInstrumenterPass()); + // Expand vector predication intrinsics into standard IR instructions. + // This pass has to run before ScalarizeMaskedMemIntrin and ExpandReduction + // passes since it emits those kinds of intrinsics. + addPass(createExpandVectorPredicationPass()); + // Add scalarization of target's unsupported masked memory intrinsics pass. // the unsupported intrinsic will be replaced with a chain of basic blocks, // that stores/loads element one-by-one if the appropriate mask bit is set. diff --git a/llvm/lib/IR/CMakeLists.txt b/llvm/lib/IR/CMakeLists.txt --- a/llvm/lib/IR/CMakeLists.txt +++ b/llvm/lib/IR/CMakeLists.txt @@ -44,6 +44,7 @@ PassManager.cpp PassRegistry.cpp PassTimingInfo.cpp + PredicatedInst.cpp SafepointIRVerifier.cpp ProfileSummary.cpp Statepoint.cpp diff --git a/llvm/lib/IR/IntrinsicInst.cpp b/llvm/lib/IR/IntrinsicInst.cpp --- a/llvm/lib/IR/IntrinsicInst.cpp +++ b/llvm/lib/IR/IntrinsicInst.cpp @@ -196,6 +196,12 @@ return nullptr; } +void VPIntrinsic::setMaskParam(Value* NewMask) { + auto MaskPos = GetMaskParamPos(getIntrinsicID()); + assert(MaskPos.hasValue()); + setArgOperand(MaskPos.getValue(), NewMask); +} + Value *VPIntrinsic::getVectorLengthParam() const { auto vlenPos = GetVectorLengthParamPos(getIntrinsicID()); if (vlenPos) @@ -203,6 +209,12 @@ return nullptr; } +void VPIntrinsic::setVectorLengthParam(Value* NewEVL) { + auto EVLPos = GetVectorLengthParamPos(getIntrinsicID()); + assert(EVLPos.hasValue()); + setArgOperand(EVLPos.getValue(), NewEVL); +} + Optional VPIntrinsic::GetMaskParamPos(Intrinsic::ID IntrinsicID) { switch (IntrinsicID) { default: diff --git a/llvm/lib/IR/PredicatedInst.cpp b/llvm/lib/IR/PredicatedInst.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/IR/PredicatedInst.cpp @@ -0,0 +1,72 @@ +//===--- llvm/PredicatedInst.cpp - Vector-predication utility subclass ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the PredicationInstruction class. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include + +namespace { +#define FIRST_BINARY_INST(X) const unsigned FirstBinOp = X; +#define LAST_BINARY_INST(X) const unsigned LastBinOp = X; +#include "llvm/IR/Instruction.def" +} // namespace + +namespace llvm { + +bool PredicatedInstruction::canIgnoreVectorLengthParam() const { + auto VPI = dyn_cast(this); + if (!VPI) + return true; + + return VPI->canIgnoreVectorLengthParam(); +} + +FastMathFlags PredicatedInstruction::getFastMathFlags() const { + return cast(this)->getFastMathFlags(); +} + +bool PredicatedInstruction::isBinaryOp() const { + unsigned FuncOC = getOpcode(); + + return FirstBinOp <= FuncOC && FuncOC <= LastBinOp; +} + +Value *PredicatedInstruction::getMaskParam() const { + auto thisVP = dyn_cast(this); + if (!thisVP) + return nullptr; + return thisVP->getMaskParam(); +} + +Value *PredicatedInstruction::getVectorLengthParam() const { + auto thisVP = dyn_cast(this); + if (!thisVP) + return nullptr; + return thisVP->getVectorLengthParam(); +} + +bool PredicatedInstruction::isVectorPredicatedOp() const { + return isa(this); +} + +unsigned PredicatedInstruction::getOpcode() const { + auto *VPInst = dyn_cast(this); + + if (!VPInst) + return cast(this)->getOpcode(); + + return VPInst->getFunctionalOpcode(); +} + +} // namespace llvm diff --git a/llvm/test/CodeGen/AArch64/O0-pipeline.ll b/llvm/test/CodeGen/AArch64/O0-pipeline.ll --- a/llvm/test/CodeGen/AArch64/O0-pipeline.ll +++ b/llvm/test/CodeGen/AArch64/O0-pipeline.ll @@ -26,6 +26,7 @@ ; CHECK-NEXT: Lower constant intrinsics ; CHECK-NEXT: Remove unreachable blocks from the CFG ; CHECK-NEXT: Instrument function entry/exit with calls to e.g. mcount() (post inlining) +; CHECK-NEXT: Expand vector predication intrinsics ; CHECK-NEXT: Scalarize Masked Memory Intrinsics ; CHECK-NEXT: Expand reduction intrinsics ; CHECK-NEXT: AArch64 Stack Tagging diff --git a/llvm/test/CodeGen/AArch64/O3-pipeline.ll b/llvm/test/CodeGen/AArch64/O3-pipeline.ll --- a/llvm/test/CodeGen/AArch64/O3-pipeline.ll +++ b/llvm/test/CodeGen/AArch64/O3-pipeline.ll @@ -55,6 +55,7 @@ ; CHECK-NEXT: Constant Hoisting ; CHECK-NEXT: Partially inline calls to library functions ; CHECK-NEXT: Instrument function entry/exit with calls to e.g. mcount() (post inlining) +; CHECK-NEXT: Expand vector predication intrinsics ; CHECK-NEXT: Scalarize Masked Memory Intrinsics ; CHECK-NEXT: Expand reduction intrinsics ; CHECK-NEXT: Dominator Tree Construction diff --git a/llvm/test/CodeGen/ARM/O3-pipeline.ll b/llvm/test/CodeGen/ARM/O3-pipeline.ll --- a/llvm/test/CodeGen/ARM/O3-pipeline.ll +++ b/llvm/test/CodeGen/ARM/O3-pipeline.ll @@ -35,6 +35,7 @@ ; CHECK-NEXT: Constant Hoisting ; CHECK-NEXT: Partially inline calls to library functions ; CHECK-NEXT: Instrument function entry/exit with calls to e.g. mcount() (post inlining) +; CHECK-NEXT: Expand vector predication intrinsics ; CHECK-NEXT: Scalarize Masked Memory Intrinsics ; CHECK-NEXT: Expand reduction intrinsics ; CHECK-NEXT: Dominator Tree Construction diff --git a/llvm/test/CodeGen/Generic/expand-vp.ll b/llvm/test/CodeGen/Generic/expand-vp.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/Generic/expand-vp.ll @@ -0,0 +1,47 @@ +; RUN: opt --expand-vec-pred -S < %s | FileCheck %s + +define void @test_vp_int(<8 x i32> %i0, <8 x i32> %i1, <8 x i32> %i2, <8 x i32> %f3, <8 x i1> %m, i32 %n) { +; CHECK-NOT: {{call.* @llvm.vp.add}} +; CHECK-NOT: {{call.* @llvm.vp.sub}} +; CHECK-NOT: {{call.* @llvm.vp.mul}} +; CHECK-NOT: {{call.* @llvm.vp.sdiv}} +; CHECK-NOT: {{call.* @llvm.vp.srem}} +; CHECK-NOT: {{call.* @llvm.vp.udiv}} +; CHECK-NOT: {{call.* @llvm.vp.urem}} +; CHECK-NOT: {{call.* @llvm.vp.and}} +; CHECK-NOT: {{call.* @llvm.vp.or}} +; CHECK-NOT: {{call.* @llvm.vp.xor}} +; CHECK-NOT: {{call.* @llvm.vp.ashr}} +; CHECK-NOT: {{call.* @llvm.vp.lshr}} +; CHECK-NOT: {{call.* @llvm.vp.shl}} + %r0 = call <8 x i32> @llvm.vp.add.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r1 = call <8 x i32> @llvm.vp.sub.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r2 = call <8 x i32> @llvm.vp.mul.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r3 = call <8 x i32> @llvm.vp.sdiv.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r4 = call <8 x i32> @llvm.vp.srem.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r5 = call <8 x i32> @llvm.vp.udiv.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r6 = call <8 x i32> @llvm.vp.urem.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r7 = call <8 x i32> @llvm.vp.and.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r8 = call <8 x i32> @llvm.vp.or.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r9 = call <8 x i32> @llvm.vp.xor.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %rA = call <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %rB = call <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %rC = call <8 x i32> @llvm.vp.shl.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + ret void +} + +; integer arith +declare <8 x i32> @llvm.vp.add.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.sub.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.mul.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.sdiv.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.srem.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.udiv.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.urem.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +; bit arith +declare <8 x i32> @llvm.vp.and.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.xor.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.or.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.shl.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) diff --git a/llvm/test/CodeGen/X86/O0-pipeline.ll b/llvm/test/CodeGen/X86/O0-pipeline.ll --- a/llvm/test/CodeGen/X86/O0-pipeline.ll +++ b/llvm/test/CodeGen/X86/O0-pipeline.ll @@ -28,6 +28,7 @@ ; CHECK-NEXT: Lower constant intrinsics ; CHECK-NEXT: Remove unreachable blocks from the CFG ; CHECK-NEXT: Instrument function entry/exit with calls to e.g. mcount() (post inlining) +; CHECK-NEXT: Expand vector predication intrinsics ; CHECK-NEXT: Scalarize Masked Memory Intrinsics ; CHECK-NEXT: Expand reduction intrinsics ; CHECK-NEXT: Expand indirectbr instructions diff --git a/llvm/test/CodeGen/X86/O3-pipeline.ll b/llvm/test/CodeGen/X86/O3-pipeline.ll --- a/llvm/test/CodeGen/X86/O3-pipeline.ll +++ b/llvm/test/CodeGen/X86/O3-pipeline.ll @@ -47,6 +47,7 @@ ; CHECK-NEXT: Constant Hoisting ; CHECK-NEXT: Partially inline calls to library functions ; CHECK-NEXT: Instrument function entry/exit with calls to e.g. mcount() (post inlining) +; CHECK-NEXT: Expand vector predication intrinsics ; CHECK-NEXT: Scalarize Masked Memory Intrinsics ; CHECK-NEXT: Expand reduction intrinsics ; CHECK-NEXT: Dominator Tree Construction diff --git a/llvm/tools/llc/llc.cpp b/llvm/tools/llc/llc.cpp --- a/llvm/tools/llc/llc.cpp +++ b/llvm/tools/llc/llc.cpp @@ -318,6 +318,7 @@ initializeVectorization(*Registry); initializeScalarizeMaskedMemIntrinPass(*Registry); initializeExpandReductionsPass(*Registry); + initializeExpandVectorPredicationPass(*Registry); initializeHardwareLoopsPass(*Registry); initializeTransformUtils(*Registry); diff --git a/llvm/tools/opt/opt.cpp b/llvm/tools/opt/opt.cpp --- a/llvm/tools/opt/opt.cpp +++ b/llvm/tools/opt/opt.cpp @@ -580,6 +580,7 @@ initializePostInlineEntryExitInstrumenterPass(Registry); initializeUnreachableBlockElimLegacyPassPass(Registry); initializeExpandReductionsPass(Registry); + initializeExpandVectorPredicationPass(Registry); initializeWasmEHPreparePass(Registry); initializeWriteBitcodePassPass(Registry); initializeHardwareLoopsPass(Registry);