Index: lib/Transforms/InstCombine/InstCombineCasts.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineCasts.cpp +++ lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" +#include "llvm/ADT/SetVector.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/PatternMatch.h" @@ -1779,6 +1780,126 @@ return ExtractElementInst::Create(NewBC, ExtElt->getIndexOperand()); } +/// Check if all users of CI are StoreInsts. +static bool hasStoreUsersOnly(CastInst &CI) { + for (User *U : CI.users()) { + if (!isa(U)) + return false; + } + return true; +} + +/// This function handles following case +/// +/// A -> B cast +/// PHI +/// B -> A cast +/// +/// All the related PHI nodes can be replaced by new PHI nodes with type A. +/// The uses of \p CI can be changed to the new PHI node corresponding to \p PN. +Instruction *InstCombiner::optimizeBitCastFromPhi(CastInst &CI, PHINode *PN) { + // BitCast used by Store can be handled in InstCombineLoadStoreAlloca.cpp. + if (hasStoreUsersOnly(CI)) + return nullptr; + + Value *Src = CI.getOperand(0); + Type *SrcTy = Src->getType(); // Type B + Type *DestTy = CI.getType(); // Type A + + SmallVector PhiWorklist; + SmallSetVector OldPhiNodes; + + // Find all of the A->B casts and PHI nodes. + // We need to inpect all related PHI nodes, but PHIs can be cyclic, so + // OldPhiNodes is used to track all known PHI nodes, before adding a new + // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first. + PhiWorklist.push_back(PN); + OldPhiNodes.insert(PN); + while (!PhiWorklist.empty()) { + auto *OldPN = PhiWorklist.pop_back_val(); + for (Value *IncValue : OldPN->incoming_values()) { + if (isa(IncValue)) + continue; + + if (auto *LI = dyn_cast(IncValue)) { + // If there is a sequence of one or more load instructions, each loaded + // value is used as address of later load instruction, bitcast is + // necessary to change the value type, don't optimize it. For + // simplicity we give up if the load address comes from another load. + Value *Addr = LI->getOperand(0); + if (Addr == &CI || isa(Addr)) + return nullptr; + if (LI->hasOneUse() && LI->isSimple()) + continue; + // If a LoadInst has more than one use, changing the type of loaded + // value may create another bitcast. + return nullptr; + } + + if (auto *PNode = dyn_cast(IncValue)) { + if (OldPhiNodes.insert(PNode)) + PhiWorklist.push_back(PNode); + continue; + } + + auto *BCI = dyn_cast(IncValue); + // We can't handle other instructions. + if (!BCI) + return nullptr; + + // Verify it's a A->B cast. + Type *TyA = BCI->getOperand(0)->getType(); + Type *TyB = BCI->getType(); + if (TyA != DestTy || TyB != SrcTy) + return nullptr; + } + } + + // For each old PHI node, create a corresponding new PHI node with a type A. + SmallDenseMap NewPNodes; + for (auto *OldPN : OldPhiNodes) { + Builder->SetInsertPoint(OldPN); + PHINode *NewPN = Builder->CreatePHI(DestTy, OldPN->getNumOperands()); + NewPNodes[OldPN] = NewPN; + } + + // Fill in the operands of new PHI nodes. + for (auto *OldPN : OldPhiNodes) { + PHINode *NewPN = NewPNodes[OldPN]; + for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) { + Value *V = OldPN->getOperand(j); + Value *NewV = nullptr; + if (auto *C = dyn_cast(V)) { + NewV = Builder->CreateBitCast(C, DestTy); + } else if (auto *LI = dyn_cast(V)) { + Builder->SetInsertPoint(OldPN->getIncomingBlock(j)->getTerminator()); + NewV = Builder->CreateBitCast(LI, DestTy); + Worklist.Add(LI); + } else if (auto *BCI = dyn_cast(V)) { + NewV = BCI->getOperand(0); + } else if (auto *PrevPN = dyn_cast(V)) { + NewV = NewPNodes[PrevPN]; + } + assert(NewV); + NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j)); + } + } + + // If there is a store with type B, change it to type A. + for (User *U : PN->users()) { + auto *SI = dyn_cast(U); + if (SI && SI->isSimple() && SI->getOperand(0) == PN) { + Builder->SetInsertPoint(SI); + Value *NewBC = Builder->CreateBitCast(NewPNodes[PN], SrcTy); + SI->setOperand(0, NewBC); + Worklist.Add(SI); + assert(hasStoreUsersOnly(*cast(NewBC))); + } + } + + return replaceInstUsesWith(CI, NewPNodes[PN]); +} + Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { // If the operands are integer typed then apply the integer transforms, // otherwise just apply the common ones. @@ -1902,6 +2023,11 @@ } } + // Handle the A->B->A cast, and there is an intervening PHI node. + if (PHINode *PN = dyn_cast(Src)) + if (Instruction *I = optimizeBitCastFromPhi(CI, PN)) + return I; + if (Instruction *I = canonicalizeBitCastExtElt(CI, *this, DL)) return I; Index: lib/Transforms/InstCombine/InstCombineInternal.h =================================================================== --- lib/Transforms/InstCombine/InstCombineInternal.h +++ lib/Transforms/InstCombine/InstCombineInternal.h @@ -362,6 +362,7 @@ Instruction *scalarizePHI(ExtractElementInst &EI, PHINode *PN); Value *EvaluateInDifferentElementOrder(Value *V, ArrayRef Mask); Instruction *foldCastedBitwiseLogic(BinaryOperator &I); + Instruction *optimizeBitCastFromPhi(CastInst &CI, PHINode *PN); /// Determine if a pair of casts can be replaced by a single cast. /// Index: test/Transforms/InstCombine/pr25342.ll =================================================================== --- test/Transforms/InstCombine/pr25342.ll +++ test/Transforms/InstCombine/pr25342.ll @@ -0,0 +1,93 @@ +; RUN: opt < %s -instcombine -S | FileCheck %s + +%"struct.std::complex" = type { { float, float } } +@dd = external global %"struct.std::complex", align 4 +@dd2 = external global %"struct.std::complex", align 4 + +define void @_Z3fooi(i32 signext %n) { +entry: + br label %for.cond + +for.cond: + %ldd.sroa.0.0 = phi i32 [ 0, %entry ], [ %5, %for.body ] + %ldd.sroa.6.0 = phi i32 [ 0, %entry ], [ %7, %for.body ] + %i.0 = phi i32 [ 0, %entry ], [ %inc, %for.body ] + %cmp = icmp slt i32 %i.0, %n + br i1 %cmp, label %for.body, label %for.end + +for.body: + %0 = load float, float* getelementptr inbounds (%"struct.std::complex", %"struct.std::complex"* @dd, i64 0, i32 0, i32 0), align 4 + %1 = load float, float* getelementptr inbounds (%"struct.std::complex", %"struct.std::complex"* @dd, i64 0, i32 0, i32 1), align 4 + %2 = load float, float* getelementptr inbounds (%"struct.std::complex", %"struct.std::complex"* @dd2, i64 0, i32 0, i32 0), align 4 + %3 = load float, float* getelementptr inbounds (%"struct.std::complex", %"struct.std::complex"* @dd2, i64 0, i32 0, i32 1), align 4 + %mul.i = fmul float %0, %2 + %mul4.i = fmul float %1, %3 + %sub.i = fsub float %mul.i, %mul4.i + %mul5.i = fmul float %1, %2 + %mul6.i = fmul float %0, %3 + %add.i4 = fadd float %mul5.i, %mul6.i + %4 = bitcast i32 %ldd.sroa.0.0 to float + %add.i = fadd float %sub.i, %4 + %5 = bitcast float %add.i to i32 + %6 = bitcast i32 %ldd.sroa.6.0 to float + %add4.i = fadd float %add.i4, %6 + %7 = bitcast float %add4.i to i32 + %inc = add nsw i32 %i.0, 1 + br label %for.cond + +for.end: + store i32 %ldd.sroa.0.0, i32* bitcast (%"struct.std::complex"* @dd to i32*), align 4 + store i32 %ldd.sroa.6.0, i32* bitcast (float* getelementptr inbounds (%"struct.std::complex", %"struct.std::complex"* @dd, i64 0, i32 0, i32 1) to i32*), align 4 + ret void + +; CHECK: phi float +; CHECK: store float +; CHECK-NOT: bitcast +} + + +define void @multi_phi(i32 signext %n) { +entry: + br label %for.cond + +for.cond: + %ldd.sroa.0.0 = phi i32 [ 0, %entry ], [ %9, %odd.bb ] + %i.0 = phi i32 [ 0, %entry ], [ %inc, %odd.bb ] + %cmp = icmp slt i32 %i.0, %n + br i1 %cmp, label %for.body, label %for.end + +for.body: + %0 = load float, float* getelementptr inbounds (%"struct.std::complex", %"struct.std::complex"* @dd, i64 0, i32 0, i32 0), align 4 + %1 = load float, float* getelementptr inbounds (%"struct.std::complex", %"struct.std::complex"* @dd, i64 0, i32 0, i32 1), align 4 + %2 = load float, float* getelementptr inbounds (%"struct.std::complex", %"struct.std::complex"* @dd2, i64 0, i32 0, i32 0), align 4 + %3 = load float, float* getelementptr inbounds (%"struct.std::complex", %"struct.std::complex"* @dd2, i64 0, i32 0, i32 1), align 4 + %mul.i = fmul float %0, %2 + %mul4.i = fmul float %1, %3 + %sub.i = fsub float %mul.i, %mul4.i + %4 = bitcast i32 %ldd.sroa.0.0 to float + %add.i = fadd float %sub.i, %4 + %5 = bitcast float %add.i to i32 + %inc = add nsw i32 %i.0, 1 + %bit0 = and i32 %inc, 1 + %even = icmp slt i32 %bit0, 1 + br i1 %even, label %even.bb, label %odd.bb + +even.bb: + %6 = bitcast i32 %5 to float + %7 = fadd float %sub.i, %6 + %8 = bitcast float %7 to i32 + br label %odd.bb + +odd.bb: + %9 = phi i32 [ %5, %for.body ], [ %8, %even.bb ] + br label %for.cond + +for.end: + store i32 %ldd.sroa.0.0, i32* bitcast (%"struct.std::complex"* @dd to i32*), align 4 + ret void + +; CHECK-LABEL: @multi_phi( +; CHECK: phi float +; CHECK: store float +; CHECK-NOT: bitcast +} Index: test/Transforms/InstCombine/pr27703.ll =================================================================== --- test/Transforms/InstCombine/pr27703.ll +++ test/Transforms/InstCombine/pr27703.ll @@ -0,0 +1,20 @@ +; RUN: opt < %s -instcombine -S | FileCheck %s + +define void @mem() { +bb: + br label %bb6 + +bb6: + %.0 = phi i8** [ undef, %bb ], [ %t2, %bb6 ] + %tmp = load i8*, i8** %.0, align 8 + %bc = bitcast i8* %tmp to i8** + %t1 = load i8*, i8** %bc, align 8 + %t2 = bitcast i8* %t1 to i8** + br label %bb6 + +bb206: + ret void +; CHECK: phi +; CHECK: bitcast +; CHECK: load +} Index: test/Transforms/InstCombine/pr27996.ll =================================================================== --- test/Transforms/InstCombine/pr27996.ll +++ test/Transforms/InstCombine/pr27996.ll @@ -0,0 +1,41 @@ +; RUN: opt < %s -instcombine -S | FileCheck %s + + +@i = constant i32 1, align 4 +@f = constant float 0x3FF19999A0000000, align 4 +@cmp = common global i32 0, align 4 +@resf = common global float* null, align 8 +@resi = common global i32* null, align 8 + +define i32 @foo() { +entry: + br label %while.cond + +while.cond: + %res.0 = phi i32* [ null, %entry ], [ @i, %if.then ], [ bitcast (float* @f to i32*), %if.else ] + %0 = load i32, i32* @cmp, align 4 + %shr = ashr i32 %0, 1 + store i32 %shr, i32* @cmp, align 4 + %tobool = icmp ne i32 %shr, 0 + br i1 %tobool, label %while.body, label %while.end + +while.body: + %and = and i32 %shr, 1 + %tobool1 = icmp ne i32 %and, 0 + br i1 %tobool1, label %if.then, label %if.else + +if.then: + br label %while.cond + +if.else: + br label %while.cond + +while.end: + %1 = bitcast i32* %res.0 to float* + store float* %1, float** @resf, align 8 + store i32* %res.0, i32** @resi, align 8 + ret i32 0 + +; CHECK-NOT: bitcast i32 +} +