diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h --- a/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h +++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationPlanner.h @@ -304,6 +304,9 @@ /// reduction chain. void adjustRecipesForInLoopReductions(VPlanPtr &Plan, VPRecipeBuilder &RecipeBuilder); + + bool adjustVPlanForVectorPatterns(VPlanPtr &Plan, + VPRecipeBuilder &RecipeBuilder); }; } // namespace llvm diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -58,6 +58,7 @@ #include "VPRecipeBuilder.h" #include "VPlan.h" #include "VPlanHCFGBuilder.h" +#include "VPlanPatternMatch.h" #include "VPlanPredicator.h" #include "VPlanTransforms.h" #include "llvm/ADT/APInt.h" @@ -8000,6 +8001,63 @@ } } +static bool tryCombineToMulh(VPBasicBlock::reverse_iterator &It, VPlanPtr &Plan, + VPRecipeBuilder &RecipeBuilder) { + VPWidenRecipe *RV = dyn_cast(&*It); + if (!RV) + return false; + + // trunc(lsr(mul(ext, ext)), BW)) -> mulhs/mulhu + Type *Ty = RV->getType(); + unsigned BW = Ty->getScalarSizeInBits(); + + VPValue *Mul; + // TODOD: Start at shift. create ext(trunc()) + // TODOD: One use checks? + if (!match(RV, vp_Trunc(vp_LShr(vp_Value(Mul), vp_SpecificInt(BW))))) + return false; + + VPValue *A, *B; + unsigned Opcode = 0; + if (match(Mul, vp_Mul(vp_SExt(vp_Value(A)), vp_SExt(vp_Value(B))))) + Opcode = VPInstruction::Mulhs; + else if (match(Mul, vp_Mul(vp_ZExt(vp_Value(A)), vp_ZExt(vp_Value(B))))) + Opcode = VPInstruction::Mulhu; + else + return false; + + unsigned MulBW = Mul->getType()->getScalarSizeInBits(); + if (MulBW > BW * 2) + return false; + + if (A->getType() != Ty || B->getType() != Ty) // TODOD: Generalize + return false; + + auto *Mulh = new VPInstruction(Opcode, {A, B}); + RV->getParent()->insert(Mulh, RV->getIterator()); + RV->replaceAllUsesWith(Mulh); + RV->recursivelyDeleteUnusedRecipes(); + It = Mulh->getReverseIterator(); + return true; +} + +bool LoopVectorizationPlanner::adjustVPlanForVectorPatterns( + VPlanPtr &Plan, VPRecipeBuilder &RecipeBuilder) { + LLVM_DEBUG(dbgs() << "Current plan: " << *Plan << "\n"); + bool Changed = false; + for (VPBlockBase *Block : depth_first(Plan->getEntry())) { + if (auto *BB = dyn_cast(Block)) { + for (VPBasicBlock::reverse_iterator It = BB->rbegin(), E = BB->rend(); + It != E; ++It) { + if (true /*TODOD: TTI->hasVectorPattern(Mulh)*/) + Changed |= tryCombineToMulh(It, Plan, RecipeBuilder); + } + } + } + + return Changed; +} + Value* LoopVectorizationPlanner::VPCallbackILV:: getOrCreateVectorValues(Value *V, unsigned Part) { return ILV.getOrCreateVectorValue(V, Part); diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -684,6 +684,8 @@ return cast_or_null(VPV->getUnderlyingValue()); return nullptr; } + + void recursivelyDeleteUnusedRecipes(); }; inline bool VPUser::classof(const VPDef *Recipe) { @@ -714,6 +716,8 @@ SLPLoad, SLPStore, ActiveLaneMask, + Mulhs, + Mulhu, }; private: @@ -830,6 +834,8 @@ /// Print the recipe. void print(raw_ostream &O, const Twine &Indent, VPSlotTracker &SlotTracker) const override; + + unsigned getOpcode() const { return getUnderlyingInstr()->getOpcode(); } }; /// A recipe for widening Call instructions. diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp --- a/llvm/lib/Transforms/Vectorize/VPlan.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp @@ -440,6 +440,19 @@ insertAfter(InsertPos); } +void VPRecipeBase::recursivelyDeleteUnusedRecipes() { + if (all_of(defined_values(), [](VPValue *Def) { + return Def->getNumUsers() == 0; + }) /* && isSafeToRemove()*/) { + for (auto *Op : operands()) { + Op->removeUser(*this); + if (VPRecipeBase *R = dyn_cast(Op->getDef())) + R->recursivelyDeleteUnusedRecipes(); + } + eraseFromParent(); + } +} + void VPInstruction::generateInstruction(VPTransformState &State, unsigned Part) { IRBuilder<> &Builder = State.Builder; diff --git a/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h new file mode 100644 --- /dev/null +++ b/llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h @@ -0,0 +1,96 @@ +//===- VPlanPatternMatch.h - Pattern matchers for vplan recipes -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// \file +/// This file contains a number of vplan pattern matchers, similar to and +/// working on top of the the generic IR pattern matches from IR/PatternMatch.h. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_VECTORIZE_VPLANPATTERNMATCHER_H +#define LLVM_TRANSFORMS_VECTORIZE_VPLANPATTERNMATCHER_H + +#include "VPlan.h" +#include "VPlanValue.h" +#include "llvm/IR/Instruction.h" +#include + +using namespace llvm; +using namespace llvm::PatternMatch; + +template +struct VPBinaryOp_match { + LHS_t L; + RHS_t R; + + VPBinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} + + template bool match(OpTy *V) { + const VPWidenRecipe *VW = dyn_cast(V); + if (VW && VW->getOpcode() == Opcode) + return L.match(VW->getOperand(0)) && R.match(VW->getOperand(1)); + // TODOD: VPInstruction? + return false; + } +}; +template +inline VPBinaryOp_match vp_Mul(const LHS &L, + const RHS &R) { + return VPBinaryOp_match(L, R); +} +template +inline VPBinaryOp_match vp_LShr(const LHS &L, + const RHS &R) { + return VPBinaryOp_match(L, R); +} + +template struct VPUnaryOp_match { + LHS_t L; + + VPUnaryOp_match(const LHS_t &LHS) : L(LHS) {} + + template bool match(OpTy *V) { + const VPWidenRecipe *VW = dyn_cast(V); + if (VW && VW->getOpcode() == Opcode) + return L.match(VW->getOperand(0)); + return false; + } +}; +template +inline VPUnaryOp_match vp_Trunc(const LHS &L) { + return VPUnaryOp_match(L); +} +template +inline VPUnaryOp_match vp_SExt(const LHS &L) { + return VPUnaryOp_match(L); +} +template +inline VPUnaryOp_match vp_ZExt(const LHS &L) { + return VPUnaryOp_match(L); +} + +inline bind_ty vp_Value(VPValue *&V) { return V; } +inline bind_ty vp_Value(const VPValue *&V) { return V; } + +struct VPspecific_intval { + APInt Val; + + VPspecific_intval(APInt V) : Val(std::move(V)) {} + + template bool match(ITy *V) { + const auto *C = V->getConstantValue(); + if (const auto *CI = dyn_cast_or_null(C)) + return CI && APInt::isSameValue(CI->getValue(), Val); + return false; + } +}; +inline VPspecific_intval vp_SpecificInt(uint64_t V) { + return VPspecific_intval(APInt(64, V)); +} + +#endif // LLVM_TRANSFORMS_VECTORIZE_VPLANPATTERNMATCHER_H diff --git a/llvm/lib/Transforms/Vectorize/VPlanValue.h b/llvm/lib/Transforms/Vectorize/VPlanValue.h --- a/llvm/lib/Transforms/Vectorize/VPlanValue.h +++ b/llvm/lib/Transforms/Vectorize/VPlanValue.h @@ -161,6 +161,19 @@ void replaceAllUsesWith(VPValue *New); VPDef *getDef() { return Def; } + + // FIXME: We need to do something about types. + Type *getType() { + if (UnderlyingVal) + return UnderlyingVal->getType(); + return nullptr; + } + + const Constant *getConstantValue() const { + if (SubclassID == VPValueSC) + return dyn_cast(getUnderlyingValue()); + return nullptr; + } }; typedef DenseMap Value2VPValueTy;