diff --git a/llvm/lib/Analysis/TypeMetadataUtils.cpp b/llvm/lib/Analysis/TypeMetadataUtils.cpp --- a/llvm/lib/Analysis/TypeMetadataUtils.cpp +++ b/llvm/lib/Analysis/TypeMetadataUtils.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/TypeMetadataUtils.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" @@ -80,9 +81,23 @@ const Module *M = CI->getParent()->getParent()->getParent(); // Find llvm.assume intrinsics for this llvm.type.test call. + // Look through phis as multiple type tests may have been merged to feed into + // the same assume. We need to make sure to handle recursive phis. + SmallPtrSet Visited; + SmallVector Worklist; for (const Use &CIU : CI->uses()) - if (auto *Assume = dyn_cast(CIU.getUser())) + Worklist.push_back(CIU.getUser()); + while (!Worklist.empty()) { + Value *V = Worklist.pop_back_val(); + if (!Visited.insert(V).second) + continue; + if (auto *Assume = dyn_cast(V)) { Assumes.push_back(Assume); + } else if (auto *Phi = dyn_cast(V)) { + for (const Use &PIU : Phi->uses()) + Worklist.push_back(PIU.getUser()); + } + } // If we found any, search for virtual calls based on %p and add them to // DevirtCalls. diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp --- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -56,6 +56,7 @@ #include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/Triple.h" @@ -1896,6 +1897,8 @@ // points to a member of the type identifier %md. Group calls by (type ID, // offset) pair (effectively the identity of the virtual function) and store // to CallSlots. + + SmallPtrSet DeletedAssumes; for (Use &U : llvm::make_early_inc_range(TypeTestFunc->uses())) { auto *CI = dyn_cast(U.getUser()); if (!CI) @@ -1918,8 +1921,10 @@ auto RemoveTypeTestAssumes = [&]() { // We no longer need the assumes or the type test. - for (auto Assume : Assumes) - Assume->eraseFromParent(); + for (auto *Assume : Assumes) { + if (DeletedAssumes.insert(Assume).second) + Assume->eraseFromParent(); + } // We can't use RecursivelyDeleteTriviallyDeadInstructions here because we // may use the vtable argument later. if (CI->use_empty()) diff --git a/llvm/test/Transforms/WholeProgramDevirt/devirt-single-impl-assume-phi.ll b/llvm/test/Transforms/WholeProgramDevirt/devirt-single-impl-assume-phi.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/WholeProgramDevirt/devirt-single-impl-assume-phi.ll @@ -0,0 +1,83 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -S -passes=wholeprogramdevirt -whole-program-visibility %s 2>&1 | FileCheck %s + +; Check that the assume in assume(phi(typetest1, typetest2)) is removed by WPD. +define void @call(ptr %obj, i1 %i) { +; CHECK-LABEL: @call( +; CHECK-NEXT: entry: +; CHECK-NEXT: br i1 [[I:%.*]], label [[BB1:%.*]], label [[BB2:%.*]] +; CHECK: bb1: +; CHECK-NEXT: [[VTABLE:%.*]] = load ptr, ptr [[OBJ:%.*]], align 8 +; CHECK-NEXT: [[FPTR:%.*]] = load ptr, ptr [[VTABLE]], align 8 +; CHECK-NEXT: call void [[FPTR]](ptr [[OBJ]]) +; CHECK-NEXT: [[P:%.*]] = call i1 @llvm.type.test(ptr [[VTABLE]], metadata !"typeid") +; CHECK-NEXT: br label [[FIN:%.*]] +; CHECK: bb2: +; CHECK-NEXT: [[VTABLE2:%.*]] = load ptr, ptr [[OBJ]], align 8 +; CHECK-NEXT: [[FPTR2:%.*]] = load ptr, ptr [[VTABLE2]], align 8 +; CHECK-NEXT: call void [[FPTR2]](ptr [[OBJ]]) +; CHECK-NEXT: [[P2:%.*]] = call i1 @llvm.type.test(ptr [[VTABLE2]], metadata !"typeid") +; CHECK-NEXT: br label [[FIN]] +; CHECK: fin: +; CHECK-NEXT: [[PN:%.*]] = phi i1 [ [[P]], [[BB1]] ], [ [[P2]], [[BB2]] ] +; CHECK-NEXT: ret void +; +entry: + br i1 %i, label %bb1, label %bb2 +bb1: + %vtable = load ptr, ptr %obj + %fptr = load ptr, ptr %vtable + call void %fptr(ptr %obj) + %p = call i1 @llvm.type.test(ptr %vtable, metadata !"typeid") + br label %fin +bb2: + %vtable2 = load ptr, ptr %obj + %fptr2 = load ptr, ptr %vtable2 + call void %fptr2(ptr %obj) + %p2 = call i1 @llvm.type.test(ptr %vtable2, metadata !"typeid") + br label %fin +fin: + %pn = phi i1 [ %p, %bb1 ], [ %p2, %bb2 ] + call void @llvm.assume(i1 %pn) + ret void +} + +; Check that we handle phi cycles. +define void @call2(ptr %obj, i1 %i) { +; CHECK-LABEL: @call2( +; CHECK-NEXT: entry: +; CHECK-NEXT: br label [[BB1:%.*]] +; CHECK: bb1: +; CHECK-NEXT: [[Z:%.*]] = phi i1 [ false, [[ENTRY:%.*]] ], [ [[P:%.*]], [[BB1]] ], [ [[PN:%.*]], [[BB2:%.*]] ] +; CHECK-NEXT: [[VTABLE:%.*]] = load ptr, ptr [[OBJ:%.*]], align 8 +; CHECK-NEXT: [[FPTR:%.*]] = load ptr, ptr [[VTABLE]], align 8 +; CHECK-NEXT: call void [[FPTR]](ptr [[OBJ]]) +; CHECK-NEXT: [[P]] = call i1 @llvm.type.test(ptr [[VTABLE]], metadata !"typeid") +; CHECK-NEXT: br i1 [[I:%.*]], label [[BB2]], label [[BB1]] +; CHECK: bb2: +; CHECK-NEXT: [[PN]] = phi i1 [ [[Z]], [[BB1]] ] +; CHECK-NEXT: br i1 [[I]], label [[FIN:%.*]], label [[BB1]] +; CHECK: fin: +; CHECK-NEXT: ret void +; +entry: + br label %bb1 +bb1: + %z = phi i1 [ false, %entry ], [ %p, %bb1 ], [ %pn, %bb2 ] + %vtable = load ptr, ptr %obj + %fptr = load ptr, ptr %vtable + call void %fptr(ptr %obj) + %p = call i1 @llvm.type.test(ptr %vtable, metadata !"typeid") + br i1 %i, label %bb2, label %bb1 +bb2: + %pn = phi i1 [ %z, %bb1 ] + br i1 %i, label %fin, label %bb1 +fin: + call void @llvm.assume(i1 %pn) + ret void +} + +declare i1 @llvm.type.test(ptr, metadata) +declare void @llvm.assume(i1) + +!0 = !{i32 0, !"typeid"}