Index: llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h =================================================================== --- llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h +++ llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h @@ -315,6 +315,8 @@ /// reduction chain. void adjustRecipesForInLoopReductions(VPlanPtr &Plan, VPRecipeBuilder &RecipeBuilder); + bool adjustVPlanForVectorPatterns(VPlanPtr &Plan, + VPRecipeBuilder &RecipeBuilder); }; } // namespace llvm Index: llvm/lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -58,6 +58,7 @@ #include "VPRecipeBuilder.h" #include "VPlan.h" #include "VPlanHCFGBuilder.h" +#include "VPlanPatternMatch.h" #include "VPlanPredicator.h" #include "VPlanTransforms.h" #include "llvm/ADT/APInt.h" @@ -7911,8 +7912,10 @@ } // Adjust the recipes for any inloop reductions. - if (Range.Start > 1) + if (Range.Start > 1) { adjustRecipesForInLoopReductions(Plan, RecipeBuilder); + adjustVPlanForVectorPatterns(Plan, RecipeBuilder); + } // Finally, if tail is folded by masking, introduce selects between the phi // and the live-out instruction of each reduction, at the end of the latch. @@ -8043,6 +8046,63 @@ } } +static bool tryCombineToMulh(VPBasicBlock::reverse_iterator &It, VPlanPtr &Plan, + VPRecipeBuilder &RecipeBuilder) { + VPWidenRecipe *RV = dyn_cast(&*It); + if (!RV) + return false; + + // trunc(lsr(mul(ext, ext)), BW)) -> mulhs/mulhu + Type *Ty = RV->getType(); + unsigned BW = Ty->getScalarSizeInBits(); + + VPValue *Mul; + // TODOD: Start at shift. create ext(trunc()) + // TODOD: One use checks? + if (!match(RV, vp_Trunc(vp_LShr(vp_Value(Mul), vp_SpecificInt(BW))))) + return false; + + VPValue *A, *B; + unsigned Opcode = 0; + if (match(Mul, vp_Mul(vp_SExt(vp_Value(A)), vp_SExt(vp_Value(B))))) + Opcode = VPInstruction::Mulhs; + else if (match(Mul, vp_Mul(vp_ZExt(vp_Value(A)), vp_ZExt(vp_Value(B))))) + Opcode = VPInstruction::Mulhu; + else + return false; + + unsigned MulBW = Mul->getType()->getScalarSizeInBits(); + if (MulBW > BW * 2) + return false; + + if (A->getType() != Ty || B->getType() != Ty) // TODOD: Generalize + return false; + + auto *Mulh = new VPInstruction(Opcode, Ty, {A, B}); + RV->getParent()->insert(Mulh, RV->getIterator()); + RV->replaceAllUsesWith(Mulh); + RV->recursivelyDeleteUnusedRecipes(); + It = Mulh->getReverseIterator(); + return true; +} + +bool LoopVectorizationPlanner::adjustVPlanForVectorPatterns( + VPlanPtr &Plan, VPRecipeBuilder &RecipeBuilder) { + LLVM_DEBUG(dbgs() << "Current plan: " << *Plan << "\n"); + bool Changed = false; + for (VPBlockBase *Block : depth_first(Plan->getEntry())) { + if (auto *BB = dyn_cast(Block)) { + for (VPBasicBlock::reverse_iterator It = BB->rbegin(), E = BB->rend(); + It != E; ++It) { + if (true /*TODOD: TTI->hasVectorPattern(Mulh)*/) + Changed |= tryCombineToMulh(It, Plan, RecipeBuilder); + } + } + } + + return Changed; +} + Value* LoopVectorizationPlanner::VPCallbackILV:: getOrCreateVectorValues(Value *V, unsigned Part) { return ILV.getOrCreateVectorValue(V, Part); @@ -8479,6 +8539,11 @@ CostAttrs, TargetTransformInfo::TCK_RecipThroughput), false}; } + case VPInstruction::Mulhs: + case VPInstruction::Mulhu: { + // TODOD: return TTI.getVectorInstrCost(getOpcode(), getVectorType(VF)); + return {1, false}; + } default: llvm_unreachable("Unsupported opcode for VPInstruction::cost"); } Index: llvm/lib/Transforms/Vectorize/VPlan.h =================================================================== --- llvm/lib/Transforms/Vectorize/VPlan.h +++ llvm/lib/Transforms/Vectorize/VPlan.h @@ -697,6 +697,12 @@ const Instruction *getUnderlyingInstr() const { return cast_or_null(getUnderlyingValue()); } + + void recursivelyDeleteUnusedRecipes(); + + static inline bool classof(const VPValue *V) { + return V->getVPValueID() >= VPBlendSC; + } }; /// This is a concrete Recipe that models a single VPlan-level instruction. @@ -714,6 +720,8 @@ SLPLoad, SLPStore, ActiveLaneMask, + Mulhs, + Mulhu, }; private: @@ -819,6 +827,8 @@ /// Print the recipe. void print(raw_ostream &O, const Twine &Indent, VPSlotTracker &SlotTracker) const override; + + unsigned getOpcode() const { return getUnderlyingInstr()->getOpcode(); } }; /// A recipe for widening Call instructions. Index: llvm/lib/Transforms/Vectorize/VPlan.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/VPlan.cpp +++ llvm/lib/Transforms/Vectorize/VPlan.cpp @@ -383,6 +383,18 @@ return getParent()->getRecipeList().erase(getIterator()); } +void VPRecipeBase::recursivelyDeleteUnusedRecipes() { + if (getNumUsers() == 0 /* && isSafeToRemove()*/) { + for (auto *Op : operands()) { + Op->removeUser(*this); + if (VPRecipeBase *R = dyn_cast(Op)) + R->recursivelyDeleteUnusedRecipes(); + } + clearOperands(); + eraseFromParent(); + } +} + void VPRecipeBase::moveAfter(VPRecipeBase *InsertPos) { removeFromParent(); insertAfter(InsertPos); @@ -436,6 +448,31 @@ State.set(this, Call, Part); break; } + case VPInstruction::Mulhs: + case VPInstruction::Mulhu: { + Value *A = State.get(getOperand(0), Part); + Value *B = State.get(getOperand(1), Part); + + auto Ty = cast(A->getType()); + unsigned BW = Ty->getScalarSizeInBits(); + auto *WidenTy = VectorType::getExtendedElementVectorType(Ty); + + auto *ExtA = getOpcode() == VPInstruction::Mulhs + ? Builder.CreateSExt(A, WidenTy) + : Builder.CreateZExt(A, WidenTy); + auto *ExtB = getOpcode() == VPInstruction::Mulhs + ? Builder.CreateSExt(B, WidenTy) + : Builder.CreateZExt(B, WidenTy); + auto *Mul = Builder.CreateMul(ExtA, ExtB); + auto *Shift = Builder.CreateLShr( + Mul, ConstantDataVector::getSplat( + State.VF.getKnownMinValue(), + ConstantInt::get(WidenTy->getElementType(), BW))); + auto *Trunc = Builder.CreateTrunc(Shift, Ty); + + State.set(this, Trunc, Part); + break; + } default: llvm_unreachable("Unsupported opcode for instruction"); } @@ -480,6 +517,12 @@ case VPInstruction::ActiveLaneMask: O << "active lane mask"; break; + case VPInstruction::Mulhs: + O << "mulhs"; + break; + case VPInstruction::Mulhu: + O << "mulhu"; + break; default: O << Instruction::getOpcodeName(getOpcode()); Index: llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h =================================================================== --- /dev/null +++ llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h @@ -0,0 +1,96 @@ +//===- VPlanPatternMatch.h - Pattern matchers for vplan recipes -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// \file +/// This file contains a number of vplan pattern matchers, similar to and +/// working on top of the the generic IR pattern matches from IR/PatternMatch.h. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_VECTORIZE_VPLANPATTERNMATCHER_H +#define LLVM_TRANSFORMS_VECTORIZE_VPLANPATTERNMATCHER_H + +#include "VPlan.h" +#include "VPlanValue.h" +#include "llvm/IR/Instruction.h" +#include + +using namespace llvm; +using namespace llvm::PatternMatch; + +template +struct VPBinaryOp_match { + LHS_t L; + RHS_t R; + + VPBinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} + + template bool match(OpTy *V) { + const VPWidenRecipe *VW = dyn_cast(V); + if (VW && VW->getOpcode() == Opcode) + return L.match(VW->getOperand(0)) && R.match(VW->getOperand(1)); + // TODOD: VPInstruction? + return false; + } +}; +template +inline VPBinaryOp_match vp_Mul(const LHS &L, + const RHS &R) { + return VPBinaryOp_match(L, R); +} +template +inline VPBinaryOp_match vp_LShr(const LHS &L, + const RHS &R) { + return VPBinaryOp_match(L, R); +} + +template struct VPUnaryOp_match { + LHS_t L; + + VPUnaryOp_match(const LHS_t &LHS) : L(LHS) {} + + template bool match(OpTy *V) { + const VPWidenRecipe *VW = dyn_cast(V); + if (VW && VW->getOpcode() == Opcode) + return L.match(VW->getOperand(0)); + return false; + } +}; +template +inline VPUnaryOp_match vp_Trunc(const LHS &L) { + return VPUnaryOp_match(L); +} +template +inline VPUnaryOp_match vp_SExt(const LHS &L) { + return VPUnaryOp_match(L); +} +template +inline VPUnaryOp_match vp_ZExt(const LHS &L) { + return VPUnaryOp_match(L); +} + +inline bind_ty vp_Value(VPValue *&V) { return V; } +inline bind_ty vp_Value(const VPValue *&V) { return V; } + +struct VPspecific_intval { + APInt Val; + + VPspecific_intval(APInt V) : Val(std::move(V)) {} + + template bool match(ITy *V) { + const auto *C = V->getConstantValue(); + if (const auto *CI = dyn_cast_or_null(C)) + return CI && APInt::isSameValue(CI->getValue(), Val); + return false; + } +}; +inline VPspecific_intval vp_SpecificInt(uint64_t V) { + return VPspecific_intval(APInt(64, V)); +} + +#endif // LLVM_TRANSFORMS_VECTORIZE_VPLANPATTERNMATCHER_H Index: llvm/lib/Transforms/Vectorize/VPlanValue.h =================================================================== --- llvm/lib/Transforms/Vectorize/VPlanValue.h +++ llvm/lib/Transforms/Vectorize/VPlanValue.h @@ -105,6 +105,13 @@ /// for any other purpose, as the values may change as LLVM evolves. unsigned getVPValueID() const { return SubclassID; } + // FIXME: We need to do something about types. + Type *getType() { + if (UnderlyingVal) + return UnderlyingVal->getType(); + return nullptr; + } + void printAsOperand(raw_ostream &OS, VPSlotTracker &Tracker) const; void print(raw_ostream &OS, VPSlotTracker &Tracker) const; @@ -157,6 +164,12 @@ } void replaceAllUsesWith(VPValue *New); + + const Constant *getConstantValue() const { + if (SubclassID == VPValueSC) + return dyn_cast(getUnderlyingValue()); + return nullptr; + } }; typedef DenseMap Value2VPValueTy; @@ -169,6 +182,9 @@ class VPUser : public VPValue { SmallVector Operands; +protected: + void clearOperands() { Operands.clear(); } + public: VPUser(const unsigned char SC, Value *UV = nullptr) : VPValue(SC, UV) {} VPUser(const unsigned char SC, Value *UV, ArrayRef Operands)