diff --git a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp --- a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp +++ b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp @@ -785,8 +785,10 @@ /// Try to shrink a udiv/urem's width down to the smallest power of two that's /// sufficient to contain its operands. -static bool narrowUDivOrURem(BinaryOperator *Instr, const ConstantRange &XCR, - const ConstantRange &YCR) { +static bool narrowUDivOrURem(BinaryOperator *Instr, + const ConstantRange &XCRAtUse, + const ConstantRange &YCRAtUse, + const ConstantRange &YCRAtDef) { assert(Instr->getOpcode() == Instruction::UDiv || Instr->getOpcode() == Instruction::URem); assert(!Instr->getType()->isVectorTy()); @@ -796,7 +798,8 @@ // What is the smallest bit width that can accommodate the entire value ranges // of both of the operands? - unsigned MaxActiveBits = std::max(XCR.getActiveBits(), YCR.getActiveBits()); + unsigned MaxActiveBits = + std::max(XCRAtUse.getActiveBits(), YCRAtUse.getActiveBits()); // Don't shrink below 8 bits wide. unsigned NewWidth = std::max(PowerOf2Ceil(MaxActiveBits), 8); @@ -805,6 +808,13 @@ if (NewWidth >= Instr->getType()->getIntegerBitWidth()) return false; + // For divisor, if the constant range computed at define point does not + // contain zero, then we should ensure the constant range after truncating + // also does not contain zero. Otherwise we may introduce an UB. + if (!YCRAtDef.contains(APInt::getZero(YCRAtDef.getBitWidth())) && + YCRAtDef.truncate(NewWidth).contains(APInt::getZero(NewWidth))) + return false; + ++NumUDivURemsNarrowed; IRBuilder<> B{Instr}; auto *TruncTy = Type::getIntNTy(Instr->getContext(), NewWidth); @@ -829,12 +839,13 @@ if (Instr->getType()->isVectorTy()) return false; - ConstantRange XCR = LVI->getConstantRangeAtUse(Instr->getOperandUse(0)); - ConstantRange YCR = LVI->getConstantRangeAtUse(Instr->getOperandUse(1)); - if (expandUDivOrURem(Instr, XCR, YCR)) + ConstantRange YCRAtDef = LVI->getConstantRange(Instr->getOperand(1), Instr); + ConstantRange XCRAtUse = LVI->getConstantRangeAtUse(Instr->getOperandUse(0)); + ConstantRange YCRAtUse = LVI->getConstantRangeAtUse(Instr->getOperandUse(1)); + if (expandUDivOrURem(Instr, XCRAtUse, YCRAtUse)) return true; - return narrowUDivOrURem(Instr, XCR, YCR); + return narrowUDivOrURem(Instr, XCRAtUse, YCRAtUse, YCRAtDef); } static bool processSRem(BinaryOperator *SDI, const ConstantRange &LCR, diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/udiv.ll b/llvm/test/Transforms/CorrelatedValuePropagation/udiv.ll --- a/llvm/test/Transforms/CorrelatedValuePropagation/udiv.ll +++ b/llvm/test/Transforms/CorrelatedValuePropagation/udiv.ll @@ -99,3 +99,4 @@ exit: ret void } + diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/urem.ll b/llvm/test/Transforms/CorrelatedValuePropagation/urem.ll --- a/llvm/test/Transforms/CorrelatedValuePropagation/urem.ll +++ b/llvm/test/Transforms/CorrelatedValuePropagation/urem.ll @@ -394,4 +394,21 @@ ret void } +define i32 @udiv_do_not_truncate(i32 noundef %call, i32 %v) { +; CHECK-LABEL: @udiv_do_not_truncate( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[CALL_NON_NULL:%.*]] = or i32 [[CALL:%.*]], 1 +; CHECK-NEXT: [[DIV:%.*]] = urem i32 8192, [[CALL_NON_NULL]] +; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[CALL_NON_NULL]], 8192 +; CHECK-NEXT: [[SELECT:%.*]] = select i1 [[CMP]], i32 1, i32 [[DIV]] +; CHECK-NEXT: ret i32 [[SELECT]] +; +entry: + %call_non_null = or i32 %call, 1 + %div = urem i32 8192, %call_non_null + %cmp = icmp ugt i32 %call_non_null, 8192 + %select = select i1 %cmp, i32 1, i32 %div + ret i32 %select +} + declare void @use(i1)