Index: llvm/include/llvm/IR/ConstantRange.h =================================================================== --- llvm/include/llvm/IR/ConstantRange.h +++ llvm/include/llvm/IR/ConstantRange.h @@ -266,6 +266,22 @@ /// Compare set size of this range with Value. bool isSizeLargerThan(uint64_t MaxSize) const; + /// Return true if all signed values in this range are less than all signed + /// values in Other. + bool isSignedStrictlyLowerThan(const ConstantRange &Other) const; + + /// Return true if all unsigned values in this range are less than all + /// unsigned values in Other. + bool isUnsignedStrictlyLowerThan(const ConstantRange &Other) const; + + /// Return true if all signed values in this range are greater than all signed + /// values in Other. + bool isSignedStrictlyHigherThan(const ConstantRange &Other) const; + + /// Return true if all unsigned values in this range are greater than all + /// unsigned values in Other. + bool isUnsignedStrictlyHigherThan(const ConstantRange &Other) const; + /// Return true if all values in this range are negative. bool isAllNegative() const; Index: llvm/lib/IR/ConstantRange.cpp =================================================================== --- llvm/lib/IR/ConstantRange.cpp +++ llvm/lib/IR/ConstantRange.cpp @@ -404,6 +404,46 @@ return (Upper - Lower).ugt(MaxSize); } +bool ConstantRange::isSignedStrictlyLowerThan( + const ConstantRange &Other) const { + assert(getBitWidth() == Other.getBitWidth()); + if (isFullSet()) + return false; + if (Other.isFullSet()) + return false; + return getSignedMax().slt(Other.getSignedMin()); +} + +bool ConstantRange::isUnsignedStrictlyLowerThan( + const ConstantRange &Other) const { + assert(getBitWidth() == Other.getBitWidth()); + if (isFullSet()) + return false; + if (Other.isFullSet()) + return false; + return getUnsignedMax().ult(Other.getUnsignedMin()); +} + +bool ConstantRange::isSignedStrictlyHigherThan( + const ConstantRange &Other) const { + assert(getBitWidth() == Other.getBitWidth()); + if (isFullSet()) + return false; + if (Other.isFullSet()) + return false; + return getSignedMin().sgt(Other.getSignedMax()); +} + +bool ConstantRange::isUnsignedStrictlyHigherThan( + const ConstantRange &Other) const { + assert(getBitWidth() == Other.getBitWidth()); + if (isFullSet()) + return false; + if (Other.isFullSet()) + return false; + return getUnsignedMin().ugt(Other.getUnsignedMax()); +} + bool ConstantRange::isAllNegative() const { // Empty set is all negative, full set is not. if (isEmptySet()) Index: llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp =================================================================== --- llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp +++ llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp @@ -59,6 +59,7 @@ STATISTIC(NumCmps, "Number of comparisons propagated"); STATISTIC(NumReturns, "Number of return values propagated"); STATISTIC(NumDeadCases, "Number of switch cases removed"); +STATISTIC(NumDivEliminated, "Number of sdivs/udivs replaced with zero"); STATISTIC(NumSDivSRemsNarrowed, "Number of sdivs/srems whose width was decreased"); STATISTIC(NumSDivs, "Number of sdiv converted to udiv"); @@ -66,6 +67,7 @@ "Number of udivs/urems whose width was decreased"); STATISTIC(NumAShrsConverted, "Number of ashr converted to lshr"); STATISTIC(NumAShrsRemoved, "Number of ashr removed"); +STATISTIC(NumLShrsRemoved, "Number of lshr removed"); STATISTIC(NumSRems, "Number of srem converted to urem"); STATISTIC(NumSExt, "Number of sext converted to zext"); STATISTIC(NumSICmps, "Number of signed icmp preds simplified to unsigned"); @@ -519,13 +521,62 @@ CmpInst::Predicate Pred = CmpInst::getNonStrictPredicate(MM->getPredicate()); LazyValueInfo::Tristate Result = LVI->getPredicateAt( Pred, MM->getLHS(), MM->getRHS(), MM, /*UseBlockValue=*/true); - if (Result == LazyValueInfo::Unknown) + if (Result != LazyValueInfo::Unknown) { + ++NumMinMax; + MM->replaceAllUsesWith(MM->getOperand(!Result)); + MM->eraseFromParent(); + return true; + } + + Value *LHS = MM->getArgOperand(0); + Value *RHS = MM->getArgOperand(1); + Value *AlwaysThisSide = nullptr; + + if (MM->getType()->isVectorTy()) return false; - ++NumMinMax; - MM->replaceAllUsesWith(MM->getOperand(!Result)); - MM->eraseFromParent(); - return true; + ConstantRange CRLHS = LVI->getConstantRange(LHS, MM); + // No need to calculate CRRHS if CRLHS one isFullSet + if (CRLHS.isFullSet()) + return false; + ConstantRange CRRHS = LVI->getConstantRange(RHS, MM); + + // Use constant range to see if the max/min always picks one operand. + switch (MM->getIntrinsicID()) { + case Intrinsic::smax: + if (CRLHS.isSignedStrictlyHigherThan(CRRHS)) + AlwaysThisSide = LHS; + else if (CRRHS.isSignedStrictlyHigherThan(CRLHS)) + AlwaysThisSide = RHS; + break; + case Intrinsic::umax: + if (CRLHS.isUnsignedStrictlyHigherThan(CRRHS)) + AlwaysThisSide = LHS; + else if (CRRHS.isUnsignedStrictlyHigherThan(CRLHS)) + AlwaysThisSide = RHS; + break; + case Intrinsic::smin: + if (CRLHS.isSignedStrictlyLowerThan(CRRHS)) + AlwaysThisSide = LHS; + else if (CRRHS.isSignedStrictlyLowerThan(CRLHS)) + AlwaysThisSide = RHS; + break; + case Intrinsic::umin: + if (CRLHS.isUnsignedStrictlyLowerThan(CRRHS)) + AlwaysThisSide = LHS; + else if (CRRHS.isUnsignedStrictlyLowerThan(CRLHS)) + AlwaysThisSide = RHS; + break; + default: + llvm_unreachable("Invalid Intrinsic"); + } + if (AlwaysThisSide) { + ++NumMinMax; + MM->replaceAllUsesWith(AlwaysThisSide); + MM->eraseFromParent(); + return true; + } + return false; } // Rewrite this with.overflow intrinsic as non-overflowing. @@ -831,12 +882,56 @@ return true; } +/// If the largest absolute value in the constant range of a dividend is less +/// than the constant divisor, the quotient must be zero. +static bool isDivResultZero(BinaryOperator *Div, ConstantInt *Divisor, + LazyValueInfo *LVI) { + assert(Div->getOpcode() == Instruction::UDiv || + Div->getOpcode() == Instruction::SDiv); + + ConstantRange CR = LVI->getConstantRange(Div->getOperand(0), Div); + APInt upperBound, Quotient; + switch (Div->getOpcode()) { + case Instruction::UDiv: + upperBound = CR.getUnsignedMax(); + Quotient = upperBound.udiv(Divisor->getValue()); + break; + case Instruction::SDiv: { + APInt SignedMax = CR.getSignedMax(); + APInt SignedMin = CR.getSignedMin(); + upperBound = SignedMax.abs().ugt(SignedMin.abs()) ? SignedMax : SignedMin; + Quotient = upperBound.sdiv(Divisor->getValue()); + } break; + default: + llvm_unreachable("Invalid Div Opcode"); + } + if (Quotient == APInt::getZero(Divisor->getBitWidth())) { + return true; + } + return false; +} + +/// Try to replace udiv with zero, otherwise try to shrink udiv/urem's width +/// down to the smallest power of two that's sufficient to contain its operands. static bool processUDivOrURem(BinaryOperator *Instr, LazyValueInfo *LVI) { assert(Instr->getOpcode() == Instruction::UDiv || Instr->getOpcode() == Instruction::URem); if (Instr->getType()->isVectorTy()) return false; + ConstantInt *CInt; + if (Instr->getOpcode() == Instruction::UDiv && + (CInt = dyn_cast(Instr->getOperand(1)))) { + if (isDivResultZero(Instr, CInt, LVI)) { + LLVM_DEBUG(dbgs() << *Instr + << " will always be zero. Replacing uses with zero.\n"); + Instr->replaceAllUsesWith(Constant::getNullValue(Instr->getType())); + Instr->eraseFromParent(); + ++NumDivEliminated; + return true; + } + } + ConstantRange XCR = LVI->getConstantRangeAtUse(Instr->getOperandUse(0)); ConstantRange YCR = LVI->getConstantRangeAtUse(Instr->getOperandUse(1)); if (expandUDivOrURem(Instr, XCR, YCR)) @@ -899,16 +994,28 @@ return true; } -/// See if LazyValueInfo's ability to exploit edge conditions or range -/// information is sufficient to prove the signs of both operands of this SDiv. -/// If this is the case, replace the SDiv with a UDiv. Even for local -/// conditions, this can sometimes prove conditions instcombine can't by -/// exploiting range information. +/// Try to eliminate SDiv, otherwise see if LazyValueInfo's ability to exploit +/// edge conditions or range information is sufficient to prove the signs of +/// both operands of this SDiv. If this is the case, replace the SDiv with a +/// UDiv. Even for local conditions, this can sometimes prove conditions +/// instcombine can't by exploiting range information. static bool processSDiv(BinaryOperator *SDI, const ConstantRange &LCR, const ConstantRange &RCR, LazyValueInfo *LVI) { assert(SDI->getOpcode() == Instruction::SDiv); assert(!SDI->getType()->isVectorTy()); + ConstantInt *CInt; + if ((CInt = dyn_cast(SDI->getOperand(1)))) { + if (isDivResultZero(SDI, CInt, LVI)) { + LLVM_DEBUG(dbgs() << *SDI + << " will always be zero. Replacing uses with zero.\n"); + SDI->replaceAllUsesWith(Constant::getNullValue(SDI->getType())); + SDI->eraseFromParent(); + ++NumDivEliminated; + return true; + } + } + struct Operand { Value *V; Domain D; @@ -971,6 +1078,27 @@ return narrowSDivOrSRem(Instr, LCR, RCR); } +static bool processLShr(BinaryOperator *SDI, LazyValueInfo *LVI) { + if (SDI->getType()->isVectorTy()) + return false; + + ConstantRange LRange = LVI->getConstantRange(SDI->getOperand(0), SDI); + + if (auto *CInt = dyn_cast(SDI->getOperand(1))) { + APInt UpperBound = LRange.getUnsignedMax(); + if (UpperBound.lshr(CInt->getValue()) == + APInt::getZero(LRange.getBitWidth())) { + LLVM_DEBUG(dbgs() << *SDI + << " will always be zero. Replace uses with zero.\n"); + ++NumLShrsRemoved; + SDI->replaceAllUsesWith(Constant::getNullValue(SDI->getType())); + SDI->eraseFromParent(); + return true; + } + } + return false; +} + static bool processAShr(BinaryOperator *SDI, LazyValueInfo *LVI) { if (SDI->getType()->isVectorTy()) return false; @@ -990,6 +1118,21 @@ if (!LRange.isAllNonNegative()) return false; + ConstantInt *CInt; + if ((CInt = dyn_cast(SDI->getOperand(1))) && + LRange.isAllNonNegative()) { + APInt UpperBound = LRange.getUnsignedMax(); + if (UpperBound.ashr(CInt->getValue()) == + APInt::getZero(LRange.getBitWidth())) { + LLVM_DEBUG(dbgs() << *SDI + << " will always be zero. Replace uses with zero.\n"); + ++NumAShrsRemoved; + SDI->replaceAllUsesWith(Constant::getNullValue(SDI->getType())); + SDI->eraseFromParent(); + return true; + } + } + ++NumAShrsConverted; auto *BO = BinaryOperator::CreateLShr(SDI->getOperand(0), SDI->getOperand(1), "", SDI); @@ -1147,6 +1290,9 @@ case Instruction::AShr: BBChanged |= processAShr(cast(&II), LVI); break; + case Instruction::LShr: + BBChanged |= processLShr(cast(&II), LVI); + break; case Instruction::SExt: BBChanged |= processSExt(cast(&II), LVI); break; Index: llvm/test/Transforms/CorrelatedValuePropagation/ashr.ll =================================================================== --- llvm/test/Transforms/CorrelatedValuePropagation/ashr.ll +++ llvm/test/Transforms/CorrelatedValuePropagation/ashr.ll @@ -140,3 +140,13 @@ %2 = select i1 %s, i32 %0, i32 %1 ret i32 %2 } + +; check that an ashr of more bits than operand is optimized away +; CHECK-LABEL: @test9 +define void @test9() { +; CHECK-NOT: ashr + %shr = ashr i8 127, 7 +; CHECK: %add = add nuw nsw i8 0, 1 + %add = add i8 %shr, 1 + ret void +} Index: llvm/test/Transforms/CorrelatedValuePropagation/lshr.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/CorrelatedValuePropagation/lshr.ll @@ -0,0 +1,25 @@ +; RUN: opt < %s -passes=correlated-propagation -S | FileCheck %s + +; Check that debug locations are preserved. For more info see: +; https://llvm.org/docs/SourceLevelDebugging.html#fixing-errors +; RUN: opt < %s -enable-debugify -passes=correlated-propagation -S 2>&1 | \ +; RUN: FileCheck %s -check-prefix=DEBUG +; DEBUG: CheckModuleDebugify: PASS + +; check that an lshr of fewer bits than operand is not optimized away +; CHECK-LABEL: @test1 +define void @test1() { + %shr = lshr i8 255, 7 +; CHECK: %shr = lshr i8 -1, 7 + ret void +} + +; check that an lshr of more bits than operand is optimized away +; CHECK-LABEL: @test2 +define void @test2() { +; CHECK-NOT: lshr + %shr = lshr i8 127, 7 +; CHECK: %add = add nuw nsw i8 0, 1 + %add = add i8 %shr, 1 + ret void +} Index: llvm/test/Transforms/CorrelatedValuePropagation/min-max.ll =================================================================== --- llvm/test/Transforms/CorrelatedValuePropagation/min-max.ll +++ llvm/test/Transforms/CorrelatedValuePropagation/min-max.ll @@ -235,3 +235,39 @@ %r = call i8 @llvm.smax(i8 %x, i8 42) ret i8 %r } + +; CHECK-LABEL: @test20 +define void @test20() { +; CHECK-NOT: smax + %max = call i8 @llvm.smax(i8 1, i8 255) +; CHECK: %add = add nuw nsw i8 1, 3 + %add = add i8 %max, 3 + ret void +} + +; CHECK-LABEL: @test21 +define void @test21() { +; CHECK-NOT: smin + %min = call i8 @llvm.smin(i8 1, i8 255) +; CHECK: %add = add nsw i8 -1, 3 + %add = add i8 %min, 3 + ret void +} + +; CHECK-LABEL: @test22 +define void @test22() { +; CHECK-NOT: umax + %max = call i8 @llvm.umax(i8 1, i8 255) +; CHECK: %add = add nsw i8 -1, 3 + %add = add i8 %max, 3 + ret void +} + +; CHECK-LABEL: @test23 +define void @test23() { +; CHECK-NOT: umin + %min = call i8 @llvm.umin(i8 1, i8 255) +; CHECK: %add = add nuw nsw i8 1, 3 + %add = add i8 %min, 3 + ret void +} Index: llvm/test/Transforms/CorrelatedValuePropagation/sdiv.ll =================================================================== --- llvm/test/Transforms/CorrelatedValuePropagation/sdiv.ll +++ llvm/test/Transforms/CorrelatedValuePropagation/sdiv.ll @@ -632,3 +632,12 @@ %div = sdiv exact i64 %x, %y ret i64 %div } + +; CHECK-LABEL: @test23 +define void @test23() { +; CHECK-NOT: sdiv + %div = sdiv i64 -1, 2 +; CHECK: %add = add nuw nsw i64 0, 1 + %add = add i64 %div, 1 + ret void +} Index: llvm/test/Transforms/CorrelatedValuePropagation/udiv.ll =================================================================== --- llvm/test/Transforms/CorrelatedValuePropagation/udiv.ll +++ llvm/test/Transforms/CorrelatedValuePropagation/udiv.ll @@ -99,3 +99,26 @@ exit: ret void } + +; CHECK-LABEL: @test7 +define void @test7() { +if.cond: + %cmp = icmp slt i64 1, 2 + br i1 %cmp, label %if, label %else + +if: + %x = add i64 0, 1 + br label %if.end + +else: + %y = add i64 1, 1 + br label %if.end + +if.end: + %phi = phi i64 [ %x, %if ], [ %y, %else ] +; CHECK-NOT: udiv + %div = udiv i64 %phi, 3 +; CHECK: %add = add nuw nsw i64 0, 1 + %add = add nuw nsw i64 %div, 1 + ret void +}