diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -692,6 +692,9 @@ /// then/else to before if. bool isProfitableToHoist(Instruction *I) const; + /// Return true if it is profitable to generate fma instructions. + bool isProfitableToGenerateFMA(Instruction *I) const; + bool useAA() const; /// Return true if this type is legal. @@ -1406,6 +1409,7 @@ virtual bool LSRWithInstrQueries() = 0; virtual bool isTruncateFree(Type *Ty1, Type *Ty2) = 0; virtual bool isProfitableToHoist(Instruction *I) = 0; + virtual bool isProfitableToGenerateFMA(Instruction *I) = 0; virtual bool useAA() = 0; virtual bool isTypeLegal(Type *Ty) = 0; virtual bool shouldBuildLookupTables() = 0; @@ -1757,6 +1761,9 @@ bool isProfitableToHoist(Instruction *I) override { return Impl.isProfitableToHoist(I); } + bool isProfitableToGenerateFMA(Instruction *I) override { + return Impl.isProfitableToGenerateFMA(I); + } bool useAA() override { return Impl.useAA(); } bool isTypeLegal(Type *Ty) override { return Impl.isTypeLegal(Ty); } bool shouldBuildLookupTables() override { diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -252,6 +252,7 @@ bool isTruncateFree(Type *Ty1, Type *Ty2) { return false; } bool isProfitableToHoist(Instruction *I) { return true; } + bool isProfitableToGenerateFMA(Instruction *I) { return false; } bool useAA() { return false; } diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -287,6 +287,10 @@ return getTLI()->isProfitableToHoist(I); } + bool isProfitableToGenerateFMA(Instruction *I) { + return getTLI()->isProfitableToGenerateFMA(I); + } + bool useAA() const { return getST()->useAA(); } bool isTypeLegal(Type *Ty) { diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -2452,6 +2452,7 @@ } virtual bool isProfitableToHoist(Instruction *I) const { return true; } + virtual bool isProfitableToGenerateFMA(Instruction *I) const { return false; } /// Return true if the extension represented by \p I is free. /// Unlikely the is[Z|FP]ExtFree family which is based on types, diff --git a/llvm/include/llvm/Transforms/Scalar/Reassociate.h b/llvm/include/llvm/Transforms/Scalar/Reassociate.h --- a/llvm/include/llvm/Transforms/Scalar/Reassociate.h +++ b/llvm/include/llvm/Transforms/Scalar/Reassociate.h @@ -25,6 +25,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SetVector.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/ValueHandle.h" #include @@ -95,6 +96,7 @@ public: PreservedAnalyses run(Function &F, FunctionAnalysisManager &); + bool runImpl(Function &F, const TargetTransformInfo *TTI); private: void BuildRankMap(Function &F, ReversePostOrderTraversal &RPOT); @@ -126,6 +128,7 @@ Value *OtherOp); Instruction *canonicalizeNegFPConstants(Instruction *I); void BuildPairMap(ReversePostOrderTraversal &RPOT); + const TargetTransformInfo *TTI; }; } // end namespace llvm diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -471,6 +471,10 @@ return TTIImpl->isProfitableToHoist(I); } +bool TargetTransformInfo::isProfitableToGenerateFMA(Instruction *I) const { + return TTIImpl->isProfitableToGenerateFMA(I); +} + bool TargetTransformInfo::useAA() const { return TTIImpl->useAA(); } bool TargetTransformInfo::isTypeLegal(Type *Ty) const { diff --git a/llvm/lib/Target/PowerPC/PPCISelLowering.h b/llvm/lib/Target/PowerPC/PPCISelLowering.h --- a/llvm/lib/Target/PowerPC/PPCISelLowering.h +++ b/llvm/lib/Target/PowerPC/PPCISelLowering.h @@ -924,6 +924,8 @@ /// FMA instruction, because Powerpc prefers FMADD. bool isProfitableToHoist(Instruction *I) const override; + bool isProfitableToGenerateFMA(Instruction *I) const override; + const MCPhysReg *getScratchRegisters(CallingConv::ID CC) const override; // Should we expand the build vector with shuffles? diff --git a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp --- a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp +++ b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp @@ -15561,6 +15561,17 @@ return true; } +bool PPCTargetLowering::isProfitableToGenerateFMA(Instruction *I) const { + const TargetOptions &Options = getTargetMachine().Options; + const Function *F = I->getFunction(); + const DataLayout &DL = F->getParent()->getDataLayout(); + Type *Ty = I->getOperand(0)->getType(); + + return isFMAFasterThanFMulAndFAdd(*F, Ty) && + isOperationLegalOrCustom(ISD::FMA, getValueType(DL, Ty)) && + (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath); +} + const MCPhysReg * PPCTargetLowering::getScratchRegisters(CallingConv::ID) const { // LR is a callee-save register, but we must treat it as clobbered by any call diff --git a/llvm/lib/Transforms/Scalar/Reassociate.cpp b/llvm/lib/Transforms/Scalar/Reassociate.cpp --- a/llvm/lib/Transforms/Scalar/Reassociate.cpp +++ b/llvm/lib/Transforms/Scalar/Reassociate.cpp @@ -1522,6 +1522,11 @@ } } + // On some target, add-mul operation is more preferable, common out mul factor + // has no gain at all and breaks the add-mul folding, so bail out early. + if (TTI->isProfitableToGenerateFMA(I)) + return nullptr; + // Scan the operand list, checking to see if there are any common factors // between operands. Consider something like A*A+A*B*C+D. We would like to // reassociate this to A*(A+B*C)+D, which reduces the number of multiplies. @@ -2389,7 +2394,7 @@ } } -PreservedAnalyses ReassociatePass::run(Function &F, FunctionAnalysisManager &) { +bool ReassociatePass::runImpl(Function &F, const TargetTransformInfo *_TTI) { // Get the functions basic blocks in Reverse Post Order. This order is used by // BuildRankMap to pre calculate ranks correctly. It also excludes dead basic // blocks (it has been seen that the analysis in this pass could hang when @@ -2411,6 +2416,7 @@ BuildPairMap(RPOT); MadeChange = false; + TTI = _TTI; // Traverse the same blocks that were analysed by BuildRankMap. for (BasicBlock *BI : RPOT) { @@ -2458,7 +2464,15 @@ for (auto &Entry : PairMap) Entry.clear(); - if (MadeChange) { + return MadeChange; + +} + +PreservedAnalyses ReassociatePass::run(Function &F, + FunctionAnalysisManager &AM) { + const TargetTransformInfo *TTI = &AM.getResult(F); + + if (runImpl(F, TTI)) { PreservedAnalyses PA; PA.preserveSet(); PA.preserve(); @@ -2485,10 +2499,10 @@ bool runOnFunction(Function &F) override { if (skipFunction(F)) return false; + const TargetTransformInfo *TTI = + &getAnalysis().getTTI(F); - FunctionAnalysisManager DummyFAM; - auto PA = Impl.run(F, DummyFAM); - return !PA.areAllPreserved(); + return Impl.runImpl(F, TTI); } void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -2496,6 +2510,7 @@ AU.addPreserved(); AU.addPreserved(); AU.addPreserved(); + AU.addRequired(); } }; diff --git a/llvm/test/Transforms/Reassociate/PowerPC/lit.local.cfg b/llvm/test/Transforms/Reassociate/PowerPC/lit.local.cfg new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/Reassociate/PowerPC/lit.local.cfg @@ -0,0 +1,2 @@ +if not 'PowerPC' in config.root.targets: + config.unsupported = True diff --git a/llvm/test/Transforms/Reassociate/PowerPC/prefer-fma.ll b/llvm/test/Transforms/Reassociate/PowerPC/prefer-fma.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/Reassociate/PowerPC/prefer-fma.ll @@ -0,0 +1,51 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -reassociate -S | FileCheck %s + +target datalayout = "e-m:e-i64:64-n32:64" +target triple = "powerpc64le-unknown-linux-gnu" + +; check that on PowerPC target, we don't do mul factor commoning optimization. + +; a * b + a * c -> a * (b + c) has no gain on PowerPC. +; before transform: a * b + a * c -> fma(a*b, a, c) +; after transform: a * (b + c) -> fmul(a, b + c) +; fma, fmul, fadd are all have same latency on PowerPC. +define double @foo(double %0, double %1, double %2) { +; CHECK-LABEL: @foo( +; CHECK-NEXT: [[TMP4:%.*]] = fmul double [[TMP0:%.*]], [[TMP1:%.*]] +; CHECK-NEXT: [[TMP5:%.*]] = fmul double [[TMP0]], [[TMP2:%.*]] +; CHECK-NEXT: [[TMP6:%.*]] = fadd double [[TMP4]], [[TMP5]] +; CHECK-NEXT: ret double [[TMP6]] +; + %4 = fmul double %0, %1 + %5 = fmul double %0, %2 + %6 = fadd double %4, %5 + ret double %6 +} + +; a + 1.47 * b - 1.47 * c + 2.47 * d - 2.47 * e -> a + 1.47 * (b - c) + 2.47 * (d - e) also has no gain. +; before transform: fma( fma( fma( fma(a, 1.47, b), -1.47, c), -2.47, d), 2.47, e) +; after transform: fma( fma(a, 1.47, sub(b, c)), 2.47, sub(d, e)) +; fma fmul, fsub all have same latency on PowerPC. Also we lose the folding opportunity for fma. +define double @fmaChain(double %0, double %1, double %2, double %3, double %4) { +; CHECK-LABEL: @fmaChain( +; CHECK-NEXT: [[TMP6:%.*]] = fmul double [[TMP1:%.*]], 1.470000e+00 +; CHECK-NEXT: [[TMP7:%.*]] = fadd double [[TMP0:%.*]], [[TMP6]] +; CHECK-NEXT: [[TMP8:%.*]] = fmul double [[TMP2:%.*]], 1.470000e+00 +; CHECK-NEXT: [[TMP9:%.*]] = fsub double [[TMP7]], [[TMP8]] +; CHECK-NEXT: [[TMP10:%.*]] = fmul double [[TMP3:%.*]], 2.470000e+00 +; CHECK-NEXT: [[TMP11:%.*]] = fadd double [[TMP9]], [[TMP10]] +; CHECK-NEXT: [[TMP12:%.*]] = fmul double [[TMP4:%.*]], 2.470000e+00 +; CHECK-NEXT: [[TMP13:%.*]] = fsub double [[TMP11]], [[TMP12]] +; CHECK-NEXT: ret double [[TMP13]] +; + %6 = fmul double %1, 1.470000e+00 + %7 = fadd double %6, %0 + %8 = fmul double %2, 1.470000e+00 + %9 = fsub double %7, %8 + %10 = fmul double %3, 2.470000e+00 + %11 = fadd double %9, %10 + %12 = fmul double %4, 2.470000e+00 + %13 = fsub double %11, %12 + ret double %13 +}