Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -2338,12 +2338,34 @@ // Check for truncates. If all the operands are truncated from the same // type, see if factoring out the truncate would permit the result to be - // folded. eg., trunc(x) + m*trunc(n) --> trunc(x + trunc(m)*n) + // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y) // if the contents of the resulting outer trunc fold to something simple. - for (; Idx < Ops.size() && isa(Ops[Idx]); ++Idx) { - const SCEVTruncateExpr *Trunc = cast(Ops[Idx]); - Type *DstType = Trunc->getType(); - Type *SrcType = Trunc->getOperand()->getType(); + auto FindTruncSrcType = [&]() -> Type * { + // Go through the available Ops to see if we have a compatible trunc() to + // start processing. + if (auto *T = dyn_cast(Ops[Idx])) + return T->getOperand()->getType(); + for (unsigned i = Idx; i < Ops.size() && Ops[i]->getSCEVType() <= scMulExpr; + ++i) { + if (const auto *Mul = dyn_cast(Ops[i])) { + bool Ok = true; + for (unsigned j = 0, e = Mul->getNumOperands(); Ok && j < e; ++j) { + const auto *Op = Mul->getOperand(j); + if (const auto *T = dyn_cast(Op)) { + return T->getOperand()->getType(); + } else if (!isa(Op)) { + Ok = false; + } + } + if (!Ok) + break; + } else { + break; + } + } + return nullptr; + }; + if (auto *SrcType = FindTruncSrcType()) { SmallVector LargeOps; bool Ok = true; // Check all the operands to see if they can be represented in the @@ -2386,7 +2408,7 @@ const SCEV *Fold = getAddExpr(LargeOps, Flags, Depth + 1); // If it folds to something simple, use it. Otherwise, don't. if (isa(Fold) || isa(Fold)) - return getTruncateExpr(Fold, DstType); + return getTruncateExpr(Fold, Ty); } } Index: unittests/Analysis/ScalarEvolutionTest.cpp =================================================================== --- unittests/Analysis/ScalarEvolutionTest.cpp +++ unittests/Analysis/ScalarEvolutionTest.cpp @@ -1009,5 +1009,37 @@ auto Result = SE.createAddRecFromPHIWithCasts(cast(Expr)); } +TEST_F(ScalarEvolutionsTest, SCEVFoldSumOfTruncs) { + // Verify that the following SCEV gets folded to a zero: + // (-1 * (trunc i64 (-1 * %0) to i32)) + (-1 * (trunc i64 %0 to i32) + Type *ArgTy = Type::getInt64Ty(Context); + Type *Int32Ty = Type::getInt32Ty(Context); + SmallVector Types; + Types.push_back(ArgTy); + FunctionType *FTy = FunctionType::get(Type::getVoidTy(Context), Types, false); + Function *F = cast(M.getOrInsertFunction("f", FTy)); + BasicBlock *BB = BasicBlock::Create(Context, "entry", F); + ReturnInst::Create(Context, nullptr, BB); + + ScalarEvolution SE = buildSE(*F); + + auto *Arg = &*(F->arg_begin()); + const auto *ArgSCEV = SE.getSCEV(Arg); + + // Build the SCEV + const auto *A0 = SE.getNegativeSCEV(ArgSCEV); + const auto *A1 = SE.getTruncateExpr(A0, Int32Ty); + const auto *A = SE.getNegativeSCEV(A1); + + const auto *B0 = SE.getTruncateExpr(ArgSCEV, Int32Ty); + const auto *B = SE.getNegativeSCEV(B0); + + const auto *Expr = SE.getAddExpr(A, B); + dbgs() << "DDN\nExpr: " << *Expr << "\n"; + // Verify that the SCEV was folded to 0 + const auto *ZeroConst = SE.getConstant(Int32Ty, 0); + EXPECT_EQ(Expr, ZeroConst); +} + } // end anonymous namespace } // end namespace llvm