diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -72,14 +72,14 @@ continue; } - if (isa(I)) { + if (isa(I)) { // We set IsOffset=true, to forbid the memcpy from occurring after the // phi: If one of the phi operands is not based on the alloca, we // would incorrectly omit a write. Worklist.emplace_back(I, true); continue; } - if (isa(I) || isa(I)) { + if (isa(I)) { // If uses of the bitcast are ok, we are ok. Worklist.emplace_back(I, IsOffset); continue; @@ -315,6 +315,19 @@ Worklist.insert(PHI); if (!collectUsersRecursive(*PHI)) return false; + } else if (auto *SI = dyn_cast(Inst)) { + if (!isa(SI->getTrueValue()) || + !isa(SI->getFalseValue())) + return false; + + if (!Worklist.contains(cast(SI->getTrueValue())) || + !Worklist.contains(cast(SI->getFalseValue()))) { + ValuesToRevisit.insert(Inst); + continue; + } + Worklist.insert(SI); + if (!collectUsersRecursive(*SI)) + return false; } else if (isa(Inst)) { Worklist.insert(Inst); if (!collectUsersRecursive(*Inst)) @@ -380,6 +393,13 @@ IC.InsertNewInstWith(NewI, *BC); NewI->takeName(BC); WorkMap[BC] = NewI; + } else if (auto *SI = dyn_cast(I)) { + auto *NewSI = SelectInst::Create( + SI->getCondition(), getReplacement(SI->getTrueValue()), + getReplacement(SI->getFalseValue()), SI->getName(), nullptr, SI); + IC.InsertNewInstWith(NewSI, *SI); + NewSI->takeName(SI); + WorkMap[SI] = NewSI; } else if (auto *MemCpy = dyn_cast(I)) { auto *SrcV = getReplacement(MemCpy->getRawSource()); // The pointer may appear in the destination of a copy, but we don't want to diff --git a/llvm/test/Transforms/InstCombine/replace-alloca-phi.ll b/llvm/test/Transforms/InstCombine/ptr-replace-alloca.ll rename from llvm/test/Transforms/InstCombine/replace-alloca-phi.ll rename to llvm/test/Transforms/InstCombine/ptr-replace-alloca.ll --- a/llvm/test/Transforms/InstCombine/replace-alloca-phi.ll +++ b/llvm/test/Transforms/InstCombine/ptr-replace-alloca.ll @@ -337,6 +337,55 @@ ret i32 %v } +define i8 @select_same_addrspace_remove_alloca(i1 %cond, ptr %p) { +; CHECK-LABEL: @select_same_addrspace_remove_alloca( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[PTR:%.*]] = select i1 %cond, ptr [[G1:@.*]], ptr [[P:%.*]] +; CHECK-NEXT: [[LOAD:%.*]] = load i8, ptr [[PTR]], align 1 +; CHECK-NEXT: ret i8 [[LOAD]] +; +entry: + %alloca = alloca [32 x i8] + call void @llvm.memcpy.p0.p0.i64(ptr %alloca, ptr @g1, i64 256, i1 false) + %ptr = select i1 %cond, ptr %alloca, ptr %p + %load = load i8, ptr %ptr + ret i8 %load +} + +define i8 @select_after_memcpy_keep_alloca(i1 %cond, ptr %p) { +; CHECK-LABEL: define i8 @select_after_memcpy_keep_alloca( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[ALLOCA:%.*]] = alloca [32 x i8], align 1 +; CHECK-NEXT: [[PTR:%.*]] = select i1 %cond, ptr [[ALLOCA]], ptr [[P:%.*]] +; CHECK-NEXT: call void @llvm.memcpy.p0.p0.i64(ptr noundef nonnull align 1 dereferenceable(256) [[PTR]], ptr noundef nonnull align 16 dereferenceable(256) [[G1:@.*]], i64 256, i1 false) +; CHECK-NEXT: [[LOAD:%.*]] = load i8, ptr [[PTR]], align 1 +; CHECK-NEXT: ret i8 [[LOAD]] +; +entry: + %alloca = alloca [32 x i8] + %ptr = select i1 %cond, ptr %alloca, ptr %p + call void @llvm.memcpy.p0.p0.i64(ptr %ptr, ptr @g1, i64 256, i1 false) + %load = load i8, ptr %ptr + ret i8 %load +} + +define i8 @select_diff_addrspace_keep_alloca(i1 %cond, ptr addrspace(1) %p) { +; CHECK-LABEL: @select_diff_addrspace_keep_alloca( +; CHECK-NEXT: entry: +; CHECK-NEXT: %alloca = alloca [32 x i8], align 1, addrspace(1) +; CHECK-NEXT: call void @llvm.memcpy.p1.p0.i64(ptr addrspace(1) noundef align 1 dereferenceable(256) %alloca, ptr noundef nonnull align 16 dereferenceable(256) @g1, i64 256, i1 false) +; CHECK-NEXT: %ptr = select i1 %cond, ptr addrspace(1) %alloca, ptr addrspace(1) %p +; CHECK-NEXT: %load = load i8, ptr addrspace(1) %ptr, align 1 +; CHECK-NEXT: ret i8 %load +; +entry: + %alloca = alloca [32 x i8], addrspace(1) + call void @llvm.memcpy.p1.p0.i64(ptr addrspace(1) %alloca, ptr @g1, i64 256, i1 false) + %ptr = select i1 %cond, ptr addrspace(1) %alloca, ptr addrspace(1) %p + %load = load i8, ptr addrspace(1) %ptr + ret i8 %load +} + declare void @llvm.memcpy.p1.p0.i64(ptr addrspace(1), ptr, i64, i1) declare void @llvm.memcpy.p0.p0.i64(ptr, ptr, i64, i1) declare void @llvm.memcpy.p0.p1.i64(ptr, ptr addrspace(1), i64, i1)