diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -6538,7 +6538,8 @@ case Instruction::Sub: case Instruction::And: case Instruction::Or: - case Instruction::Mul: { + case Instruction::Mul: + case Instruction::FMul: { Value *LL = LU->getOperand(0); Value *LR = LU->getOperand(1); // Find a recurrence. diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -15,6 +15,7 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" @@ -672,6 +673,15 @@ } } + // Simplify FMUL recurrences starting with 0.0 to 0.0 if nnan and nsz are set. + // Given a phi node with entry value as 0 and it used in fmul operation, + // we can replace fmul with 0 safely and eleminate loop operation. + PHINode *PN = nullptr; + Value *Start = nullptr, *Step = nullptr; + if (matchSimpleRecurrence(&I, PN, Start, Step) && I.hasNoNaNs() && + I.hasNoSignedZeros() && match(Start, m_Zero())) + return replaceInstUsesWith(I, Start); + return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/remove-loop-phi-fastmul.ll b/llvm/test/Transforms/InstCombine/remove-loop-phi-fastmul.ll --- a/llvm/test/Transforms/InstCombine/remove-loop-phi-fastmul.ll +++ b/llvm/test/Transforms/InstCombine/remove-loop-phi-fastmul.ll @@ -6,15 +6,11 @@ ; CHECK-NEXT: br label [[FOR_BODY:%.*]] ; CHECK: for.body: ; CHECK-NEXT: [[I_02:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INC:%.*]], [[FOR_BODY]] ] -; CHECK-NEXT: [[F_PROD_01:%.*]] = phi double [ 0.000000e+00, [[ENTRY]] ], [ [[MUL:%.*]], [[FOR_BODY]] ] -; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds [1000 x double], ptr [[ARR_D:%.*]], i64 0, i64 [[I_02]] -; CHECK-NEXT: [[TMP0:%.*]] = load double, ptr [[ARRAYIDX]], align 8 -; CHECK-NEXT: [[MUL]] = fmul fast double [[F_PROD_01]], [[TMP0]] ; CHECK-NEXT: [[INC]] = add i64 [[I_02]], 1 ; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[INC]], 1000 ; CHECK-NEXT: br i1 [[CMP]], label [[FOR_BODY]], label [[END:%.*]] ; CHECK: end: -; CHECK-NEXT: ret double [[MUL]] +; CHECK-NEXT: ret double 0.000000e+00 ; entry: br label %for.body @@ -40,15 +36,11 @@ ; CHECK-NEXT: br label [[FOR_BODY:%.*]] ; CHECK: for.body: ; CHECK-NEXT: [[I_02:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INC:%.*]], [[FOR_BODY]] ] -; CHECK-NEXT: [[F_PROD_01:%.*]] = phi double [ 0.000000e+00, [[ENTRY]] ], [ [[MUL:%.*]], [[FOR_BODY]] ] -; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds [1000 x double], ptr [[ARR_D:%.*]], i64 0, i64 [[I_02]] -; CHECK-NEXT: [[TMP0:%.*]] = load double, ptr [[ARRAYIDX]], align 8 -; CHECK-NEXT: [[MUL]] = fmul nnan nsz double [[F_PROD_01]], [[TMP0]] ; CHECK-NEXT: [[INC]] = add i64 [[I_02]], 1 ; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[INC]], 1000 ; CHECK-NEXT: br i1 [[CMP]], label [[FOR_BODY]], label [[END:%.*]] ; CHECK: end: -; CHECK-NEXT: ret double [[MUL]] +; CHECK-NEXT: ret double 0.000000e+00 ; entry: br label %for.body