diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -12686,11 +12686,26 @@ } // Match the mathematical pattern A - (A / B) * B, where A and B can be -// arbitrary expressions. +// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used +// for URem with constant power-of-2 second operands. // It's not always easy, as A and B can be folded (imagine A is X / 2, and B is // 4, A / B becomes X / 8). bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS, const SCEV *&RHS) { + // Try to match 'zext (trunc A to iB) to iY', which is used + // for URem with constant power-of-2 second operands. Make sure the size of + // the operand A matches the size of the whole expressions. + if (const auto *ZExt = dyn_cast(Expr)) { + const auto *Trunc = dyn_cast(ZExt->getOperand(0)); + if (Trunc) { + LHS = Trunc->getOperand(); + if (LHS->getType() != Expr->getType()) + LHS = getZeroExtendExpr(LHS, Expr->getType()); + RHS = getConstant(APInt(getTypeSizeInBits(Expr->getType()), 1) + << getTypeSizeInBits(Trunc->getType())); + return true; + } + } const auto *Add = dyn_cast(Expr); if (Add == nullptr || Add->getNumOperands() != 2) return false; diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp --- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp +++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp @@ -63,6 +63,11 @@ const SCEV *RHS) { return SE.computeConstantDifference(LHS, RHS); } + + bool matchURem(ScalarEvolution &SE, const SCEV *Expr, const SCEV *&LHS, + const SCEV *&RHS) { + return SE.matchURem(Expr, LHS, RHS); + } }; TEST_F(ScalarEvolutionsTest, SCEVUnknownRAUW) { @@ -1316,4 +1321,57 @@ }); } +TEST_F(ScalarEvolutionsTest, MatchURem) { + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = parseAssemblyString( + "target datalayout = \"e-m:e-p:32:32-f64:32:64-f80:32-n8:16:32-S128\" " + " " + "define void @test(i32 %a, i32 %b, i16 %c, i64 %d) {" + "entry: " + " %rem1 = urem i32 %a, 2" + " %rem2 = urem i32 %a, 5" + " %rem3 = urem i32 %a, %b" + " %c.ext = zext i16 %c to i32" + " %rem4 = urem i32 %c.ext, 2" + " %ext = zext i32 %rem4 to i64" + " %rem5 = urem i64 %d, 17179869184" + " ret void " + "} ", + Err, C); + + assert(M && "Could not parse module?"); + assert(!verifyModule(*M) && "Must have been well formed!"); + + runWithSE(*M, "test", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + for (auto *N : {"rem1", "rem2", "rem3", "rem5"}) { + auto *URemI = getInstructionByName(F, N); + auto *S = SE.getSCEV(URemI); + const SCEV *LHS, *RHS; + EXPECT_TRUE(matchURem(SE, S, LHS, RHS)); + EXPECT_EQ(LHS, SE.getSCEV(URemI->getOperand(0))); + EXPECT_EQ(RHS, SE.getSCEV(URemI->getOperand(1))); + EXPECT_EQ(LHS->getType(), S->getType()); + EXPECT_EQ(RHS->getType(), S->getType()); + } + + // Check the case where the urem operand is zero-extended. Make sure the + // match results are extended to the size of the input expression. + auto *Ext = getInstructionByName(F, "ext"); + auto *URem1 = getInstructionByName(F, "rem4"); + auto *S = SE.getSCEV(Ext); + const SCEV *LHS, *RHS; + EXPECT_TRUE(matchURem(SE, S, LHS, RHS)); + EXPECT_NE(LHS, SE.getSCEV(URem1->getOperand(0))); + // RHS and URem1->getOperand(1) have different widths, so compare the + // integer values. + EXPECT_EQ(cast(RHS)->getValue()->getZExtValue(), + cast(SE.getSCEV(URem1->getOperand(1))) + ->getValue() + ->getZExtValue()); + EXPECT_EQ(LHS->getType(), S->getType()); + EXPECT_EQ(RHS->getType(), S->getType()); + }); +} + } // end namespace llvm