diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp --- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -787,6 +787,49 @@ Value *SCEVExpander::visitMulExpr(const SCEVMulExpr *S) { Type *Ty = SE.getEffectiveSCEVType(S->getType()); + auto IsSignumExpr = [](const SCEV *E, const SCEV *&X) { + auto *SMin = dyn_cast(E); + if (!SMin || !SMin->getOperand(0)->isOne()) + return false; + auto SMax = dyn_cast(SMin->getOperand(1)); + if (!SMax) + return false; + + auto *C = dyn_cast(SMax->getOperand(0)); + if (!C || !C->getValue()->isMinusOne()) + return false; + + X = SMax->getOperand(1); + return true; + }; + auto IsDivAbs = [this](const SCEV *E, const SCEV *&X, int32_t &Amount) { + // FIXME: We need to match for `UDIV exact`, but that's not possible at + // the moment! + auto *Div = dyn_cast(E); + if (!Div) + return false; + + auto *Abs = dyn_cast(Div->getOperand(0)); + if (!Abs || Abs->getOperand(0) != SE.getNegativeSCEV(Abs->getOperand(1))) + return false; + auto *AmtShifted = dyn_cast(Div->getOperand(1)); + if (!AmtShifted || !AmtShifted->getAPInt().isPowerOf2()) + return false; + X = Abs->getOperand(1); + Amount = AmtShifted->getAPInt().exactLogBase2(); + return true; + }; + const SCEV *X1; + const SCEV *X2; + int32_t Amount; + // ASHR instructions are decomposed into a multiplication of a divide and + // signum expression, because there is no dedicated ASHR SCEV expression. Try + // to match the pattern and emit a ASHR instruction directly. + if (IsSignumExpr(S->getOperand(1), X1) && + IsDivAbs(S->getOperand(0), X2, Amount) && X1 == X2) + return Builder.CreateAShr(expandCodeForImpl(X1, Ty, false), + ConstantInt::get(Ty, Amount), "", true); + // Collect all the mul operands in a loop, along with their associated loops. // Iterate in reverse so that constants are emitted last, all else equal. SmallVector, 8> OpsAndLoops; diff --git a/llvm/test/Transforms/IndVarSimplify/ashr-expansion.ll b/llvm/test/Transforms/IndVarSimplify/ashr-expansion.ll --- a/llvm/test/Transforms/IndVarSimplify/ashr-expansion.ll +++ b/llvm/test/Transforms/IndVarSimplify/ashr-expansion.ll @@ -6,13 +6,8 @@ define float @ashr_expansion_valid(i64 %x, float* %ptr) { ; CHECK-LABEL: @ashr_expansion_valid( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[SMAX:%.*]] = call i64 @llvm.smax.i64(i64 [[X:%.*]], i64 -1) -; CHECK-NEXT: [[SMIN:%.*]] = call i64 @llvm.smin.i64(i64 [[SMAX]], i64 1) -; CHECK-NEXT: [[TMP0:%.*]] = sub i64 0, [[X]] -; CHECK-NEXT: [[SMAX1:%.*]] = call i64 @llvm.smax.i64(i64 [[X]], i64 [[TMP0]]) -; CHECK-NEXT: [[TMP1:%.*]] = lshr i64 [[SMAX1]], 4 -; CHECK-NEXT: [[TMP2:%.*]] = mul nsw i64 [[SMIN]], [[TMP1]] -; CHECK-NEXT: [[UMAX:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP2]], i64 1) +; CHECK-NEXT: [[TMP0:%.*]] = ashr exact i64 [[X:%.*]], 4 +; CHECK-NEXT: [[UMAX:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP0]], i64 1) ; CHECK-NEXT: br label [[LOOP:%.*]] ; CHECK: loop: ; CHECK-NEXT: [[IV:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[IV_NEXT:%.*]], [[LOOP]] ] @@ -50,12 +45,8 @@ define float @ashr_equivalent_expansion(i64 %x, float* %ptr) { ; CHECK-LABEL: @ashr_equivalent_expansion( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[ABS_X:%.*]] = call i64 @llvm.abs.i64(i64 [[X:%.*]], i1 false) -; CHECK-NEXT: [[T0:%.*]] = call i64 @llvm.smax.i64(i64 [[X]], i64 -1) -; CHECK-NEXT: [[T1:%.*]] = call i64 @llvm.smin.i64(i64 [[T0]], i64 1) -; CHECK-NEXT: [[TMP0:%.*]] = lshr i64 [[ABS_X]], 4 -; CHECK-NEXT: [[TMP1:%.*]] = mul i64 [[T1]], [[TMP0]] -; CHECK-NEXT: [[UMAX:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP1]], i64 1) +; CHECK-NEXT: [[TMP0:%.*]] = ashr exact i64 [[X:%.*]], 4 +; CHECK-NEXT: [[UMAX:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP0]], i64 1) ; CHECK-NEXT: br label [[LOOP:%.*]] ; CHECK: loop: ; CHECK-NEXT: [[IV:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[IV_NEXT:%.*]], [[LOOP]] ] @@ -98,11 +89,7 @@ define float @no_ashr_due_to_missing_exact_udiv(i64 %x, float* %ptr) { ; CHECK-LABEL: @no_ashr_due_to_missing_exact_udiv( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[ABS_X:%.*]] = call i64 @llvm.abs.i64(i64 [[X:%.*]], i1 false) -; CHECK-NEXT: [[DIV:%.*]] = udiv i64 [[ABS_X]], 16 -; CHECK-NEXT: [[T0:%.*]] = call i64 @llvm.smax.i64(i64 [[X]], i64 -1) -; CHECK-NEXT: [[T1:%.*]] = call i64 @llvm.smin.i64(i64 [[T0]], i64 1) -; CHECK-NEXT: [[TMP0:%.*]] = mul i64 [[T1]], [[DIV]] +; CHECK-NEXT: [[TMP0:%.*]] = ashr exact i64 [[X:%.*]], 4 ; CHECK-NEXT: [[UMAX:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP0]], i64 1) ; CHECK-NEXT: br label [[LOOP:%.*]] ; CHECK: loop: