diff --git a/llvm/include/llvm/CodeGen/MachinePassRegistry.def b/llvm/include/llvm/CodeGen/MachinePassRegistry.def --- a/llvm/include/llvm/CodeGen/MachinePassRegistry.def +++ b/llvm/include/llvm/CodeGen/MachinePassRegistry.def @@ -45,6 +45,7 @@ FUNCTION_PASS("post-inline-ee-instrument", EntryExitInstrumenterPass, (true)) FUNCTION_PASS("expand-large-div-rem", ExpandLargeDivRemPass, ()) FUNCTION_PASS("expand-large-fp-convert", ExpandLargeFpConvertPass, ()) +FUNCTION_PASS("expand-powi", ExpandPowiPass, ()) FUNCTION_PASS("expand-reductions", ExpandReductionsPass, ()) FUNCTION_PASS("expandvp", ExpandVectorPredicationPass, ()) FUNCTION_PASS("lowerinvoke", LowerInvokePass, ()) diff --git a/llvm/include/llvm/CodeGen/Passes.h b/llvm/include/llvm/CodeGen/Passes.h --- a/llvm/include/llvm/CodeGen/Passes.h +++ b/llvm/include/llvm/CodeGen/Passes.h @@ -518,6 +518,9 @@ // Expands large div/rem instructions. FunctionPass *createExpandLargeFpConvertPass(); + // Expands powi instructions. + FunctionPass *createExpandPowiPass(); + // This pass expands memcmp() to load/stores. FunctionPass *createExpandMemCmpPass(); diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td --- a/llvm/include/llvm/IR/Intrinsics.td +++ b/llvm/include/llvm/IR/Intrinsics.td @@ -1686,6 +1686,11 @@ [ LLVMMatchType<0>, LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, llvm_i32_ty]>; + def int_vp_powi : DefaultAttrsIntrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; // Casts def int_vp_trunc : DefaultAttrsIntrinsic<[ llvm_anyvector_ty ], diff --git a/llvm/include/llvm/IR/VPIntrinsics.def b/llvm/include/llvm/IR/VPIntrinsics.def --- a/llvm/include/llvm/IR/VPIntrinsics.def +++ b/llvm/include/llvm/IR/VPIntrinsics.def @@ -362,6 +362,10 @@ BEGIN_REGISTER_VP(vp_nearbyint, 1, 2, VP_FNEARBYINT, -1) END_REGISTER_VP(vp_nearbyint, VP_FNEARBYINT) +// llvm.vp.powi(x, y, mask,vlen) +BEGIN_REGISTER_VP_INTRINSIC(vp_powi, 2, 3) +VP_PROPERTY_BINARYOP +END_REGISTER_VP_INTRINSIC(vp_powi) ///// } Floating-Point Arithmetic ///// Type Casts { diff --git a/llvm/include/llvm/InitializePasses.h b/llvm/include/llvm/InitializePasses.h --- a/llvm/include/llvm/InitializePasses.h +++ b/llvm/include/llvm/InitializePasses.h @@ -130,6 +130,7 @@ void initializeExpandLargeDivRemLegacyPassPass(PassRegistry&); void initializeExpandMemCmpPassPass(PassRegistry&); void initializeExpandPostRAPass(PassRegistry&); +void initializeExpandPowiLegacyPassPass(PassRegistry &); void initializeExpandReductionsPass(PassRegistry&); void initializeExpandVectorPredicationPass(PassRegistry &); void initializeMakeGuardsExplicitLegacyPassPass(PassRegistry&); diff --git a/llvm/lib/CodeGen/CMakeLists.txt b/llvm/lib/CodeGen/CMakeLists.txt --- a/llvm/lib/CodeGen/CMakeLists.txt +++ b/llvm/lib/CodeGen/CMakeLists.txt @@ -59,6 +59,7 @@ ExpandLargeFpConvert.cpp ExpandMemCmp.cpp ExpandPostRAPseudos.cpp + ExpandPowi.cpp ExpandReductions.cpp ExpandVectorPredication.cpp FaultMaps.cpp diff --git a/llvm/lib/CodeGen/ExpandPowi.cpp b/llvm/lib/CodeGen/ExpandPowi.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/CodeGen/ExpandPowi.cpp @@ -0,0 +1,168 @@ +//===--- ExpandPowi.cpp - Expand Powi intrinsics ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass implements IR expansion for powi/vp.powi. The expansion is based on +// compiler-rt/__powidf2.c. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/SmallVector.h" +#include "llvm/CodeGen/Passes.h" +#include "llvm/CodeGen/TargetLowering.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/PassManager.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" + +#define DEBUG_TYPE "expand-powi" + +using namespace llvm; + +// Helper function to generate Value for CmpInst::Predicate. +// FIXME: Support createVPCmp in IRBuilderBase. +static Value *getPredicateValue(LLVMContext &Context, + CmpInst::Predicate Predicate) { + StringRef PredicateStr = CmpInst::getPredicateName(Predicate); + auto *PredicateMDS = MDString::get(Context, PredicateStr); + return MetadataAsValue::get(Context, PredicateMDS); +} + +// The expansion is based on the c code of compiler-rt/__powidf2.c, +// const int recip = b < 0; +// double r = 1; +// while (1) { +// if (b & 1) +// r *= a; +// b /= 2; +// if (b == 0) +// break; +// a *= a; +// } +// return recip ? 1 / r : r; +static void expandPowi(IntrinsicInst *II) { + Value *OrigBase = II->getOperand(0); + Value *OrigExp = II->getOperand(1); + Value *Mask = II->getOperand(2); + Value *EVL = II->getOperand(3); + + BasicBlock *PreLoopBB = II->getParent(); + BasicBlock *PostLoopBB = PreLoopBB->splitBasicBlock(II, "powi-post-loop"); + BasicBlock *LoopBody = + BasicBlock::Create(PreLoopBB->getContext(), "powi-forward-loop", + PreLoopBB->getParent(), PostLoopBB); + + IRBuilder<> Builder(PreLoopBB->getTerminator()); + Builder.CreateBr(LoopBody); + PreLoopBB->getTerminator()->eraseFromParent(); + + Type *BaseTy = OrigBase->getType(); + Type *ExpTy = OrigExp->getType(); + Type *CondTy = ExpTy->getWithNewBitWidth(1); + Value *True = ConstantInt::get(CondTy, 1); + LLVMContext &C = II->getContext(); + + Builder.SetInsertPoint(LoopBody); + // Create phi of base. + PHINode *Base = Builder.CreatePHI(BaseTy, 2, "base"); + Base->addIncoming(OrigBase, PreLoopBB); + // Create phi of exponent. + PHINode *Exp = Builder.CreatePHI(ExpTy, 2, "exp"); + Exp->addIncoming(OrigExp, PreLoopBB); + // Create phi of res. + PHINode *Res = Builder.CreatePHI(BaseTy, 2, "res"); + Res->addIncoming(ConstantFP::get(BaseTy, 1.), PreLoopBB); + // Res *= Base if Exp is odd. + Value *Tmp = Builder.CreateIntrinsic(BaseTy, Intrinsic::vp_fmul, + {Res, Base, True, EVL}); + Value *And1 = Builder.CreateIntrinsic( + ExpTy, Intrinsic::vp_and, {Exp, ConstantInt::get(ExpTy, 1), True, EVL}); + Value *PredicateNE = getPredicateValue(C, CmpInst::ICMP_NE); + Value *IsOdd = Builder.CreateIntrinsic( + CondTy, Intrinsic::vp_icmp, + {And1, ConstantInt::get(ExpTy, 0), PredicateNE, True, EVL}); + Value *NewRes = Builder.CreateIntrinsic(BaseTy, Intrinsic::vp_select, + {IsOdd, Tmp, Res, EVL}); + Res->addIncoming(NewRes, LoopBody); + // Update Exp. + Value *NewExp = Builder.CreateIntrinsic( + ExpTy, Intrinsic::vp_lshr, {Exp, ConstantInt::get(ExpTy, 1), True, EVL}); + Exp->addIncoming(NewExp, LoopBody); + // Update Base. + Value *NewBase = Builder.CreateIntrinsic(BaseTy, Intrinsic::vp_fmul, + {Base, Base, True, EVL}); + Base->addIncoming(NewBase, LoopBody); + // Check whether the elements of Exp are all zeros. + Type *ExpScalarTy = ExpTy->getScalarType(); + Value *ScalarZero = ConstantInt::get(ExpScalarTy, 0); + Value *OrSum = Builder.CreateIntrinsic(ExpScalarTy, Intrinsic::vp_reduce_or, + {ScalarZero, NewExp, Mask, EVL}); + Builder.CreateCondBr(Builder.CreateICmpEQ(OrSum, ScalarZero), PostLoopBB, + LoopBody); + + Builder.SetInsertPoint(&PostLoopBB->front()); + // Use reciprocal if power is negative. + Value *Recip = + Builder.CreateIntrinsic(BaseTy, Intrinsic::vp_fdiv, + {ConstantFP::get(BaseTy, 1.), NewRes, Mask, EVL}); + // FIXME: Use vp.icmp. + Value *PredicateSLT = getPredicateValue(C, CmpInst::ICMP_SLT); + Value *IsNegative = Builder.CreateIntrinsic( + CondTy, Intrinsic::vp_icmp, + {OrigExp, ConstantInt::get(ExpTy, 0), PredicateSLT, True, EVL}); + Value *Powi = Builder.CreateIntrinsic(BaseTy, Intrinsic::vp_select, + {IsNegative, Recip, NewRes, EVL}); + + II->replaceAllUsesWith(Powi); + II->eraseFromParent(); +} + +// TODO: Add cost model to skip small fixed vectors powi. +static bool runImpl(Function &F) { + SmallVector Replace; + for (auto &I : instructions(F)) { + if (auto *II = dyn_cast(&I)) { + // TODO: Also support llvm.powi. + if (II->getIntrinsicID() == Intrinsic::vp_powi) { + Replace.push_back(II); + } + } + } + + if (Replace.empty()) + return false; + + for (IntrinsicInst *II : Replace) + expandPowi(II); + + return true; +} + +namespace { +class ExpandPowiLegacyPass : public FunctionPass { +public: + static char ID; + + ExpandPowiLegacyPass() : FunctionPass(ID) { + initializeExpandPowiLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { return runImpl(F); } +}; +} // namespace + +char ExpandPowiLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(ExpandPowiLegacyPass, "expand-powi", + "Expand powi functions", false, false) +INITIALIZE_PASS_END(ExpandPowiLegacyPass, "expand-powi", + "Expand powi functions", false, false) + +FunctionPass *llvm::createExpandPowiPass() { + return new ExpandPowiLegacyPass(); +} diff --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp --- a/llvm/lib/CodeGen/TargetPassConfig.cpp +++ b/llvm/lib/CodeGen/TargetPassConfig.cpp @@ -1089,6 +1089,7 @@ PM->add(createTargetTransformInfoWrapperPass(TM->getTargetIRAnalysis())); addPass(createExpandLargeDivRemPass()); addPass(createExpandLargeFpConvertPass()); + addPass(createExpandPowiPass()); addIRPasses(); addCodeGenPrepare(); addPassesToHandleExceptions(); diff --git a/llvm/test/CodeGen/Generic/expand-powi.ll b/llvm/test/CodeGen/Generic/expand-powi.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/Generic/expand-powi.ll @@ -0,0 +1,30 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -mtriple=x86_64-unknown-linux-gnu -expand-powi -S < %s | FileCheck %s +declare @llvm.vp.powi.nxv1f32.nxv1i32(, , , i32) +define @foo( %a, %b, %m, i32 %evl) { +; CHECK-LABEL: @foo( +; CHECK-NEXT: entry: +; CHECK-NEXT: br label [[POWI_FORWARD_LOOP:%.*]] +; CHECK: powi-forward-loop: +; CHECK-NEXT: [[BASE:%.*]] = phi [ [[A:%.*]], [[ENTRY:%.*]] ], [ [[TMP5:%.*]], [[POWI_FORWARD_LOOP]] ] +; CHECK-NEXT: [[EXP:%.*]] = phi [ [[B:%.*]], [[ENTRY]] ], [ [[TMP4:%.*]], [[POWI_FORWARD_LOOP]] ] +; CHECK-NEXT: [[RES:%.*]] = phi [ shufflevector ( insertelement ( poison, float 1.000000e+00, i64 0), poison, zeroinitializer), [[ENTRY]] ], [ [[TMP3:%.*]], [[POWI_FORWARD_LOOP]] ] +; CHECK-NEXT: [[TMP0:%.*]] = call @llvm.vp.fmul.nxv1f32( [[RES]], [[BASE]], shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer), i32 [[EVL:%.*]]) +; CHECK-NEXT: [[TMP1:%.*]] = call @llvm.vp.and.nxv1i32( [[EXP]], shufflevector ( insertelement ( poison, i32 1, i64 0), poison, zeroinitializer), shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer), i32 [[EVL]]) +; CHECK-NEXT: [[TMP2:%.*]] = call @llvm.vp.icmp.nxv1i32( [[TMP1]], zeroinitializer, metadata !"ne", shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer), i32 [[EVL]]) +; CHECK-NEXT: [[TMP3]] = call @llvm.vp.select.nxv1f32( [[TMP2]], [[TMP0]], [[RES]], i32 [[EVL]]) +; CHECK-NEXT: [[TMP4]] = call @llvm.vp.lshr.nxv1i32( [[EXP]], shufflevector ( insertelement ( poison, i32 1, i64 0), poison, zeroinitializer), shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer), i32 [[EVL]]) +; CHECK-NEXT: [[TMP5]] = call @llvm.vp.fmul.nxv1f32( [[BASE]], [[BASE]], shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer), i32 [[EVL]]) +; CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.vp.reduce.or.nxv1i32(i32 0, [[TMP4]], [[M:%.*]], i32 [[EVL]]) +; CHECK-NEXT: [[TMP7:%.*]] = icmp eq i32 [[TMP6]], 0 +; CHECK-NEXT: br i1 [[TMP7]], label [[POWI_POST_LOOP:%.*]], label [[POWI_FORWARD_LOOP]] +; CHECK: powi-post-loop: +; CHECK-NEXT: [[TMP8:%.*]] = call @llvm.vp.fdiv.nxv1f32( shufflevector ( insertelement ( poison, float 1.000000e+00, i64 0), poison, zeroinitializer), [[TMP3]], [[M]], i32 [[EVL]]) +; CHECK-NEXT: [[TMP9:%.*]] = call @llvm.vp.icmp.nxv1i32( [[B]], zeroinitializer, metadata !"slt", shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer), i32 [[EVL]]) +; CHECK-NEXT: [[TMP10:%.*]] = call @llvm.vp.select.nxv1f32( [[TMP9]], [[TMP8]], [[TMP3]], i32 [[EVL]]) +; CHECK-NEXT: ret [[TMP10]] +; +entry: + %0 = call @llvm.vp.powi.nxv1f32.nxv1i32( %a, %b, %m, i32 %evl) + ret %0 +} diff --git a/llvm/tools/opt/opt.cpp b/llvm/tools/opt/opt.cpp --- a/llvm/tools/opt/opt.cpp +++ b/llvm/tools/opt/opt.cpp @@ -394,6 +394,7 @@ "fix-irreducible", "expand-large-fp-convert", "callbrprepare", + "expand-powi", }; for (const auto &P : PassNamePrefix) if (Pass.startswith(P)) @@ -443,6 +444,7 @@ initializeExpandLargeDivRemLegacyPassPass(Registry); initializeExpandLargeFpConvertLegacyPassPass(Registry); initializeExpandMemCmpPassPass(Registry); + initializeExpandPowiLegacyPassPass(Registry); initializeScalarizeMaskedMemIntrinLegacyPassPass(Registry); initializeSelectOptimizePass(Registry); initializeCallBrPreparePass(Registry);