diff --git a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp --- a/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp +++ b/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp @@ -62,6 +62,7 @@ STATISTIC(NumUDivs, "Number of udivs whose width was decreased"); STATISTIC(NumAShrs, "Number of ashr converted to lshr"); STATISTIC(NumSRems, "Number of srem converted to urem"); +STATISTIC(NumSExt, "Number of sext converted to zext"); STATISTIC(NumOverflows, "Number of overflow checks removed"); STATISTIC(NumSaturating, "Number of saturating arithmetics converted to normal arithmetics"); @@ -637,6 +638,27 @@ return true; } +static bool processSExt(SExtInst *SDI, LazyValueInfo *LVI) { + if (SDI->getType()->isVectorTy()) + return false; + + Value *Base = SDI->getOperand(0); + + Constant *Zero = ConstantInt::get(Base->getType(), 0); + if (LVI->getPredicateAt(ICmpInst::ICMP_SGE, Base, Zero, SDI) != + LazyValueInfo::True) + return false; + + ++NumSExt; + auto *ZExt = + CastInst::CreateZExtOrBitCast(Base, SDI->getType(), SDI->getName(), SDI); + ZExt->setDebugLoc(SDI->getDebugLoc()); + SDI->replaceAllUsesWith(ZExt); + SDI->eraseFromParent(); + + return true; +} + static bool processBinOp(BinaryOperator *BinOp, LazyValueInfo *LVI) { using OBO = OverflowingBinaryOperator; @@ -745,6 +767,9 @@ case Instruction::AShr: BBChanged |= processAShr(cast(II), LVI); break; + case Instruction::SExt: + BBChanged |= processSExt(cast(II), LVI); + break; case Instruction::Add: case Instruction::Sub: BBChanged |= processBinOp(cast(II), LVI); diff --git a/llvm/test/Transforms/CorrelatedValuePropagation/sext.ll b/llvm/test/Transforms/CorrelatedValuePropagation/sext.ll --- a/llvm/test/Transforms/CorrelatedValuePropagation/sext.ll +++ b/llvm/test/Transforms/CorrelatedValuePropagation/sext.ll @@ -18,9 +18,9 @@ ; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[A]], 1 ; CHECK-NEXT: br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_END:%.*]] ; CHECK: for.body: -; CHECK-NEXT: [[EXT_WIDE:%.*]] = sext i32 [[A]] to i64 -; CHECK-NEXT: call void @use64(i64 [[EXT_WIDE]]) -; CHECK-NEXT: [[EXT]] = trunc i64 [[EXT_WIDE]] to i32 +; CHECK-NEXT: [[EXT_WIDE1:%.*]] = zext i32 [[A]] to i64 +; CHECK-NEXT: call void @use64(i64 [[EXT_WIDE1]]) +; CHECK-NEXT: [[EXT]] = trunc i64 [[EXT_WIDE1]] to i32 ; CHECK-NEXT: br label [[FOR_COND]] ; CHECK: for.end: ; CHECK-NEXT: ret void @@ -85,9 +85,9 @@ ; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[N:%.*]], 0 ; CHECK-NEXT: br i1 [[CMP]], label [[BB:%.*]], label [[EXIT:%.*]] ; CHECK: bb: -; CHECK-NEXT: [[EXT_WIDE:%.*]] = sext i32 [[N]] to i64 -; CHECK-NEXT: call void @use64(i64 [[EXT_WIDE]]) -; CHECK-NEXT: [[EXT:%.*]] = trunc i64 [[EXT_WIDE]] to i32 +; CHECK-NEXT: [[EXT_WIDE1:%.*]] = zext i32 [[N]] to i64 +; CHECK-NEXT: call void @use64(i64 [[EXT_WIDE1]]) +; CHECK-NEXT: [[EXT:%.*]] = trunc i64 [[EXT_WIDE1]] to i32 ; CHECK-NEXT: br label [[EXIT]] ; CHECK: exit: ; CHECK-NEXT: ret void