diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -265,9 +265,11 @@ void replacePointer(Instruction &I, Value *V); private: + bool collectUsersRecursive(Instruction &I); void replace(Instruction *I); Value *getReplacement(Value *I); + SmallPtrSet ValuesToRevisit; SmallSetVector Worklist; MapVector WorkMap; InstCombinerImpl &IC; @@ -275,15 +277,47 @@ } // end anonymous namespace bool PointerReplacer::collectUsers(Instruction &I) { + if (!collectUsersRecursive(I)) + return false; + + // Ensure that all outstanding (indirect) users of I + // are inserted into the Worklist. Return false + // otherwise. + for (auto *Inst : ValuesToRevisit) + if (!Worklist.contains(Inst)) + return false; + return true; +} + +bool PointerReplacer::collectUsersRecursive(Instruction &I) { for (auto *U : I.users()) { auto *Inst = cast(&*U); if (auto *Load = dyn_cast(Inst)) { if (Load->isVolatile()) return false; Worklist.insert(Load); - } else if (isa(Inst) || isa(Inst)) { + } else if (auto *PHI = dyn_cast(Inst)) { + // All incoming values must be instructions for replacability + if (any_of(PHI->incoming_values(), + [](Value *V) { return !isa(V); })) + return false; + + // If at least one incoming value of the PHI is not in Worklist, + // store the PHI for revisiting and skip this iteration of the + // loop. + if (any_of(PHI->incoming_values(), [this](Value *V) { + return !Worklist.contains(cast(V)); + })) { + ValuesToRevisit.insert(Inst); + continue; + } + + Worklist.insert(PHI); + if (!collectUsersRecursive(*PHI)) + return false; + } else if (isa(Inst)) { Worklist.insert(Inst); - if (!collectUsers(*Inst)) + if (!collectUsersRecursive(*Inst)) return false; } else if (auto *MI = dyn_cast(Inst)) { if (MI->isVolatile()) @@ -318,6 +352,14 @@ IC.InsertNewInstWith(NewI, *LT); IC.replaceInstUsesWith(*LT, NewI); WorkMap[LT] = NewI; + } else if (auto *PHI = dyn_cast(I)) { + Type *NewTy = getReplacement(PHI->getIncomingValue(0))->getType(); + auto *NewPHI = PHINode::Create(NewTy, PHI->getNumIncomingValues(), + PHI->getName(), PHI); + for (unsigned int I = 0; I < PHI->getNumIncomingValues(); ++I) + NewPHI->addIncoming(getReplacement(PHI->getIncomingValue(I)), + PHI->getIncomingBlock(I)); + WorkMap[PHI] = NewPHI; } else if (auto *GEP = dyn_cast(I)) { auto *V = getReplacement(GEP->getPointerOperand()); assert(V && "Operand not replaced"); diff --git a/llvm/test/Transforms/InstCombine/replace-alloca-phi.ll b/llvm/test/Transforms/InstCombine/replace-alloca-phi.ll --- a/llvm/test/Transforms/InstCombine/replace-alloca-phi.ll +++ b/llvm/test/Transforms/InstCombine/replace-alloca-phi.ll @@ -9,18 +9,14 @@ define i8 @remove_alloca_use_arg(i1 %cond) { ; CHECK-LABEL: @remove_alloca_use_arg( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [32 x i8], align 4, addrspace(1) -; CHECK-NEXT: call void @llvm.memcpy.p1.p0.i64(ptr addrspace(1) noundef align 4 dereferenceable(256) [[ALLOCA]], ptr noundef nonnull align 16 dereferenceable(256) @g1, i64 256, i1 false) ; CHECK-NEXT: br i1 [[COND:%.*]], label [[IF:%.*]], label [[ELSE:%.*]] ; CHECK: if: -; CHECK-NEXT: [[VAL_IF:%.*]] = getelementptr inbounds [32 x i8], ptr addrspace(1) [[ALLOCA]], i64 0, i64 2 ; CHECK-NEXT: br label [[SINK:%.*]] ; CHECK: else: -; CHECK-NEXT: [[VAL_ELSE:%.*]] = getelementptr inbounds [32 x i8], ptr addrspace(1) [[ALLOCA]], i64 0, i64 1 ; CHECK-NEXT: br label [[SINK]] ; CHECK: sink: -; CHECK-NEXT: [[PTR:%.*]] = phi ptr addrspace(1) [ [[VAL_IF]], [[IF]] ], [ [[VAL_ELSE]], [[ELSE]] ] -; CHECK-NEXT: [[LOAD:%.*]] = load i8, ptr addrspace(1) [[PTR]], align 1 +; CHECK-NEXT: [[PTR1:%.*]] = phi ptr [ getelementptr inbounds ([32 x i8], ptr @g1, i64 0, i64 2), [[IF]] ], [ getelementptr inbounds ([32 x i8], ptr @g1, i64 0, i64 1), [[ELSE]] ] +; CHECK-NEXT: [[LOAD:%.*]] = load i8, ptr [[PTR1]], align 1 ; CHECK-NEXT: ret i8 [[LOAD]] ; entry: @@ -116,18 +112,14 @@ define i8 @loop_phi_remove_alloca(i1 %cond) { ; CHECK-LABEL: @loop_phi_remove_alloca( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [32 x i8], align 4, addrspace(1) -; CHECK-NEXT: call void @llvm.memcpy.p1.p0.i64(ptr addrspace(1) noundef align 4 dereferenceable(256) [[ALLOCA]], ptr noundef nonnull align 16 dereferenceable(256) @g1, i64 256, i1 false) -; CHECK-NEXT: [[VAL1:%.*]] = getelementptr inbounds [32 x i8], ptr addrspace(1) [[ALLOCA]], i64 0, i64 1 ; CHECK-NEXT: br label [[BB_0:%.*]] ; CHECK: bb.0: -; CHECK-NEXT: [[PTR:%.*]] = phi ptr addrspace(1) [ [[VAL1]], [[ENTRY:%.*]] ], [ [[VAL2:%.*]], [[BB_1:%.*]] ] +; CHECK-NEXT: [[PTR1:%.*]] = phi ptr [ getelementptr inbounds ([32 x i8], ptr @g1, i64 0, i64 1), [[ENTRY:%.*]] ], [ getelementptr inbounds ([32 x i8], ptr @g1, i64 0, i64 2), [[BB_1:%.*]] ] ; CHECK-NEXT: br i1 [[COND:%.*]], label [[BB_1]], label [[EXIT:%.*]] ; CHECK: bb.1: -; CHECK-NEXT: [[VAL2]] = getelementptr inbounds [32 x i8], ptr addrspace(1) [[ALLOCA]], i64 0, i64 2 ; CHECK-NEXT: br label [[BB_0]] ; CHECK: exit: -; CHECK-NEXT: [[LOAD:%.*]] = load i8, ptr addrspace(1) [[PTR]], align 1 +; CHECK-NEXT: [[LOAD:%.*]] = load i8, ptr [[PTR1]], align 1 ; CHECK-NEXT: ret i8 [[LOAD]] ; entry: @@ -174,21 +166,17 @@ ret i32 %v } -define i8 @loop_phi_late_memtransfer(i1 %cond) { -; CHECK-LABEL: @loop_phi_late_memtransfer( +define i8 @loop_phi_late_memtransfer_remove_alloca(i1 %cond) { +; CHECK-LABEL: @loop_phi_late_memtransfer_remove_alloca( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [32 x i8], align 4, addrspace(1) -; CHECK-NEXT: [[VAL1:%.*]] = getelementptr inbounds [32 x i8], ptr addrspace(1) [[ALLOCA]], i64 0, i64 1 ; CHECK-NEXT: br label [[BB_0:%.*]] ; CHECK: bb.0: -; CHECK-NEXT: [[PTR:%.*]] = phi ptr addrspace(1) [ [[VAL1]], [[ENTRY:%.*]] ], [ [[VAL2:%.*]], [[BB_1:%.*]] ] +; CHECK-NEXT: [[PTR1:%.*]] = phi ptr [ getelementptr inbounds ([32 x i8], ptr @g1, i64 0, i64 1), [[ENTRY:%.*]] ], [ getelementptr inbounds ([32 x i8], ptr @g1, i64 0, i64 2), [[BB_1:%.*]] ] ; CHECK-NEXT: br i1 [[COND:%.*]], label [[BB_1]], label [[EXIT:%.*]] ; CHECK: bb.1: -; CHECK-NEXT: [[VAL2]] = getelementptr inbounds [32 x i8], ptr addrspace(1) [[ALLOCA]], i64 0, i64 2 -; CHECK-NEXT: call void @llvm.memcpy.p1.p0.i64(ptr addrspace(1) noundef align 4 dereferenceable(256) [[ALLOCA]], ptr noundef nonnull align 16 dereferenceable(256) @g1, i64 256, i1 false) ; CHECK-NEXT: br label [[BB_0]] ; CHECK: exit: -; CHECK-NEXT: [[LOAD:%.*]] = load i8, ptr addrspace(1) [[PTR]], align 1 +; CHECK-NEXT: [[LOAD:%.*]] = load i8, ptr [[PTR1]], align 1 ; CHECK-NEXT: ret i8 [[LOAD]] ; entry: