Index: include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- include/llvm/Analysis/TargetTransformInfo.h +++ include/llvm/Analysis/TargetTransformInfo.h @@ -313,6 +313,10 @@ /// by referencing its sub-register AX. bool isTruncateFree(Type *Ty1, Type *Ty2) const; + /// \brief Return true if it is profitable to hoist instruction in the + /// then/else to before if. + bool isProfitableToHoist(Instruction *I) const; + /// \brief Return true if this type is legal. bool isTypeLegal(Type *Ty) const; @@ -521,6 +525,7 @@ int64_t BaseOffset, bool HasBaseReg, int64_t Scale) = 0; virtual bool isTruncateFree(Type *Ty1, Type *Ty2) = 0; + virtual bool isProfitableToHoist(Instruction *I) = 0; virtual bool isTypeLegal(Type *Ty) = 0; virtual unsigned getJumpBufAlignment() = 0; virtual unsigned getJumpBufSize() = 0; @@ -633,6 +638,9 @@ bool isTruncateFree(Type *Ty1, Type *Ty2) override { return Impl.isTruncateFree(Ty1, Ty2); } + bool isProfitableToHoist(Instruction *I) override { + return Impl.isProfitableToHoist(I); + } bool isTypeLegal(Type *Ty) override { return Impl.isTypeLegal(Ty); } unsigned getJumpBufAlignment() override { return Impl.getJumpBufAlignment(); } unsigned getJumpBufSize() override { return Impl.getJumpBufSize(); } Index: include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- include/llvm/Analysis/TargetTransformInfoImpl.h +++ include/llvm/Analysis/TargetTransformInfoImpl.h @@ -225,6 +225,8 @@ bool isTruncateFree(Type *Ty1, Type *Ty2) { return false; } + bool isProfitableToHoist(Instruction *I) { return true; } + bool isTypeLegal(Type *Ty) { return false; } unsigned getJumpBufAlignment() { return 0; } Index: include/llvm/CodeGen/BasicTTIImpl.h =================================================================== --- include/llvm/CodeGen/BasicTTIImpl.h +++ include/llvm/CodeGen/BasicTTIImpl.h @@ -145,6 +145,10 @@ return getTLI()->isTruncateFree(Ty1, Ty2); } + bool isProfitableToHoist(Instruction *I) { + return getTLI()->isProfitableToHoist(I); + } + bool isTypeLegal(Type *Ty) { EVT VT = getTLI()->getValueType(Ty); return getTLI()->isTypeLegal(VT); Index: include/llvm/Target/TargetLowering.h =================================================================== --- include/llvm/Target/TargetLowering.h +++ include/llvm/Target/TargetLowering.h @@ -1456,6 +1456,8 @@ return false; } + virtual bool isProfitableToHoist(Instruction *I) const { return true; } + /// Return true if any actual instruction that defines a value of type Ty1 /// implicitly zero-extends the value to Ty2 in the result register. /// Index: lib/Analysis/TargetTransformInfo.cpp =================================================================== --- lib/Analysis/TargetTransformInfo.cpp +++ lib/Analysis/TargetTransformInfo.cpp @@ -123,6 +123,10 @@ return TTIImpl->isTruncateFree(Ty1, Ty2); } +bool TargetTransformInfo::isProfitableToHoist(Instruction *I) const { + return TTIImpl->isProfitableToHoist(I); +} + bool TargetTransformInfo::isTypeLegal(Type *Ty) const { return TTIImpl->isTypeLegal(Ty); } Index: lib/Target/AArch64/AArch64ISelLowering.h =================================================================== --- lib/Target/AArch64/AArch64ISelLowering.h +++ lib/Target/AArch64/AArch64ISelLowering.h @@ -18,6 +18,7 @@ #include "llvm/CodeGen/CallingConvLower.h" #include "llvm/CodeGen/SelectionDAG.h" #include "llvm/IR/CallingConv.h" +#include "llvm/IR/Instruction.h" #include "llvm/Target/TargetLowering.h" namespace llvm { @@ -286,6 +287,8 @@ bool isTruncateFree(Type *Ty1, Type *Ty2) const override; bool isTruncateFree(EVT VT1, EVT VT2) const override; + bool isProfitableToHoist(Instruction *I) const override; + bool isZExtFree(Type *Ty1, Type *Ty2) const override; bool isZExtFree(EVT VT1, EVT VT2) const override; bool isZExtFree(SDValue Val, EVT VT2) const override; Index: lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- lib/Target/AArch64/AArch64ISelLowering.cpp +++ lib/Target/AArch64/AArch64ISelLowering.cpp @@ -6535,6 +6535,34 @@ return NumBits1 > NumBits2; } +/// Check if it is profitable to hoist instruction in then/else to if. +/// Not profitable if I and it's user can form a FMA instruction +/// because we prefer FMSUB/FMADD. +bool AArch64TargetLowering::isProfitableToHoist(Instruction *I) const { + if (I->getOpcode() != Instruction::FMul) + return true; + + if (I->getNumUses() != 1) + return true; + + Instruction *User = I->user_back(); + + if (User && + !(User->getOpcode() == Instruction::FSub || + User->getOpcode() == Instruction::FAdd)) + return true; + + const TargetOptions &Options = getTargetMachine().Options; + EVT VT = getValueType(User->getOperand(0)->getType()); + + if (isFMAFasterThanFMulAndFAdd(VT) && + isOperationLegalOrCustom(ISD::FMA, VT) && + (Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath)) + return false; + + return true; +} + // All 32-bit GPR operations implicitly zero the high-half of the corresponding // 64-bit GPR. bool AArch64TargetLowering::isZExtFree(Type *Ty1, Type *Ty2) const { Index: lib/Transforms/Utils/SimplifyCFG.cpp =================================================================== --- lib/Transforms/Utils/SimplifyCFG.cpp +++ lib/Transforms/Utils/SimplifyCFG.cpp @@ -1053,7 +1053,8 @@ /// HoistThenElseCodeToIf - Given a conditional branch that goes to BB1 and /// BB2, hoist any common code in the two blocks up into the branch block. The /// caller of this function guarantees that BI's block dominates BB1 and BB2. -static bool HoistThenElseCodeToIf(BranchInst *BI, const DataLayout *DL) { +static bool HoistThenElseCodeToIf(BranchInst *BI, const DataLayout *DL, + const TargetTransformInfo &TTI) { // This does very trivial matching, with limited scanning, to find identical // instructions in the two blocks. In particular, we don't want to get into // O(M*N) situations here where M and N are the sizes of BB1 and BB2. As @@ -1088,6 +1089,9 @@ if (isa(I1)) goto HoistTerminator; + if (!TTI.isProfitableToHoist(I1) || !TTI.isProfitableToHoist(I2)) + return Changed; + // For a normal instruction, we just move one to right before the branch, // then replace all uses of the other with the first. Finally, we remove // the now redundant second instruction. @@ -4444,7 +4448,7 @@ // can hoist it up to the branching block. if (BI->getSuccessor(0)->getSinglePredecessor()) { if (BI->getSuccessor(1)->getSinglePredecessor()) { - if (HoistThenElseCodeToIf(BI, DL)) + if (HoistThenElseCodeToIf(BI, DL, TTI)) return SimplifyCFG(BB, TTI, BonusInstThreshold, DL, AC) | true; } else { // If Successor #1 has multiple preds, we may be able to conditionally Index: test/Transforms/SimplifyCFG/AArch64/lit.local.cfg =================================================================== --- /dev/null +++ test/Transforms/SimplifyCFG/AArch64/lit.local.cfg @@ -0,0 +1,5 @@ +config.suffixes = ['.ll'] + +targets = set(config.root.targets_to_build.split()) +if not 'AArch64' in targets: + config.unsupported = True Index: test/Transforms/SimplifyCFG/AArch64/prefer-fma.ll =================================================================== --- /dev/null +++ test/Transforms/SimplifyCFG/AArch64/prefer-fma.ll @@ -0,0 +1,72 @@ +; RUN: opt < %s -mtriple=aarch64-linux-gnu -simplifycfg -enable-unsafe-fp-math -S >%t +; RUN: FileCheck %s < %t +; ModuleID = 't.cc' + +; Function Attrs: nounwind +define double @_Z3fooRdS_S_S_(double* dereferenceable(8) %x, double* dereferenceable(8) %y, double* dereferenceable(8) %a) #0 { +entry: + %0 = load double* %y, align 8 + %cmp = fcmp oeq double %0, 0.000000e+00 + %1 = load double* %x, align 8 + br i1 %cmp, label %if.then, label %if.else + +; fadd (const, (fmul x, y)) +if.then: ; preds = %entry +; CHECK-LABEL: if.then: +; CHECK: %3 = fmul fast double %1, %2 +; CHECK-NEXT: %mul = fadd fast double 1.000000e+00, %3 + %2 = load double* %a, align 8 + %3 = fmul fast double %1, %2 + %mul = fadd fast double 1.000000e+00, %3 + store double %mul, double* %y, align 8 + br label %if.end + +; fsub ((fmul x, y), z) +if.else: ; preds = %entry +; CHECK-LABEL: if.else: +; CHECK: %mul1 = fmul fast double %1, %2 +; CHECK-NEXT: %sub1 = fsub fast double %mul1, %0 + %4 = load double* %a, align 8 + %mul1 = fmul fast double %1, %4 + %sub1 = fsub fast double %mul1, %0 + store double %sub1, double* %y, align 8 + br label %if.end + +if.end: ; preds = %if.else, %if.then + %5 = load double* %y, align 8 + %cmp2 = fcmp oeq double %5, 2.000000e+00 + %6 = load double* %x, align 8 + br i1 %cmp2, label %if.then2, label %if.else2 + +; fsub (x, (fmul y, z)) +if.then2: ; preds = %entry +; CHECK-LABEL: if.then2: +; CHECK: %7 = fmul fast double %5, 3.000000e+00 +; CHECK-NEXT: %mul2 = fsub fast double %6, %7 + %7 = load double* %a, align 8 + %8 = fmul fast double %6, 3.0000000e+00 + %mul2 = fsub fast double %7, %8 + store double %mul2, double* %y, align 8 + br label %if.end2 + +; fsub (fneg((fmul x, y)), const) +if.else2: ; preds = %entry +; CHECK-LABEL: if.else2: +; CHECK: %mul3 = fmul fast double %5, 3.000000e+00 +; CHECK-NEXT: %neg = fsub fast double 0.000000e+00, %mul3 +; CHECK-NEXT: %sub2 = fsub fast double %neg, 3.000000e+00 + %mul3 = fmul fast double %6, 3.0000000e+00 + %neg = fsub fast double 0.0000000e+00, %mul3 + %sub2 = fsub fast double %neg, 3.0000000e+00 + store double %sub2, double* %y, align 8 + br label %if.end2 + +if.end2: ; preds = %if.else, %if.then + %9 = load double* %x, align 8 + %10 = load double* %y, align 8 + %add = fadd fast double %9, %10 + %11 = load double* %a, align 8 + %add2 = fadd fast double %add, %11 + ret double %add2 +} +