Index: lib/Transforms/InstCombine/InstCombineAndOrXor.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -1193,6 +1193,37 @@ return nullptr; } +bool InstCombiner::shouldFoldCast(CastInst *CI) { + Value *CastSrc = CI->getOperand(0); + + // Noop casts and casts of constants should be eliminated trivially. + if (CI->getSrcTy() == CI->getDestTy() || isa(CastSrc)) return false; + + // If this cast is paired with another cast that can be eliminated, we prefer + // to have it eliminated. + if (const CastInst *PrecedingCI = dyn_cast(CastSrc)) + if (isEliminableCastPair(PrecedingCI, CI)) + return false; + + // If this is a vector sext from a compare, then we don't want to break the + // idiom where each element of the extended vector is either zero or all ones. + if (CI->getOpcode() == Instruction::SExt && + isa(CastSrc) && CI->getDestTy()->isVectorTy()) + return false; + + // If this cast is a zext whose source is an icmp, we prefer to have it + // eliminated. + if (ZExtInst *ZExt = dyn_cast(CI)) { + if (ICmpInst *ICmp = dyn_cast(CastSrc)) { + if (transformZExtICmp(*ZExt, ICmp, false)) { + return false; + } + } + } + + return true; +} + Instruction *InstCombiner::foldCastedBitwiseLogic(BinaryOperator &I) { auto LogicOpc = I.getOpcode(); assert((LogicOpc == Instruction::And || LogicOpc == Instruction::Or || @@ -1238,11 +1269,8 @@ Value *Cast1Src = Cast1->getOperand(0); // fold (logic (cast A), (cast B)) -> (cast (logic A, B)) - - // Only do this if the casts both really cause code to be generated. - if ((!isa(Cast0Src) || !isa(Cast1Src)) && - ShouldOptimizeCast(CastOpcode, Cast0Src, DestTy) && - ShouldOptimizeCast(CastOpcode, Cast1Src, DestTy)) { + if (shouldFoldCast(Cast0) && + shouldFoldCast(Cast1)) { Value *NewOp = Builder->CreateBinOp(LogicOpc, Cast0Src, Cast1Src, I.getName()); return CastInst::Create(CastOpcode, NewOp, DestTy); Index: lib/Transforms/InstCombine/InstCombineCasts.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineCasts.cpp +++ lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -227,20 +227,14 @@ return InsertNewInstWith(Res, *I); } - -/// This function is a wrapper around CastInst::isEliminableCastPair. It -/// simply extracts arguments and returns what that function returns. -static Instruction::CastOps -isEliminableCastPair(const CastInst *CI, ///< First cast instruction - unsigned opcode, ///< Opcode for the second cast - Type *DstTy, ///< Target type for the second cast - const DataLayout &DL) { - Type *SrcTy = CI->getOperand(0)->getType(); // A from above - Type *MidTy = CI->getType(); // B from above - - // Get the opcodes of the two Cast instructions - Instruction::CastOps firstOp = Instruction::CastOps(CI->getOpcode()); - Instruction::CastOps secondOp = Instruction::CastOps(opcode); +Instruction::CastOps InstCombiner::isEliminableCastPair(const CastInst *CI1, + const CastInst *CI2) { + Type *SrcTy = CI1->getSrcTy(); + Type *MidTy = CI1->getDestTy(); + Type *DstTy = CI2->getDestTy(); + + Instruction::CastOps firstOp = Instruction::CastOps(CI1->getOpcode()); + Instruction::CastOps secondOp = Instruction::CastOps(CI2->getOpcode()); Type *SrcIntPtrTy = SrcTy->isPtrOrPtrVectorTy() ? DL.getIntPtrType(SrcTy) : nullptr; Type *MidIntPtrTy = @@ -260,30 +254,6 @@ return Instruction::CastOps(Res); } -/// Return true if the cast from "V to Ty" actually results in any code being -/// generated and is interesting to optimize out. -/// If the cast can be eliminated by some other simple transformation, we prefer -/// to do the simplification first. -bool InstCombiner::ShouldOptimizeCast(Instruction::CastOps opc, const Value *V, - Type *Ty) { - // Noop casts and casts of constants should be eliminated trivially. - if (V->getType() == Ty || isa(V)) return false; - - // If this is another cast that can be eliminated, we prefer to have it - // eliminated. - if (const CastInst *CI = dyn_cast(V)) - if (isEliminableCastPair(CI, opc, Ty, DL)) - return false; - - // If this is a vector sext from a compare, then we don't want to break the - // idiom where each element of the extended vector is either zero or all ones. - if (opc == Instruction::SExt && isa(V) && Ty->isVectorTy()) - return false; - - return true; -} - - /// @brief Implement the transforms common to all CastInst visitors. Instruction *InstCombiner::commonCastTransforms(CastInst &CI) { Value *Src = CI.getOperand(0); @@ -292,7 +262,7 @@ // eliminate it now. if (CastInst *CSrc = dyn_cast(Src)) { // A->B->C cast if (Instruction::CastOps opc = - isEliminableCastPair(CSrc, CI.getOpcode(), CI.getType(), DL)) { + isEliminableCastPair(CSrc, &CI)) { // The first cast (CSrc) is eliminable so we need to fix up or replace // the second cast (CI). CSrc will then have a good chance of being dead. return CastInst::Create(opc, CSrc->getOperand(0), CI.getType()); @@ -578,10 +548,8 @@ return nullptr; } -/// Transform (zext icmp) to bitwise / integer operations in order to eliminate -/// the icmp. -Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, Instruction &CI, - bool DoXform) { +Instruction *InstCombiner::transformZExtICmp(ZExtInst &CI, ICmpInst *ICI, + bool DoTransform) { // If we are just checking for a icmp eq of a single bit and zext'ing it // to an integer, then shift the bit to the appropriate place and then // cast to integer to avoid the comparison. @@ -592,7 +560,7 @@ // zext (x >s -1) to i32 --> (x>>u31)^1 true if signbit clear. if ((ICI->getPredicate() == ICmpInst::ICMP_SLT && Op1CV == 0) || (ICI->getPredicate() == ICmpInst::ICMP_SGT && Op1CV.isAllOnesValue())) { - if (!DoXform) return ICI; + if (!DoTransform) return ICI; Value *In = ICI->getOperand(0); Value *Sh = ConstantInt::get(In->getType(), @@ -627,7 +595,7 @@ APInt KnownZeroMask(~KnownZero); if (KnownZeroMask.isPowerOf2()) { // Exactly 1 possible 1? - if (!DoXform) return ICI; + if (!DoTransform) return ICI; bool isNE = ICI->getPredicate() == ICmpInst::ICMP_NE; if (Op1CV != 0 && (Op1CV != KnownZeroMask)) { @@ -678,7 +646,7 @@ APInt KnownBits = KnownZeroLHS | KnownOneLHS; APInt UnknownBit = ~KnownBits; if (UnknownBit.countPopulation() == 1) { - if (!DoXform) return ICI; + if (!DoTransform) return ICI; Value *Result = Builder->CreateXor(LHS, RHS); @@ -918,7 +886,7 @@ } if (ICmpInst *ICI = dyn_cast(Src)) - return transformZExtICmp(ICI, CI); + return transformZExtICmp(CI, ICI); BinaryOperator *SrcI = dyn_cast(Src); if (SrcI && SrcI->getOpcode() == Instruction::Or) { @@ -927,8 +895,8 @@ ICmpInst *LHS = dyn_cast(SrcI->getOperand(0)); ICmpInst *RHS = dyn_cast(SrcI->getOperand(1)); if (LHS && RHS && LHS->hasOneUse() && RHS->hasOneUse() && - (transformZExtICmp(LHS, CI, false) || - transformZExtICmp(RHS, CI, false))) { + (transformZExtICmp(CI, LHS, false) || + transformZExtICmp(CI, RHS, false))) { Value *LCast = Builder->CreateZExt(LHS, CI.getType(), LHS->getName()); Value *RCast = Builder->CreateZExt(RHS, CI.getType(), RHS->getName()); return BinaryOperator::Create(Instruction::Or, LCast, RCast); Index: lib/Transforms/InstCombine/InstCombineInternal.h =================================================================== --- lib/Transforms/InstCombine/InstCombineInternal.h +++ lib/Transforms/InstCombine/InstCombineInternal.h @@ -355,14 +355,15 @@ SmallVectorImpl &NewIndices); Instruction *FoldOpIntoSelect(Instruction &Op, SelectInst *SI); - /// \brief Classify whether a cast is worth optimizing. + /// \brief Classify whether it is worthy to remove a cast by folding it with + /// another cast. /// - /// Returns true if the cast from "V to Ty" actually results in any code - /// being generated and is interesting to optimize out. If the cast can be - /// eliminated by some other simple transformation, we prefer to do the - /// simplification first. - bool ShouldOptimizeCast(Instruction::CastOps opcode, const Value *V, - Type *Ty); + /// This is needed to determine whether the simplification of + /// (logic (cast A), (cast B)) to (cast (logic A, B)) is worthy. + /// + /// Returns true if this cast actually results in any code being generated and + /// if it cannot already be eliminated by some other transformation. + bool shouldFoldCast(CastInst *CI); /// \brief Try to optimize a sequence of instructions checking if an operation /// on LHS and RHS overflows. @@ -385,8 +386,16 @@ bool transformConstExprCastCall(CallSite CS); Instruction *transformCallThroughTrampoline(CallSite CS, IntrinsicInst *Tramp); - Instruction *transformZExtICmp(ICmpInst *ICI, Instruction &CI, - bool DoXform = true); + + /// \brief Transform (zext icmp) to bitwise / integer operations in order to + /// eliminate it. + /// + /// Pass `false` for \p DoTransform if you just want to test whether the + /// given (zext icmp) would be transformed. Pass `true` if you also want to + /// actually perform the transformation. + Instruction *transformZExtICmp(ZExtInst &CI, ICmpInst *ICI, + bool DoTransform = true); + Instruction *transformSExtICmp(ICmpInst *ICI, Instruction &CI); bool WillNotOverflowSignedAdd(Value *LHS, Value *RHS, Instruction &CxtI); bool WillNotOverflowSignedSub(Value *LHS, Value *RHS, Instruction &CxtI); @@ -397,6 +406,11 @@ Value *EvaluateInDifferentElementOrder(Value *V, ArrayRef Mask); Instruction *foldCastedBitwiseLogic(BinaryOperator &I); + /// \brief This function is a wrapper around CastInst::isEliminableCastPair. + /// It simply extracts arguments and returns what that function returns. + Instruction::CastOps isEliminableCastPair(const CastInst *CI1, + const CastInst *CI2); + public: /// \brief Inserts an instruction \p New before instruction \p Old /// Index: test/Transforms/InstCombine/fold-casts-of-icmps.ll =================================================================== --- /dev/null +++ test/Transforms/InstCombine/fold-casts-of-icmps.ll @@ -0,0 +1,42 @@ +; RUN: opt < %s -instcombine -S | FileCheck %s +; +; These tests assert that the folding of expressions of the form +; (logic (cast A), (cast B)) to (cast (logic A, B)) also works if A or B are +; icmp instructions. + +; CHECK-LABEL: @test1( +; CHECK-NEXT: %1 = icmp sgt i64 %a, %b +; CHECK-NEXT: %2 = icmp slt i64 %a, %c +; CHECK-NEXT: %3 = and i1 %1, %2 +; CHECK-NEXT: %4 = zext i1 %3 to i8 +; CHECK-NEXT: ret i8 %4 +define i8 @test1(i64 %a, i64 %b, i64 %c) { + %1 = icmp sgt i64 %a, %b + %2 = zext i1 %1 to i8 + %3 = icmp slt i64 %a, %c + %4 = zext i1 %3 to i8 + %5 = and i8 %2, %4 + ret i8 %5 +} + +; Assert that casts are also folded accross multiple logical operators. +; CHECK-LABEL: @test2( +; CHECK-NEXT: %1 = icmp sgt i64 %a, %b +; CHECK-NEXT: %2 = icmp slt i64 %a, %c +; CHECK-NEXT: %3 = and i1 %1, %2 +; CHECK-NEXT: %4 = icmp eq i64 %a, %d +; CHECK-NEXT: %5 = or i1 %3, %4 +; CHECK-NEXT: %6 = zext i1 %5 to i8 +; CHECK-NEXT: ret i8 %6 +define i8 @test2(i64 %a, i64 %b, i64 %c, i64 %d) { + %1 = icmp sgt i64 %a, %b + %2 = zext i1 %1 to i8 + %3 = icmp slt i64 %a, %c + %4 = zext i1 %3 to i8 + %5 = and i8 %2, %4 + %6 = icmp eq i64 %a, %d + %7 = zext i1 %6 to i8 + %8 = or i8 %5, %7 + ret i8 %8 +} +