diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -61,6 +61,7 @@ class Type; class User; class Value; +class VPIntrinsic; struct KnownBits; template class Optional; @@ -1379,6 +1380,38 @@ /// Intrinsics") Use of %evl is discouraged when that is not the case. bool hasActiveVectorLength() const; + struct VPLegalization { + enum VPTransform { + // keep the predicating parameter + Legal = 0, + // where legal, discard the predicate parameter + Discard = 1, + // transform into something else that is also predicating + Convert = 2 + }; + + // How to transform the EVL parameter. + // Legal: keep the EVL parameter as it is. + // Discard: Ignore the EVL parameter where it is safe to do so. + // Convert: Fold the EVL into the mask parameter. + VPTransform EVLParamStrategy; + + // How to transform the operator. + // Legal: The target supports this operator. + // Convert: Convert this to a non-VP operation. + // The 'Discard' strategy is invalid. + VPTransform OpStrategy; + + bool shouldDoNothing() const { + return (EVLParamStrategy == Legal) && (OpStrategy == Legal); + } + VPLegalization(VPTransform EVLParamStrategy, VPTransform OpStrategy) + : EVLParamStrategy(EVLParamStrategy), OpStrategy(OpStrategy) {} + }; + + /// \returns How the target needs this vector-predicated operation to be + /// transformed. + VPLegalization getVPLegalizationStrategy(const VPIntrinsic &PI) const; /// @} /// @} @@ -1688,6 +1721,8 @@ virtual bool supportsScalableVectors() const = 0; virtual bool hasActiveVectorLength() const = 0; virtual InstructionCost getInstructionLatency(const Instruction *I) = 0; + virtual VPLegalization + getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0; }; template @@ -2259,6 +2294,11 @@ InstructionCost getInstructionLatency(const Instruction *I) override { return Impl.getInstructionLatency(I); } + + VPLegalization + getVPLegalizationStrategy(const VPIntrinsic &PI) const override { + return Impl.getVPLegalizationStrategy(PI); + } }; template diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -750,6 +750,13 @@ bool hasActiveVectorLength() const { return false; } + TargetTransformInfo::VPLegalization + getVPLegalizationStrategy(const VPIntrinsic &PI) const { + return TargetTransformInfo::VPLegalization( + /* EVLParamStrategy */ TargetTransformInfo::VPLegalization::Discard, + /* OperatorStrategy */ TargetTransformInfo::VPLegalization::Convert); + } + protected: // Obtain the minimum required size to hold the value (without the sign) // In case of a vector it returns the min required size for one element. diff --git a/llvm/include/llvm/CodeGen/ExpandVectorPredication.h b/llvm/include/llvm/CodeGen/ExpandVectorPredication.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/CodeGen/ExpandVectorPredication.h @@ -0,0 +1,23 @@ +//===-- ExpandVectorPredication.h - Expand vector predication ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CODEGEN_EXPANDVECTORPREDICATION_H +#define LLVM_CODEGEN_EXPANDVECTORPREDICATION_H + +#include "llvm/IR/PassManager.h" + +namespace llvm { + +class ExpandVectorPredicationPass + : public PassInfoMixin { +public: + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); +}; +} // end namespace llvm + +#endif // LLVM_CODEGEN_EXPANDVECTORPREDICATION_H diff --git a/llvm/include/llvm/CodeGen/MachinePassRegistry.def b/llvm/include/llvm/CodeGen/MachinePassRegistry.def --- a/llvm/include/llvm/CodeGen/MachinePassRegistry.def +++ b/llvm/include/llvm/CodeGen/MachinePassRegistry.def @@ -103,6 +103,7 @@ #define DUMMY_FUNCTION_PASS(NAME, PASS_NAME, CONSTRUCTOR) #endif DUMMY_FUNCTION_PASS("expandmemcmp", ExpandMemCmpPass, ()) +DUMMY_FUNCTION_PASS("expandvp", ExpandVectorPredicationPass, ()) DUMMY_FUNCTION_PASS("gc-lowering", GCLoweringPass, ()) DUMMY_FUNCTION_PASS("shadow-stack-gc-lowering", ShadowStackGCLoweringPass, ()) DUMMY_FUNCTION_PASS("sjljehprepare", SjLjEHPreparePass, ()) diff --git a/llvm/include/llvm/CodeGen/Passes.h b/llvm/include/llvm/CodeGen/Passes.h --- a/llvm/include/llvm/CodeGen/Passes.h +++ b/llvm/include/llvm/CodeGen/Passes.h @@ -453,6 +453,11 @@ // the corresponding function in a vector library (e.g., SVML, libmvec). FunctionPass *createReplaceWithVeclibLegacyPass(); + /// This pass expands the vector predication intrinsics into unpredicated + /// instructions with selects or just the explicit vector length into the + /// predicate mask. + FunctionPass *createExpandVectorPredicationPass(); + // This pass expands memcmp() to load/stores. FunctionPass *createExpandMemCmpPass(); diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h --- a/llvm/include/llvm/IR/IntrinsicInst.h +++ b/llvm/include/llvm/IR/IntrinsicInst.h @@ -400,9 +400,11 @@ /// \return the mask parameter or nullptr. Value *getMaskParam() const; + void setMaskParam(Value *); /// \return the vector length parameter or nullptr. Value *getVectorLengthParam() const; + void setVectorLengthParam(Value *); /// \return whether the vector length param can be ignored. bool canIgnoreVectorLengthParam() const; diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h --- a/llvm/include/llvm/InitializePasses.h +++ b/llvm/include/llvm/InitializePasses.h @@ -154,6 +154,7 @@ void initializeExpandMemCmpPassPass(PassRegistry&); void initializeExpandPostRAPass(PassRegistry&); void initializeExpandReductionsPass(PassRegistry&); +void initializeExpandVectorPredicationPass(PassRegistry &); void initializeMakeGuardsExplicitLegacyPassPass(PassRegistry&); void initializeExternalAAWrapperPassPass(PassRegistry&); void initializeFEntryInserterPass(PassRegistry&); diff --git a/llvm/include/llvm/LinkAllPasses.h b/llvm/include/llvm/LinkAllPasses.h --- a/llvm/include/llvm/LinkAllPasses.h +++ b/llvm/include/llvm/LinkAllPasses.h @@ -197,6 +197,7 @@ (void) llvm::createMergeFunctionsPass(); (void) llvm::createMergeICmpsLegacyPass(); (void) llvm::createExpandMemCmpPass(); + (void) llvm::createExpandVectorPredicationPass(); std::string buf; llvm::raw_string_ostream os(buf); (void) llvm::createPrintModulePass(os); diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -1026,6 +1026,11 @@ return TTIImpl->preferPredicatedReductionSelect(Opcode, Ty, Flags); } +TargetTransformInfo::VPLegalization +TargetTransformInfo::getVPLegalizationStrategy(const VPIntrinsic &VPI) const { + return TTIImpl->getVPLegalizationStrategy(VPI); +} + bool TargetTransformInfo::shouldExpandReduction(const IntrinsicInst *II) const { return TTIImpl->shouldExpandReduction(II); } diff --git a/llvm/lib/CodeGen/CMakeLists.txt b/llvm/lib/CodeGen/CMakeLists.txt --- a/llvm/lib/CodeGen/CMakeLists.txt +++ b/llvm/lib/CodeGen/CMakeLists.txt @@ -29,6 +29,7 @@ ExpandMemCmp.cpp ExpandPostRAPseudos.cpp ExpandReductions.cpp + ExpandVectorPredication.cpp FaultMaps.cpp FEntryInserter.cpp FinalizeISel.cpp diff --git a/llvm/lib/CodeGen/ExpandVectorPredication.cpp b/llvm/lib/CodeGen/ExpandVectorPredication.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/CodeGen/ExpandVectorPredication.cpp @@ -0,0 +1,469 @@ +//===----- CodeGen/ExpandVectorPredication.cpp - Expand VP intrinsics -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass implements IR expansion for vector predication intrinsics, allowing +// targets to enable vector predication until just before codegen. +// +//===----------------------------------------------------------------------===// + +#include "llvm/CodeGen/ExpandVectorPredication.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/CodeGen/Passes.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" + +using namespace llvm; + +using VPLegalization = TargetTransformInfo::VPLegalization; +using VPTransform = TargetTransformInfo::VPLegalization::VPTransform; + +// Keep this in sync with TargetTransformInfo::VPLegalization. +#define VPINTERNAL_VPLEGAL_CASES \ + VPINTERNAL_CASE(Legal) \ + VPINTERNAL_CASE(Discard) \ + VPINTERNAL_CASE(Convert) + +#define VPINTERNAL_CASE(X) "|" #X + +// Override options. +static cl::opt EVLTransformOverride( + "expandvp-override-evl-transform", cl::init(""), cl::Hidden, + cl::desc("Options: " VPINTERNAL_VPLEGAL_CASES + ". If non-empty, ignore " + "TargetTransformInfo and " + "always use this transformation for the %evl parameter (Used in " + "testing).")); + +static cl::opt MaskTransformOverride( + "expandvp-override-mask-transform", cl::init(""), cl::Hidden, + cl::desc("Options: " VPINTERNAL_VPLEGAL_CASES + ". If non-empty, Ignore " + "TargetTransformInfo and " + "always use this transformation for the %mask parameter (Used in " + "testing).")); + +#undef VPINTERNAL_CASE +#define VPINTERNAL_CASE(X) .Case(#X, VPLegalization::X) + +static VPTransform parseOverrideOption(const std::string &TextOpt) { + return StringSwitch(TextOpt) VPINTERNAL_VPLEGAL_CASES; +} + +#undef VPINTERNAL_VPLEGAL_CASES + +// Whether any override options are set. +static bool anyExpandVPOverridesSet() { + return !EVLTransformOverride.empty() || !MaskTransformOverride.empty(); +} + +#define DEBUG_TYPE "expandvp" + +STATISTIC(NumFoldedVL, "Number of folded vector length params"); +STATISTIC(NumLoweredVPOps, "Number of folded vector predication operations"); + +///// Helpers { + +/// \returns Whether the vector mask \p MaskVal has all lane bits set. +static bool isAllTrueMask(Value *MaskVal) { + auto *ConstVec = dyn_cast(MaskVal); + return ConstVec && ConstVec->isAllOnesValue(); +} + +/// \returns A non-excepting divisor constant for this type. +static Constant *getSafeDivisor(Type *DivTy) { + assert(DivTy->isIntOrIntVectorTy() && "Unsupported divisor type"); + return ConstantInt::get(DivTy, 1u, false); +} + +/// Transfer operation properties from \p OldVPI to \p NewVal. +static void transferDecorations(Value &NewVal, VPIntrinsic &VPI) { + auto *NewInst = dyn_cast(&NewVal); + if (!NewInst || !isa(NewVal)) + return; + + auto *OldFMOp = dyn_cast(&VPI); + if (!OldFMOp) + return; + + NewInst->setFastMathFlags(OldFMOp->getFastMathFlags()); +} + +/// Transfer all properties from \p OldOp to \p NewOp and replace all uses. +/// OldVP gets erased. +static void replaceOperation(Value &NewOp, VPIntrinsic &OldOp) { + transferDecorations(NewOp, OldOp); + OldOp.replaceAllUsesWith(&NewOp); + OldOp.eraseFromParent(); +} + +//// } Helpers + +namespace { + +// Expansion pass state at function scope. +struct CachingVPExpander { + Function &F; + const TargetTransformInfo &TTI; + + /// \returns A (fixed length) vector with ascending integer indices + /// (<0, 1, ..., NumElems-1>). + /// \p Builder + /// Used for instruction creation. + /// \p LaneTy + /// Integer element type of the result vector. + /// \p NumElems + /// Number of vector elements. + Value *createStepVector(IRBuilder<> &Builder, Type *LaneTy, + unsigned NumElems); + + /// \returns A bitmask that is true where the lane position is less-than \p + /// EVLParam + /// + /// \p Builder + /// Used for instruction creation. + /// \p VLParam + /// The explicit vector length parameter to test against the lane + /// positions. + /// \p ElemCount + /// Static (potentially scalable) number of vector elements. + Value *convertEVLToMask(IRBuilder<> &Builder, Value *EVLParam, + ElementCount ElemCount); + + Value *foldEVLIntoMask(VPIntrinsic &VPI); + + /// "Remove" the %evl parameter of \p PI by setting it to the static vector + /// length of the operation. + void discardEVLParameter(VPIntrinsic &PI); + + /// \brief Lower this VP binary operator to a unpredicated binary operator. + Value *expandPredicationInBinaryOperator(IRBuilder<> &Builder, + VPIntrinsic &PI); + + /// \brief Query TTI and expand the vector predication in \p P accordingly. + Value *expandPredication(VPIntrinsic &PI); + + /// \brief Determine how and whether the VPIntrinsic \p VPI shall be + /// expanded. This overrides TTI with the cl::opts listed at the top of this + /// file. + VPLegalization getVPLegalizationStrategy(const VPIntrinsic &VPI) const; + bool UsingTTIOverrides; + +public: + CachingVPExpander(Function &F, const TargetTransformInfo &TTI) + : F(F), TTI(TTI), UsingTTIOverrides(anyExpandVPOverridesSet()) {} + + bool expandVectorPredication(); +}; + +//// CachingVPExpander { + +Value *CachingVPExpander::createStepVector(IRBuilder<> &Builder, Type *LaneTy, + unsigned NumElems) { + // TODO add caching + SmallVector ConstElems; + + for (unsigned Idx = 0; Idx < NumElems; ++Idx) + ConstElems.push_back(ConstantInt::get(LaneTy, Idx, false)); + + return ConstantVector::get(ConstElems); +} + +Value *CachingVPExpander::convertEVLToMask(IRBuilder<> &Builder, + Value *EVLParam, + ElementCount ElemCount) { + // TODO add caching + // Scalable vector %evl conversion. + if (ElemCount.isScalable()) { + auto *M = Builder.GetInsertBlock()->getModule(); + Type *BoolVecTy = VectorType::get(Builder.getInt1Ty(), ElemCount); + Function *ActiveMaskFunc = Intrinsic::getDeclaration( + M, Intrinsic::get_active_lane_mask, {BoolVecTy, EVLParam->getType()}); + // `get_active_lane_mask` performs an implicit less-than comparison. + Value *ConstZero = Builder.getInt32(0); + return Builder.CreateCall(ActiveMaskFunc, {ConstZero, EVLParam}); + } + + // Fixed vector %evl conversion. + Type *LaneTy = EVLParam->getType(); + unsigned NumElems = ElemCount.getFixedValue(); + Value *VLSplat = Builder.CreateVectorSplat(NumElems, EVLParam); + Value *IdxVec = createStepVector(Builder, LaneTy, NumElems); + return Builder.CreateICmp(CmpInst::ICMP_ULT, IdxVec, VLSplat); +} + +Value * +CachingVPExpander::expandPredicationInBinaryOperator(IRBuilder<> &Builder, + VPIntrinsic &VPI) { + assert((isSafeToSpeculativelyExecute(&VPI) || + VPI.canIgnoreVectorLengthParam()) && + "Implicitly dropping %evl in non-speculatable operator!"); + + auto OC = static_cast(VPI.getFunctionalOpcode()); + assert(Instruction::isBinaryOp(OC)); + + Value *Op0 = VPI.getOperand(0); + Value *Op1 = VPI.getOperand(1); + Value *Mask = VPI.getMaskParam(); + + // Blend in safe operands. + if (Mask && !isAllTrueMask(Mask)) { + switch (OC) { + default: + // Can safely ignore the predicate. + break; + + // Division operators need a safe divisor on masked-off lanes (1). + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::URem: + case Instruction::SRem: + // 2nd operand must not be zero. + Value *SafeDivisor = getSafeDivisor(VPI.getType()); + Op1 = Builder.CreateSelect(Mask, Op1, SafeDivisor); + } + } + + Value *NewBinOp = Builder.CreateBinOp(OC, Op0, Op1, VPI.getName()); + + replaceOperation(*NewBinOp, VPI); + return NewBinOp; +} + +void CachingVPExpander::discardEVLParameter(VPIntrinsic &VPI) { + LLVM_DEBUG(dbgs() << "Discard EVL parameter in " << VPI << "\n"); + + if (VPI.canIgnoreVectorLengthParam()) + return; + + Value *EVLParam = VPI.getVectorLengthParam(); + if (!EVLParam) + return; + + ElementCount StaticElemCount = VPI.getStaticVectorLength(); + Value *MaxEVL = nullptr; + Type *Int32Ty = Type::getInt32Ty(VPI.getContext()); + if (StaticElemCount.isScalable()) { + // TODO add caching + auto *M = VPI.getModule(); + Function *VScaleFunc = + Intrinsic::getDeclaration(M, Intrinsic::vscale, Int32Ty); + IRBuilder<> Builder(VPI.getParent(), VPI.getIterator()); + Value *FactorConst = Builder.getInt32(StaticElemCount.getKnownMinValue()); + Value *VScale = Builder.CreateCall(VScaleFunc, {}, "vscale"); + MaxEVL = Builder.CreateMul(VScale, FactorConst, "scalable_size", + /*NUW*/ true, /*NSW*/ false); + } else { + MaxEVL = ConstantInt::get(Int32Ty, StaticElemCount.getFixedValue(), false); + } + VPI.setVectorLengthParam(MaxEVL); +} + +Value *CachingVPExpander::foldEVLIntoMask(VPIntrinsic &VPI) { + LLVM_DEBUG(dbgs() << "Folding vlen for " << VPI << '\n'); + + IRBuilder<> Builder(&VPI); + + // Ineffective %evl parameter and so nothing to do here. + if (VPI.canIgnoreVectorLengthParam()) + return &VPI; + + // Only VP intrinsics can have an %evl parameter. + Value *OldMaskParam = VPI.getMaskParam(); + Value *OldEVLParam = VPI.getVectorLengthParam(); + assert(OldMaskParam && "no mask param to fold the vl param into"); + assert(OldEVLParam && "no EVL param to fold away"); + + LLVM_DEBUG(dbgs() << "OLD evl: " << *OldEVLParam << '\n'); + LLVM_DEBUG(dbgs() << "OLD mask: " << *OldMaskParam << '\n'); + + // Convert the %evl predication into vector mask predication. + ElementCount ElemCount = VPI.getStaticVectorLength(); + Value *VLMask = convertEVLToMask(Builder, OldEVLParam, ElemCount); + Value *NewMaskParam = Builder.CreateAnd(VLMask, OldMaskParam); + VPI.setMaskParam(NewMaskParam); + + // Drop the %evl parameter. + discardEVLParameter(VPI); + assert(VPI.canIgnoreVectorLengthParam() && + "transformation did not render the evl param ineffective!"); + + // Reassess the modified instruction. + return &VPI; +} + +Value *CachingVPExpander::expandPredication(VPIntrinsic &VPI) { + LLVM_DEBUG(dbgs() << "Lowering to unpredicated op: " << VPI << '\n'); + + IRBuilder<> Builder(&VPI); + + // Try lowering to a LLVM instruction first. + unsigned OC = VPI.getFunctionalOpcode(); + + if (Instruction::isBinaryOp(OC)) + return expandPredicationInBinaryOperator(Builder, VPI); + + return &VPI; +} + +//// } CachingVPExpander + +struct TransformJob { + VPIntrinsic *PI; + TargetTransformInfo::VPLegalization Strategy; + TransformJob(VPIntrinsic *PI, TargetTransformInfo::VPLegalization InitStrat) + : PI(PI), Strategy(InitStrat) {} + + bool isDone() const { return Strategy.shouldDoNothing(); } +}; + +void sanitizeStrategy(Instruction &I, VPLegalization &LegalizeStrat) { + // Speculatable instructions do not strictly need predication. + if (isSafeToSpeculativelyExecute(&I)) { + // Converting a speculatable VP intrinsic means dropping %mask and %evl. + // No need to expand %evl into the %mask only to ignore that code. + if (LegalizeStrat.OpStrategy == VPLegalization::Convert) + LegalizeStrat.EVLParamStrategy = VPLegalization::Discard; + return; + } + + // We have to preserve the predicating effect of %evl for this + // non-speculatable VP intrinsic. + // 1) Never discard %evl. + // 2) If this VP intrinsic will be expanded to non-VP code, make sure that + // %evl gets folded into %mask. + if ((LegalizeStrat.EVLParamStrategy == VPLegalization::Discard) || + (LegalizeStrat.OpStrategy == VPLegalization::Convert)) { + LegalizeStrat.EVLParamStrategy = VPLegalization::Convert; + } +} + +VPLegalization +CachingVPExpander::getVPLegalizationStrategy(const VPIntrinsic &VPI) const { + auto VPStrat = TTI.getVPLegalizationStrategy(VPI); + if (LLVM_LIKELY(!UsingTTIOverrides)) { + // No overrides - we are in production. + return VPStrat; + } + + // Overrides set - we are in testing, the following does not need to be + // efficient. + VPStrat.EVLParamStrategy = parseOverrideOption(EVLTransformOverride); + VPStrat.OpStrategy = parseOverrideOption(MaskTransformOverride); + return VPStrat; +} + +/// \brief Expand llvm.vp.* intrinsics as requested by \p TTI. +bool CachingVPExpander::expandVectorPredication() { + SmallVector Worklist; + + // Collect all VPIntrinsics that need expansion and determine their expansion + // strategy. + for (auto &I : instructions(F)) { + auto *VPI = dyn_cast(&I); + if (!VPI) + continue; + auto VPStrat = getVPLegalizationStrategy(*VPI); + sanitizeStrategy(I, VPStrat); + if (!VPStrat.shouldDoNothing()) + Worklist.emplace_back(VPI, VPStrat); + } + if (Worklist.empty()) + return false; + + // Transform all VPIntrinsics on the worklist. + LLVM_DEBUG(dbgs() << "\n:::: Transforming " << Worklist.size() + << " instructions ::::\n"); + for (TransformJob Job : Worklist) { + // Transform the EVL parameter. + switch (Job.Strategy.EVLParamStrategy) { + case VPLegalization::Legal: + break; + case VPLegalization::Discard: + discardEVLParameter(*Job.PI); + break; + case VPLegalization::Convert: + if (foldEVLIntoMask(*Job.PI)) + ++NumFoldedVL; + break; + } + Job.Strategy.EVLParamStrategy = VPLegalization::Legal; + + // Replace with a non-predicated operation. + switch (Job.Strategy.OpStrategy) { + case VPLegalization::Legal: + break; + case VPLegalization::Discard: + llvm_unreachable("Invalid strategy for operators."); + case VPLegalization::Convert: + expandPredication(*Job.PI); + ++NumLoweredVPOps; + break; + } + Job.Strategy.OpStrategy = VPLegalization::Legal; + + assert(Job.isDone() && "incomplete transformation"); + } + + return true; +} +class ExpandVectorPredication : public FunctionPass { +public: + static char ID; + ExpandVectorPredication() : FunctionPass(ID) { + initializeExpandVectorPredicationPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + const auto *TTI = &getAnalysis().getTTI(F); + CachingVPExpander VPExpander(F, *TTI); + return VPExpander.expandVectorPredication(); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + AU.setPreservesCFG(); + } +}; +} // namespace + +char ExpandVectorPredication::ID; +INITIALIZE_PASS_BEGIN(ExpandVectorPredication, "expandvp", + "Expand vector predication intrinsics", false, false) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_END(ExpandVectorPredication, "expandvp", + "Expand vector predication intrinsics", false, false) + +FunctionPass *llvm::createExpandVectorPredicationPass() { + return new ExpandVectorPredication(); +} + +PreservedAnalyses +ExpandVectorPredicationPass::run(Function &F, FunctionAnalysisManager &AM) { + const auto &TTI = AM.getResult(F); + CachingVPExpander VPExpander(F, TTI); + if (!VPExpander.expandVectorPredication()) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserveSet(); + return PA; +} diff --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp --- a/llvm/lib/CodeGen/TargetPassConfig.cpp +++ b/llvm/lib/CodeGen/TargetPassConfig.cpp @@ -864,6 +864,11 @@ if (getOptLevel() != CodeGenOpt::None && !DisablePartialLibcallInlining) addPass(createPartiallyInlineLibCallsPass()); + // Expand vector predication intrinsics into standard IR instructions. + // This pass has to run before ScalarizeMaskedMemIntrin and ExpandReduction + // passes since it emits those kinds of intrinsics. + addPass(createExpandVectorPredicationPass()); + // Add scalarization of target's unsupported masked memory intrinsics pass. // the unsupported intrinsic will be replaced with a chain of basic blocks, // that stores/loads element one-by-one if the appropriate mask bit is set. diff --git a/llvm/lib/IR/IntrinsicInst.cpp b/llvm/lib/IR/IntrinsicInst.cpp --- a/llvm/lib/IR/IntrinsicInst.cpp +++ b/llvm/lib/IR/IntrinsicInst.cpp @@ -279,6 +279,11 @@ return nullptr; } +void VPIntrinsic::setMaskParam(Value *NewMask) { + auto MaskPos = GetMaskParamPos(getIntrinsicID()); + setArgOperand(*MaskPos, NewMask); +} + Value *VPIntrinsic::getVectorLengthParam() const { auto vlenPos = GetVectorLengthParamPos(getIntrinsicID()); if (vlenPos) @@ -286,6 +291,11 @@ return nullptr; } +void VPIntrinsic::setVectorLengthParam(Value *NewEVL) { + auto EVLPos = GetVectorLengthParamPos(getIntrinsicID()); + setArgOperand(*EVLPos, NewEVL); +} + Optional VPIntrinsic::GetMaskParamPos(Intrinsic::ID IntrinsicID) { switch (IntrinsicID) { default: diff --git a/llvm/test/CodeGen/AArch64/O0-pipeline.ll b/llvm/test/CodeGen/AArch64/O0-pipeline.ll --- a/llvm/test/CodeGen/AArch64/O0-pipeline.ll +++ b/llvm/test/CodeGen/AArch64/O0-pipeline.ll @@ -21,6 +21,7 @@ ; CHECK-NEXT: Shadow Stack GC Lowering ; CHECK-NEXT: Lower constant intrinsics ; CHECK-NEXT: Remove unreachable blocks from the CFG +; CHECK-NEXT: Expand vector predication intrinsics ; CHECK-NEXT: Scalarize Masked Memory Intrinsics ; CHECK-NEXT: Expand reduction intrinsics ; CHECK-NEXT: AArch64 Stack Tagging diff --git a/llvm/test/CodeGen/AArch64/O3-pipeline.ll b/llvm/test/CodeGen/AArch64/O3-pipeline.ll --- a/llvm/test/CodeGen/AArch64/O3-pipeline.ll +++ b/llvm/test/CodeGen/AArch64/O3-pipeline.ll @@ -56,6 +56,7 @@ ; CHECK-NEXT: Constant Hoisting ; CHECK-NEXT: Replace intrinsics with calls to vector library ; CHECK-NEXT: Partially inline calls to library functions +; CHECK-NEXT: Expand vector predication intrinsics ; CHECK-NEXT: Scalarize Masked Memory Intrinsics ; CHECK-NEXT: Expand reduction intrinsics ; CHECK-NEXT: Stack Safety Analysis diff --git a/llvm/test/CodeGen/ARM/O3-pipeline.ll b/llvm/test/CodeGen/ARM/O3-pipeline.ll --- a/llvm/test/CodeGen/ARM/O3-pipeline.ll +++ b/llvm/test/CodeGen/ARM/O3-pipeline.ll @@ -37,6 +37,7 @@ ; CHECK-NEXT: Constant Hoisting ; CHECK-NEXT: Replace intrinsics with calls to vector library ; CHECK-NEXT: Partially inline calls to library functions +; CHECK-NEXT: Expand vector predication intrinsics ; CHECK-NEXT: Scalarize Masked Memory Intrinsics ; CHECK-NEXT: Expand reduction intrinsics ; CHECK-NEXT: Natural Loop Information diff --git a/llvm/test/CodeGen/Generic/expand-vp.ll b/llvm/test/CodeGen/Generic/expand-vp.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/Generic/expand-vp.ll @@ -0,0 +1,245 @@ +; Partial expansion cases (still VP with parameter expansions). +; RUN: opt --expandvp --expandvp-override-evl-transform=Legal --expandvp-override-mask-transform=Legal -S < %s | FileCheck %s --check-prefix=LEGAL_LEGAL +; RUN: opt --expandvp --expandvp-override-evl-transform=Discard --expandvp-override-mask-transform=Legal -S < %s | FileCheck %s --check-prefix=DISCARD_LEGAL +; RUN: opt --expandvp --expandvp-override-evl-transform=Convert --expandvp-override-mask-transform=Legal -S < %s | FileCheck %s --check-prefix=CONVERT_LEGAL +; Full expansion cases (all expanded to non-VP). +; RUN: opt --expandvp --expandvp-override-evl-transform=Discard --expandvp-override-mask-transform=Convert -S < %s | FileCheck %s --check-prefix=ALL-CONVERT +; RUN: opt --expandvp -S < %s | FileCheck %s --check-prefix=ALL-CONVERT +; RUN: opt --expandvp --expandvp-override-evl-transform=Legal --expandvp-override-mask-transform=Convert -S < %s | FileCheck %s --check-prefix=ALL-CONVERT +; RUN: opt --expandvp --expandvp-override-evl-transform=Convert --expandvp-override-mask-transform=Convert -S < %s | FileCheck %s --check-prefix=ALL-CONVERT + + +; Fixed-width vectors +; Integer arith +declare <8 x i32> @llvm.vp.add.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.sub.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.mul.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.sdiv.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.srem.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.udiv.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.urem.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +; Bit arith +declare <8 x i32> @llvm.vp.and.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.xor.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.or.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) +declare <8 x i32> @llvm.vp.shl.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) + +; Fixed vector test function. +define void @test_vp_int_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x i32> %i2, <8 x i32> %f3, <8 x i1> %m, i32 %n) { + %r0 = call <8 x i32> @llvm.vp.add.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r1 = call <8 x i32> @llvm.vp.sub.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r2 = call <8 x i32> @llvm.vp.mul.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r3 = call <8 x i32> @llvm.vp.sdiv.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r4 = call <8 x i32> @llvm.vp.srem.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r5 = call <8 x i32> @llvm.vp.udiv.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r6 = call <8 x i32> @llvm.vp.urem.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r7 = call <8 x i32> @llvm.vp.and.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r8 = call <8 x i32> @llvm.vp.or.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r9 = call <8 x i32> @llvm.vp.xor.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %rA = call <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %rB = call <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %rC = call <8 x i32> @llvm.vp.shl.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + ret void +} + +; Scalable-width vectors +; Integer arith +declare @llvm.vp.add.nxv4i32(, , , i32) +declare @llvm.vp.sub.nxv4i32(, , , i32) +declare @llvm.vp.mul.nxv4i32(, , , i32) +declare @llvm.vp.sdiv.nxv4i32(, , , i32) +declare @llvm.vp.srem.nxv4i32(, , , i32) +declare @llvm.vp.udiv.nxv4i32(, , , i32) +declare @llvm.vp.urem.nxv4i32(, , , i32) +; Bit arith +declare @llvm.vp.and.nxv4i32(, , , i32) +declare @llvm.vp.xor.nxv4i32(, , , i32) +declare @llvm.vp.or.nxv4i32(, , , i32) +declare @llvm.vp.ashr.nxv4i32(, , , i32) +declare @llvm.vp.lshr.nxv4i32(, , , i32) +declare @llvm.vp.shl.nxv4i32(, , , i32) + +; Scalable vector test function. +define void @test_vp_int_vscale( %i0, %i1, %i2, %f3, %m, i32 %n) { + %r0 = call @llvm.vp.add.nxv4i32( %i0, %i1, %m, i32 %n) + %r1 = call @llvm.vp.sub.nxv4i32( %i0, %i1, %m, i32 %n) + %r2 = call @llvm.vp.mul.nxv4i32( %i0, %i1, %m, i32 %n) + %r3 = call @llvm.vp.sdiv.nxv4i32( %i0, %i1, %m, i32 %n) + %r4 = call @llvm.vp.srem.nxv4i32( %i0, %i1, %m, i32 %n) + %r5 = call @llvm.vp.udiv.nxv4i32( %i0, %i1, %m, i32 %n) + %r6 = call @llvm.vp.urem.nxv4i32( %i0, %i1, %m, i32 %n) + %r7 = call @llvm.vp.and.nxv4i32( %i0, %i1, %m, i32 %n) + %r8 = call @llvm.vp.or.nxv4i32( %i0, %i1, %m, i32 %n) + %r9 = call @llvm.vp.xor.nxv4i32( %i0, %i1, %m, i32 %n) + %rA = call @llvm.vp.ashr.nxv4i32( %i0, %i1, %m, i32 %n) + %rB = call @llvm.vp.lshr.nxv4i32( %i0, %i1, %m, i32 %n) + %rC = call @llvm.vp.shl.nxv4i32( %i0, %i1, %m, i32 %n) + ret void +} +; All VP intrinsics have to be lowered into non-VP ops +; Convert %evl into %mask for non-speculatable VP intrinsics and emit the +; instruction+select idiom with a non-VP SIMD instruction. +; +; ALL-CONVERT-NOT: {{call.* @llvm.vp.add}} +; ALL-CONVERT-NOT: {{call.* @llvm.vp.sub}} +; ALL-CONVERT-NOT: {{call.* @llvm.vp.mul}} +; ALL-CONVERT-NOT: {{call.* @llvm.vp.sdiv}} +; ALL-CONVERT-NOT: {{call.* @llvm.vp.srem}} +; ALL-CONVERT-NOT: {{call.* @llvm.vp.udiv}} +; ALL-CONVERT-NOT: {{call.* @llvm.vp.urem}} +; ALL-CONVERT-NOT: {{call.* @llvm.vp.and}} +; ALL-CONVERT-NOT: {{call.* @llvm.vp.or}} +; ALL-CONVERT-NOT: {{call.* @llvm.vp.xor}} +; ALL-CONVERT-NOT: {{call.* @llvm.vp.ashr}} +; ALL-CONVERT-NOT: {{call.* @llvm.vp.lshr}} +; ALL-CONVERT-NOT: {{call.* @llvm.vp.shl}} +; +; ALL-CONVERT: define void @test_vp_int_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x i32> %i2, <8 x i32> %f3, <8 x i1> %m, i32 %n) { +; ALL-CONVERT-NEXT: %{{.*}} = add <8 x i32> %i0, %i1 +; ALL-CONVERT-NEXT: %{{.*}} = sub <8 x i32> %i0, %i1 +; ALL-CONVERT-NEXT: %{{.*}} = mul <8 x i32> %i0, %i1 +; ALL-CONVERT-NEXT: [[NINS:%.+]] = insertelement <8 x i32> poison, i32 %n, i32 0 +; ALL-CONVERT-NEXT: [[NSPLAT:%.+]] = shufflevector <8 x i32> [[NINS]], <8 x i32> poison, <8 x i32> zeroinitializer +; ALL-CONVERT-NEXT: [[EVLM:%.+]] = icmp ult <8 x i32> , [[NSPLAT]] +; ALL-CONVERT-NEXT: [[NEWM:%.+]] = and <8 x i1> [[EVLM]], %m +; ALL-CONVERT-NEXT: [[SELONE:%.+]] = select <8 x i1> [[NEWM]], <8 x i32> %i1, <8 x i32> +; ALL-CONVERT-NEXT: %{{.+}} = sdiv <8 x i32> %i0, [[SELONE]] +; ALL-CONVERT-NOT: %{{.+}} = srem <8 x i32> %i0, %i1 +; ALL-CONVERT: %{{.+}} = srem <8 x i32> %i0, %{{.+}} +; ALL-CONVERT-NOT: %{{.+}} = udiv <8 x i32> %i0, %i1 +; ALL-CONVERT: %{{.+}} = udiv <8 x i32> %i0, %{{.+}} +; ALL-CONVERT-NOT: %{{.+}} = urem <8 x i32> %i0, %i1 +; ALL-CONVERT: %{{.+}} = urem <8 x i32> %i0, %{{.+}} +; ALL-CONVERT-NEXT: %{{.+}} = and <8 x i32> %i0, %i1 +; ALL-CONVERT-NEXT: %{{.+}} = or <8 x i32> %i0, %i1 +; ALL-CONVERT-NEXT: %{{.+}} = xor <8 x i32> %i0, %i1 +; ALL-CONVERT-NEXT: %{{.+}} = ashr <8 x i32> %i0, %i1 +; ALL-CONVERT-NEXT: %{{.+}} = lshr <8 x i32> %i0, %i1 +; ALL-CONVERT-NEXT: %{{.+}} = shl <8 x i32> %i0, %i1 +; ALL-CONVERT: ret void + + + + +; All legal - don't transform anything. + +; LEGAL_LEGAL: define void @test_vp_int_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x i32> %i2, <8 x i32> %f3, <8 x i1> %m, i32 %n) { +; LEGAL_LEGAL-NEXT: %r0 = call <8 x i32> @llvm.vp.add.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) +; LEGAL_LEGAL-NEXT: %r1 = call <8 x i32> @llvm.vp.sub.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) +; LEGAL_LEGAL-NEXT: %r2 = call <8 x i32> @llvm.vp.mul.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) +; LEGAL_LEGAL-NEXT: %r3 = call <8 x i32> @llvm.vp.sdiv.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) +; LEGAL_LEGAL-NEXT: %r4 = call <8 x i32> @llvm.vp.srem.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) +; LEGAL_LEGAL-NEXT: %r5 = call <8 x i32> @llvm.vp.udiv.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) +; LEGAL_LEGAL-NEXT: %r6 = call <8 x i32> @llvm.vp.urem.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) +; LEGAL_LEGAL-NEXT: %r7 = call <8 x i32> @llvm.vp.and.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) +; LEGAL_LEGAL-NEXT: %r8 = call <8 x i32> @llvm.vp.or.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) +; LEGAL_LEGAL-NEXT: %r9 = call <8 x i32> @llvm.vp.xor.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) +; LEGAL_LEGAL-NEXT: %rA = call <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) +; LEGAL_LEGAL-NEXT: %rB = call <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) +; LEGAL_LEGAL-NEXT: %rC = call <8 x i32> @llvm.vp.shl.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) +; LEGAL_LEGAL-NEXT: ret void + +; LEGAL_LEGAL:define void @test_vp_int_vscale( %i0, %i1, %i2, %f3, %m, i32 %n) { +; LEGAL_LEGAL-NEXT: %r0 = call @llvm.vp.add.nxv4i32( %i0, %i1, %m, i32 %n) +; LEGAL_LEGAL-NEXT: %r1 = call @llvm.vp.sub.nxv4i32( %i0, %i1, %m, i32 %n) +; LEGAL_LEGAL-NEXT: %r2 = call @llvm.vp.mul.nxv4i32( %i0, %i1, %m, i32 %n) +; LEGAL_LEGAL-NEXT: %r3 = call @llvm.vp.sdiv.nxv4i32( %i0, %i1, %m, i32 %n) +; LEGAL_LEGAL-NEXT: %r4 = call @llvm.vp.srem.nxv4i32( %i0, %i1, %m, i32 %n) +; LEGAL_LEGAL-NEXT: %r5 = call @llvm.vp.udiv.nxv4i32( %i0, %i1, %m, i32 %n) +; LEGAL_LEGAL-NEXT: %r6 = call @llvm.vp.urem.nxv4i32( %i0, %i1, %m, i32 %n) +; LEGAL_LEGAL-NEXT: %r7 = call @llvm.vp.and.nxv4i32( %i0, %i1, %m, i32 %n) +; LEGAL_LEGAL-NEXT: %r8 = call @llvm.vp.or.nxv4i32( %i0, %i1, %m, i32 %n) +; LEGAL_LEGAL-NEXT: %r9 = call @llvm.vp.xor.nxv4i32( %i0, %i1, %m, i32 %n) +; LEGAL_LEGAL-NEXT: %rA = call @llvm.vp.ashr.nxv4i32( %i0, %i1, %m, i32 %n) +; LEGAL_LEGAL-NEXT: %rB = call @llvm.vp.lshr.nxv4i32( %i0, %i1, %m, i32 %n) +; LEGAL_LEGAL-NEXT: %rC = call @llvm.vp.shl.nxv4i32( %i0, %i1, %m, i32 %n) +; LEGAL_LEGAL-NEXT: ret void + + +; Drop %evl where possible else fold %evl into %mask (%evl Discard, %mask Legal) +; +; There is no caching yet in the ExpandVectorPredication pass and the %evl +; expansion code is emitted for every non-speculatable intrinsic again. Hence, +; only check that.. +; (1) The %evl folding code and %mask are correct for the first +; non-speculatable VP intrinsic. +; (2) All other non-speculatable VP intrinsics have a modified mask argument. +; (3) All speculatable VP intrinsics keep their %mask and %evl. +; (4) All VP intrinsics have an ineffective %evl parameter. + +; DISCARD_LEGAL: define void @test_vp_int_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x i32> %i2, <8 x i32> %f3, <8 x i1> %m, i32 %n) { +; DISCARD_LEGAL-NEXT: %r0 = call <8 x i32> @llvm.vp.add.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) +; DISCARD_LEGAL-NEXT: %r1 = call <8 x i32> @llvm.vp.sub.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) +; DISCARD_LEGAL-NEXT: %r2 = call <8 x i32> @llvm.vp.mul.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) +; DISCARD_LEGAL-NEXT: [[NSPLATINS:%.+]] = insertelement <8 x i32> poison, i32 %n, i32 0 +; DISCARD_LEGAL-NEXT: [[NSPLAT:%.+]] = shufflevector <8 x i32> [[NSPLATINS]], <8 x i32> poison, <8 x i32> zeroinitializer +; DISCARD_LEGAL-NEXT: [[EVLMASK:%.+]] = icmp ult <8 x i32> , [[NSPLAT]] +; DISCARD_LEGAL-NEXT: [[NEWMASK:%.+]] = and <8 x i1> [[EVLMASK]], %m +; DISCARD_LEGAL-NEXT: %r3 = call <8 x i32> @llvm.vp.sdiv.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> [[NEWMASK]], i32 8) +; DISCARD_LEGAL-NOT: %r4 = call <8 x i32> @llvm.vp.srem.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) +; DISCARD_LEGAL-NOT: %r5 = call <8 x i32> @llvm.vp.udiv.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) +; DISCARD_LEGAL-NOT: %r6 = call <8 x i32> @llvm.vp.urem.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) +; DISCARD_LEGAL: %r7 = call <8 x i32> @llvm.vp.and.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) +; DISCARD_LEGAL-NEXT: %r8 = call <8 x i32> @llvm.vp.or.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) +; DISCARD_LEGAL-NEXT: %r9 = call <8 x i32> @llvm.vp.xor.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) +; DISCARD_LEGAL-NEXT: %rA = call <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) +; DISCARD_LEGAL-NEXT: %rB = call <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) +; DISCARD_LEGAL-NEXT: %rC = call <8 x i32> @llvm.vp.shl.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) +; DISCARD_LEGAL-NEXT: ret void + +; TODO compute vscale only once and use caching. +; In the meantime, we only check for the correct vscale code for the first VP +; intrinsic and skip over it for all others. + +; DISCARD_LEGAL: define void @test_vp_int_vscale( %i0, %i1, %i2, %f3, %m, i32 %n) { +; DISCARD_LEGAL-NEXT: %vscale = call i32 @llvm.vscale.i32() +; DISCARD_LEGAL-NEXT: %scalable_size = mul nuw i32 %vscale, 4 +; DISCARD_LEGAL-NEXT: %r0 = call @llvm.vp.add.nxv4i32( %i0, %i1, %m, i32 %scalable_size) +; DISCARD_LEGAL: %r1 = call @llvm.vp.sub.nxv4i32( %i0, %i1, %m, i32 %scalable_size{{.*}}) +; DISCARD_LEGAL: %r2 = call @llvm.vp.mul.nxv4i32( %i0, %i1, %m, i32 %scalable_size{{.*}}) +; DISCARD_LEGAL: [[EVLM:%.+]] = call @llvm.get.active.lane.mask.nxv4i1.i32(i32 0, i32 %n) +; DISCARD_LEGAL: [[NEWM:%.+]] = and [[EVLM]], %m +; DISCARD_LEGAL: %r3 = call @llvm.vp.sdiv.nxv4i32( %i0, %i1, [[NEWM]], i32 %scalable_size{{.*}}) +; DISCARD_LEGAL-NOT: %{{.+}} = call @llvm.vp.{{.*}}, i32 %n) +; DISCARD_LEGAL: ret void + + +; Convert %evl into %mask everywhere (%evl Convert, %mask Legal) +; +; For the same reasons as in the (%evl Discard, %mask Legal) case only check that.. +; (1) The %evl folding code and %mask are correct for the first VP intrinsic. +; (2) All other VP intrinsics have a modified mask argument. +; (3) All VP intrinsics have an ineffective %evl parameter. +; +; CONVERT_LEGAL: define void @test_vp_int_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x i32> %i2, <8 x i32> %f3, <8 x i1> %m, i32 %n) { +; CONVERT_LEGAL-NEXT: [[NINS:%.+]] = insertelement <8 x i32> poison, i32 %n, i32 0 +; CONVERT_LEGAL-NEXT: [[NSPLAT:%.+]] = shufflevector <8 x i32> [[NINS]], <8 x i32> poison, <8 x i32> zeroinitializer +; CONVERT_LEGAL-NEXT: [[EVLM:%.+]] = icmp ult <8 x i32> , [[NSPLAT]] +; CONVERT_LEGAL-NEXT: [[NEWM:%.+]] = and <8 x i1> [[EVLM]], %m +; CONVERT_LEGAL-NEXT: %{{.+}} = call <8 x i32> @llvm.vp.add.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> [[NEWM]], i32 8) +; CONVERT_LEGAL-NOT: %{{.+}} = call <8 x i32> @llvm.vp.sub.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) +; CONVERT_LEGAL-NOT: %{{.+}} = call <8 x i32> @llvm.vp.mul.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) +; CONVERT_LEGAL-NOT: %{{.+}} = call <8 x i32> @llvm.vp.sdiv.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) +; CONVERT_LEGAL-NOT: %{{.+}} = call <8 x i32> @llvm.vp.srem.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) +; CONVERT_LEGAL-NOT: %{{.+}} = call <8 x i32> @llvm.vp.udiv.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) +; CONVERT_LEGAL-NOT: %{{.+}} = call <8 x i32> @llvm.vp.urem.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) +; CONVERT_LEGAL-NOT: %{{.+}} = call <8 x i32> @llvm.vp.and.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) +; CONVERT_LEGAL-NOT: %{{.+}} = call <8 x i32> @llvm.vp.or.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) +; CONVERT_LEGAL-NOT: %{{.+}} = call <8 x i32> @llvm.vp.xor.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) +; CONVERT_LEGAL-NOT: %{{.+}} = call <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) +; CONVERT_LEGAL-NOT: %{{.+}} = call <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) +; CONVERT_LEGAL-NOT: %{{.+}} = call <8 x i32> @llvm.vp.shl.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 8) +; CONVERT_LEGAL: ret void + +; Similar to %evl discard, %mask legal but make sure the first VP intrinsic has a legal expansion +; CONVERT_LEGAL: define void @test_vp_int_vscale( %i0, %i1, %i2, %f3, %m, i32 %n) { +; CONVERT_LEGAL-NEXT: [[EVLM:%.+]] = call @llvm.get.active.lane.mask.nxv4i1.i32(i32 0, i32 %n) +; CONVERT_LEGAL-NEXT: [[NEWM:%.+]] = and [[EVLM]], %m +; CONVERT_LEGAL-NEXT: %vscale = call i32 @llvm.vscale.i32() +; CONVERT_LEGAL-NEXT: %scalable_size = mul nuw i32 %vscale, 4 +; CONVERT_LEGAL-NEXT: %r0 = call @llvm.vp.add.nxv4i32( %i0, %i1, [[NEWM]], i32 %scalable_size) +; CONVERT_LEGAL-NOT: %{{.*}} = call @llvm.vp.{{.*}}, i32 %n) +; CONVERT_LEGAL: ret void + diff --git a/llvm/test/CodeGen/X86/O0-pipeline.ll b/llvm/test/CodeGen/X86/O0-pipeline.ll --- a/llvm/test/CodeGen/X86/O0-pipeline.ll +++ b/llvm/test/CodeGen/X86/O0-pipeline.ll @@ -25,6 +25,7 @@ ; CHECK-NEXT: Shadow Stack GC Lowering ; CHECK-NEXT: Lower constant intrinsics ; CHECK-NEXT: Remove unreachable blocks from the CFG +; CHECK-NEXT: Expand vector predication intrinsics ; CHECK-NEXT: Scalarize Masked Memory Intrinsics ; CHECK-NEXT: Expand reduction intrinsics ; CHECK-NEXT: Expand indirectbr instructions diff --git a/llvm/test/CodeGen/X86/opt-pipeline.ll b/llvm/test/CodeGen/X86/opt-pipeline.ll --- a/llvm/test/CodeGen/X86/opt-pipeline.ll +++ b/llvm/test/CodeGen/X86/opt-pipeline.ll @@ -54,6 +54,7 @@ ; CHECK-NEXT: Constant Hoisting ; CHECK-NEXT: Replace intrinsics with calls to vector library ; CHECK-NEXT: Partially inline calls to library functions +; CHECK-NEXT: Expand vector predication intrinsics ; CHECK-NEXT: Scalarize Masked Memory Intrinsics ; CHECK-NEXT: Expand reduction intrinsics ; CHECK-NEXT: Interleaved Access Pass diff --git a/llvm/tools/llc/llc.cpp b/llvm/tools/llc/llc.cpp --- a/llvm/tools/llc/llc.cpp +++ b/llvm/tools/llc/llc.cpp @@ -352,6 +352,7 @@ initializeVectorization(*Registry); initializeScalarizeMaskedMemIntrinLegacyPassPass(*Registry); initializeExpandReductionsPass(*Registry); + initializeExpandVectorPredicationPass(*Registry); initializeHardwareLoopsPass(*Registry); initializeTransformUtils(*Registry); initializeReplaceWithVeclibLegacyPass(*Registry); diff --git a/llvm/tools/opt/opt.cpp b/llvm/tools/opt/opt.cpp --- a/llvm/tools/opt/opt.cpp +++ b/llvm/tools/opt/opt.cpp @@ -513,7 +513,7 @@ "safe-stack", "cost-model", "codegenprepare", "interleaved-load-combine", "unreachableblockelim", "verify-safepoint-ir", - "atomic-expand", + "atomic-expand", "expandvp", "hardware-loops", "type-promotion", "mve-tail-predication", "interleaved-access", "global-merge", "pre-isel-intrinsic-lowering", @@ -591,6 +591,7 @@ initializePostInlineEntryExitInstrumenterPass(Registry); initializeUnreachableBlockElimLegacyPassPass(Registry); initializeExpandReductionsPass(Registry); + initializeExpandVectorPredicationPass(Registry); initializeWasmEHPreparePass(Registry); initializeWriteBitcodePassPass(Registry); initializeHardwareLoopsPass(Registry);