Index: lib/Transforms/Scalar/GVN.cpp =================================================================== --- lib/Transforms/Scalar/GVN.cpp +++ lib/Transforms/Scalar/GVN.cpp @@ -608,6 +608,8 @@ DenseMap LeaderTable; BumpPtrAllocator TableAllocator; + // Map used to replace values with constants in the same BB. + SmallMapVector ReplaceWithConstMap; SmallVector InstrsToErase; typedef SmallVector LoadDepVect; @@ -719,6 +721,7 @@ void verifyRemoved(const Instruction *I) const; bool splitCriticalEdges(); BasicBlock *splitCriticalEdges(BasicBlock *Pred, BasicBlock *Succ); + bool replaceOperandsWithConsts(Instruction *instr) const; bool propagateEquality(Value *LHS, Value *RHS, const BasicBlockEdge &Root); bool processFoldableCondBr(BranchInst *BI); void addDeadBlock(BasicBlock *BB); @@ -2031,6 +2034,22 @@ return Pred != nullptr; } +// Tries to replace instruction with const, +// using information from ReplaceWithConstMap. +bool GVN::replaceOperandsWithConsts(Instruction *Instr) const { + + bool Changed = false; + for (unsigned OpNum = 0 ; OpNum < Instr->getNumOperands() ; ++OpNum) { + Value *operand = Instr->getOperand(OpNum); + auto it = ReplaceWithConstMap.find(operand); + if (it != ReplaceWithConstMap.end()) { + Instr->setOperand(OpNum, it->second); + Changed = true; + } + } + return Changed; +} + /// The given values are known to be equal in every block /// dominated by 'Root'. Exploit this, for example by replacing 'LHS' with /// 'RHS' everywhere in the scope. Returns whether a change was made. @@ -2047,11 +2066,13 @@ std::pair Item = Worklist.pop_back_val(); LHS = Item.first; RHS = Item.second; - if (LHS == RHS) continue; + if (LHS == RHS) + continue; assert(LHS->getType() == RHS->getType() && "Equality but unequal types!"); // Don't try to propagate equalities between constants. - if (isa(LHS) && isa(RHS)) continue; + if (isa(LHS) && isa(RHS)) + continue; // Prefer a constant on the right-hand side, or an Argument if no constants. if (isa(LHS) || (isa(LHS) && !isa(RHS))) @@ -2202,6 +2223,43 @@ return true; } + if (IntrinsicInst *IntrinsicI = dyn_cast(I)) { + if (IntrinsicI->getIntrinsicID() == Intrinsic::assume) { + + Value *V = IntrinsicI->getArgOperand(0); + if (ICmpInst *ICmpI = dyn_cast(V)) { + if (ICmpI->getSignedPredicate() == ICmpInst::Predicate::ICMP_EQ) { + + bool Changed = false; + Value *Lhs = ICmpI->getOperand(0); + Value *Rhs = ICmpI->getOperand(1); + Constant *LhsConst = dyn_cast(Lhs); + Constant *RhsConst = dyn_cast(Rhs); + + // Propagate equality only if only one of operand is const. + if ((LhsConst != nullptr) ^ (RhsConst != nullptr)) { + for (BasicBlock *successor : successors(I->getParent())) { + BasicBlockEdge Edge(I->getParent(), successor); + + // equality propagation can't be done for every + // successor, but propagateEquality checks it + Changed |= propagateEquality(Lhs, + Rhs, + Edge); + } + + if (LhsConst != nullptr) + ReplaceWithConstMap[Rhs] = LhsConst; + else if (RhsConst != nullptr) + ReplaceWithConstMap[Lhs] = RhsConst; + } + return Changed; + } + } + + } + } + if (LoadInst *LI = dyn_cast(I)) { if (processLoad(LI)) return true; @@ -2266,7 +2324,8 @@ // Instructions with void type don't return a value, so there's // no point in trying to find redundancies in them. - if (I->getType()->isVoidTy()) return false; + if (I->getType()->isVoidTy()) + return false; uint32_t NextNum = VN.getNextUnusedValueNumber(); unsigned Num = VN.lookup_or_add(I); @@ -2373,10 +2432,14 @@ if (DeadBlocks.count(BB)) return false; + // Clearing map before every BB because it can be used only for single BB. + ReplaceWithConstMap.clear(); bool ChangedFunction = false; for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE;) { + + ChangedFunction |= replaceOperandsWithConsts(BI); ChangedFunction |= processInstruction(BI); if (InstrsToErase.empty()) { ++BI; Index: test/Transforms/GVN/assume-ptr-equal-same-bb.ll =================================================================== --- /dev/null +++ test/Transforms/GVN/assume-ptr-equal-same-bb.ll @@ -0,0 +1,30 @@ +; RUN: opt < %s -gvn -S | grep " call void @_ZN1A3fooEv(" + +; Checks if %2 will be replaced with @_ZN1A3fooEv, assuming +; that %vtable == _ZTV1A (with alignment). + +%struct.A = type { i32 (...)** } + +@_ZTV1A = available_externally unnamed_addr constant [3 x i8*] [i8* null, i8* bitcast (i8** @_ZTI1A to i8*), i8* bitcast (void (%struct.A*)* @_ZN1A3fooEv to i8*)], align 8 +@_ZTI1A = external constant i8* + +define i32 @main() #0 { +entry: + %call = tail call noalias i8* @_Znwm(i64 8) + %0 = bitcast i8* %call to %struct.A* + %1 = bitcast i8* %call to i8*** + + %vtable = load i8**, i8*** %1, align 8 + %cmp.vtables = icmp eq i8** %vtable, getelementptr inbounds ([3 x i8*], [3 x i8*]* @_ZTV1A, i64 0, i64 2) + tail call void @llvm.assume(i1 %cmp.vtables) + + %vtable1.cast = bitcast i8** %vtable to void (%struct.A*)** + %2 = load void (%struct.A*)*, void (%struct.A*)** %vtable1.cast, align 8 + tail call void %2(%struct.A* %0) + + ret i32 0 +} + +declare noalias i8* @_Znwm(i64) +declare void @llvm.assume(i1) +declare void @_ZN1A3fooEv(%struct.A*) Index: test/Transforms/GVN/assume-ptr-equal-successors.ll =================================================================== --- /dev/null +++ test/Transforms/GVN/assume-ptr-equal-successors.ll @@ -0,0 +1,47 @@ +; RUN: opt < %s -gvn -S | FileCheck %s + +; Checks if %2 and %3 will be replaced with direct calls to foo and bar by +; assuming that %vtable == @_ZTV1A (with alignment). +; CHECK: call i32 @_ZN1A3fooEv( +; CHECK: call i32 @_ZN1A3barEv( + + +%struct.A = type { i32 (...)** } + +@_ZTV1A = available_externally unnamed_addr constant [4 x i8*] [i8* null, i8* bitcast (i8** @_ZTI1A to i8*), i8* bitcast (i32 (%struct.A*)* @_ZN1A3fooEv to i8*), i8* bitcast (i32 (%struct.A*)* @_ZN1A3barEv to i8*)], align 8 +@_ZTI1A = external constant i8* + +define void @_Z1gb(i1 zeroext %p) { +entry: + %call = tail call noalias i8* @_Znwm(i64 8) #4 + %0 = bitcast i8* %call to %struct.A* + tail call void @_ZN1AC1Ev(%struct.A* %0) #1 + %1 = bitcast i8* %call to i8*** + %vtable = load i8**, i8*** %1, align 8 + %cmp.vtables = icmp eq i8** %vtable, getelementptr inbounds ([4 x i8*], [4 x i8*]* @_ZTV1A, i64 0, i64 2) + tail call void @llvm.assume(i1 %cmp.vtables) + br i1 %p, label %if.then, label %if.else + +if.then: ; preds = %entry + %vtable1.cast = bitcast i8** %vtable to i32 (%struct.A*)** + %2 = load i32 (%struct.A*)*, i32 (%struct.A*)** %vtable1.cast, align 8 + %call2 = tail call i32 %2(%struct.A* %0) #1 + br label %if.end + +if.else: ; preds = %entry + %vfn47 = getelementptr inbounds i8*, i8** %vtable, i64 1 + %vfn4 = bitcast i8** %vfn47 to i32 (%struct.A*)** + %3 = load i32 (%struct.A*)*, i32 (%struct.A*)** %vfn4, align 8 + %call5 = tail call i32 %3(%struct.A* %0) #1 + br label %if.end + +if.end: ; preds = %if.else, %if.then + ret void +} + +declare noalias i8* @_Znwm(i64) +declare void @_ZN1AC1Ev(%struct.A*) +declare void @llvm.assume(i1) +declare i32 @_ZN1A3fooEv(%struct.A*) +declare i32 @_ZN1A3barEv(%struct.A*) +