diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -12750,11 +12750,25 @@ } // Match the mathematical pattern A - (A / B) * B, where A and B can be -// arbitrary expressions. +// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used +// for URem with constant power-of-2 second operands. // It's not always easy, as A and B can be folded (imagine A is X / 2, and B is // 4, A / B becomes X / 8). bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS, const SCEV *&RHS) { + // Try to match 'zext (trunc A to iB) to iY', which is used + // for URem with constant power-of-2 second operands. Make sure the size of + // the operand A matches the size of the whole expressions. + if (const auto *ZExt = dyn_cast(Expr)) { + const auto *Trunc = dyn_cast(ZExt->getOperand(0)); + if (Trunc && getTypeSizeInBits(Trunc->getOperand()->getType()) == + getTypeSizeInBits(Expr->getType())) { + LHS = Trunc->getOperand(); + RHS = getConstant(Expr->getType(), + 1u << getTypeSizeInBits(Trunc->getType())); + return true; + } + } const auto *Add = dyn_cast(Expr); if (Add == nullptr || Add->getNumOperands() != 2) return false; diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp --- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp +++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp @@ -63,6 +63,11 @@ const SCEV *RHS) { return SE.computeConstantDifference(LHS, RHS); } + + bool matchURem(ScalarEvolution &SE, const SCEV *Expr, const SCEV *&LHS, + const SCEV *&RHS) { + return SE.matchURem(Expr, LHS, RHS); + } }; TEST_F(ScalarEvolutionsTest, SCEVUnknownRAUW) { @@ -1316,4 +1321,34 @@ }); } +TEST_F(ScalarEvolutionsTest, MatchURem) { + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = parseAssemblyString( + "target datalayout = \"e-m:e-p:32:32-f64:32:64-f80:32-n8:16:32-S128\" " + " " + "define void @test(i32 %a, i32 %b) {" + "entry: " + " %rem1 = urem i32 %a, 2" + " %rem2 = urem i32 %a, 5" + " %rem3 = urem i32 %a, %b" + " ret void " + "} ", + Err, C); + + assert(M && "Could not parse module?"); + assert(!verifyModule(*M) && "Must have been well formed!"); + + runWithSE(*M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + for (auto *N : {"rem1", "rem2", "rem3"}) { + auto *URemI = getInstructionByName(F, N); + auto *S = SE.getSCEV(URemI); + const SCEV *LHS, *RHS; + EXPECT_TRUE(matchURem(SE, S, LHS, RHS)); + EXPECT_EQ(LHS, SE.getSCEV(URemI->getOperand(0))); + EXPECT_EQ(RHS, SE.getSCEV(URemI->getOperand(1))); + } + }); +} + } // end namespace llvm