diff --git a/llvm/include/llvm/Transforms/Utils/PredicateInfo.h b/llvm/include/llvm/Transforms/Utils/PredicateInfo.h --- a/llvm/include/llvm/Transforms/Utils/PredicateInfo.h +++ b/llvm/include/llvm/Transforms/Utils/PredicateInfo.h @@ -70,6 +70,13 @@ enum PredicateType { PT_Branch, PT_Assume, PT_Switch }; +/// Constraint for a predicate of the form "cmp Pred Op, OtherOp", where Op +/// is the value the constraint applies to (the ssa.copy result). +struct PredicateConstraint { + CmpInst::Predicate Predicate; + Value *OtherOp; +}; + // Base class for all predicate information we provide. // All of our predicate information has at least a comparison. class PredicateBase : public ilist_node { @@ -95,6 +102,9 @@ PB->Type == PT_Switch; } + /// Fetch condition in the form of PredicateConstraint, if possible. + Optional getConstraint() const; + protected: PredicateBase(PredicateType PT, Value *Op, Value *Condition) : Type(PT), OriginalOp(Op), Condition(Condition) {} diff --git a/llvm/lib/Transforms/Scalar/NewGVN.cpp b/llvm/lib/Transforms/Scalar/NewGVN.cpp --- a/llvm/lib/Transforms/Scalar/NewGVN.cpp +++ b/llvm/lib/Transforms/Scalar/NewGVN.cpp @@ -1539,86 +1539,39 @@ LLVM_DEBUG(dbgs() << "Found predicate info from instruction !\n"); - auto *CopyOf = I->getOperand(0); - auto *Cond = PI->Condition; - - // If this a copy of the condition, it must be either true or false depending - // on the predicate info type and edge. - if (CopyOf == Cond) { - // We should not need to add predicate users because the predicate info is - // already a use of this operand. - if (isa(PI)) - return createConstantExpression(ConstantInt::getTrue(Cond->getType())); - if (auto *PBranch = dyn_cast(PI)) { - if (PBranch->TrueEdge) - return createConstantExpression(ConstantInt::getTrue(Cond->getType())); - return createConstantExpression(ConstantInt::getFalse(Cond->getType())); - } - if (auto *PSwitch = dyn_cast(PI)) - return createConstantExpression(cast(PSwitch->CaseValue)); - } - - // Not a copy of the condition, so see what the predicates tell us about this - // value. First, though, we check to make sure the value is actually a copy - // of one of the condition operands. It's possible, in certain cases, for it - // to be a copy of a predicateinfo copy. In particular, if two branch - // operations use the same condition, and one branch dominates the other, we - // will end up with a copy of a copy. This is currently a small deficiency in - // predicateinfo. What will end up happening here is that we will value - // number both copies the same anyway. - - // Everything below relies on the condition being a comparison. - auto *Cmp = dyn_cast(Cond); - if (!Cmp) + const Optional &Constraint = PI->getConstraint(); + if (!Constraint) return nullptr; - if (CopyOf != Cmp->getOperand(0) && CopyOf != Cmp->getOperand(1)) { - LLVM_DEBUG(dbgs() << "Copy is not of any condition operands!\n"); - return nullptr; - } - Value *FirstOp = lookupOperandLeader(Cmp->getOperand(0)); - Value *SecondOp = lookupOperandLeader(Cmp->getOperand(1)); - bool SwappedOps = false; + CmpInst::Predicate Predicate = Constraint->Predicate; + Value *CmpOp0 = I->getOperand(0); + Value *CmpOp1 = Constraint->OtherOp; + + Value *FirstOp = lookupOperandLeader(CmpOp0); + Value *SecondOp = lookupOperandLeader(CmpOp1); + Value *AdditionallyUsedValue = CmpOp0; + // Sort the ops. if (shouldSwapOperands(FirstOp, SecondOp)) { std::swap(FirstOp, SecondOp); - SwappedOps = true; + Predicate = CmpInst::getSwappedPredicate(Predicate); + AdditionallyUsedValue = CmpOp1; } - CmpInst::Predicate Predicate = - SwappedOps ? Cmp->getSwappedPredicate() : Cmp->getPredicate(); - - if (isa(PI)) { - // If we assume the operands are equal, then they are equal. - if (Predicate == CmpInst::ICMP_EQ) { - addPredicateUsers(PI, I); - addAdditionalUsers(SwappedOps ? Cmp->getOperand(1) : Cmp->getOperand(0), - I); - return createVariableOrConstant(FirstOp); - } + + if (Predicate == CmpInst::ICMP_EQ) { + addPredicateUsers(PI, I); + addAdditionalUsers(AdditionallyUsedValue, I); + return createVariableOrConstant(FirstOp); } - if (const auto *PBranch = dyn_cast(PI)) { - // If we are *not* a copy of the comparison, we may equal to the other - // operand when the predicate implies something about equality of - // operations. In particular, if the comparison is true/false when the - // operands are equal, and we are on the right edge, we know this operation - // is equal to something. - if ((PBranch->TrueEdge && Predicate == CmpInst::ICMP_EQ) || - (!PBranch->TrueEdge && Predicate == CmpInst::ICMP_NE)) { - addPredicateUsers(PI, I); - addAdditionalUsers(SwappedOps ? Cmp->getOperand(1) : Cmp->getOperand(0), - I); - return createVariableOrConstant(FirstOp); - } - // Handle the special case of floating point. - if (((PBranch->TrueEdge && Predicate == CmpInst::FCMP_OEQ) || - (!PBranch->TrueEdge && Predicate == CmpInst::FCMP_UNE)) && - isa(FirstOp) && !cast(FirstOp)->isZero()) { - addPredicateUsers(PI, I); - addAdditionalUsers(SwappedOps ? Cmp->getOperand(1) : Cmp->getOperand(0), - I); - return createConstantExpression(cast(FirstOp)); - } + + // Handle the special case of floating point. + if (Predicate == CmpInst::FCMP_OEQ && isa(FirstOp) && + !cast(FirstOp)->isZero()) { + addPredicateUsers(PI, I); + addAdditionalUsers(AdditionallyUsedValue, I); + return createConstantExpression(cast(FirstOp)); } + return nullptr; } diff --git a/llvm/lib/Transforms/Scalar/SCCP.cpp b/llvm/lib/Transforms/Scalar/SCCP.cpp --- a/llvm/lib/Transforms/Scalar/SCCP.cpp +++ b/llvm/lib/Transforms/Scalar/SCCP.cpp @@ -1262,55 +1262,22 @@ auto *PI = getPredicateInfoFor(&CB); assert(PI && "Missing predicate info for ssa.copy"); - CmpInst *Cmp; - bool TrueEdge; - if (auto *PBranch = dyn_cast(PI)) { - Cmp = dyn_cast(PBranch->Condition); - TrueEdge = PBranch->TrueEdge; - } else if (auto *PAssume = dyn_cast(PI)) { - Cmp = dyn_cast(PAssume->Condition); - TrueEdge = true; - } else { + const Optional &Constraint = PI->getConstraint(); + if (!Constraint) { mergeInValue(ValueState[&CB], &CB, CopyOfVal); return; } - // Everything below relies on the condition being a comparison. - if (!Cmp) { - mergeInValue(ValueState[&CB], &CB, CopyOfVal); - return; - } + CmpInst::Predicate Pred = Constraint->Predicate; + Value *OtherOp = Constraint->OtherOp; - Value *RenamedOp = PI->RenamedOp; - Value *CmpOp0 = Cmp->getOperand(0); - Value *CmpOp1 = Cmp->getOperand(1); - // Bail out if neither of the operands matches RenamedOp. - if (CmpOp0 != RenamedOp && CmpOp1 != RenamedOp) { - mergeInValue(ValueState[&CB], &CB, getValueState(CopyOf)); + // Wait until OtherOp is resolved. + if (getValueState(OtherOp).isUnknown()) { + addAdditionalUser(OtherOp, &CB); return; } - auto Pred = Cmp->getPredicate(); - if (CmpOp1 == RenamedOp) { - std::swap(CmpOp0, CmpOp1); - Pred = Cmp->getSwappedPredicate(); - } - - // Wait until CmpOp1 is resolved. - if (getValueState(CmpOp1).isUnknown()) { - addAdditionalUser(CmpOp1, &CB); - return; - } - - // The code below relies on PredicateInfo only inserting copies for the - // true branch when the branch condition is an AND and only inserting - // copies for the false branch when the branch condition is an OR. This - // ensures we can intersect the range from the condition with the range of - // CopyOf. - if (!TrueEdge) - Pred = CmpInst::getInversePredicate(Pred); - - ValueLatticeElement CondVal = getValueState(CmpOp1); + ValueLatticeElement CondVal = getValueState(OtherOp); ValueLatticeElement &IV = ValueState[&CB]; if (CondVal.isConstantRange() || CopyOfVal.isConstantRange()) { auto ImposedCR = @@ -1334,7 +1301,7 @@ if (!CopyOfCR.contains(NewCR) && CopyOfCR.getSingleMissingElement()) NewCR = CopyOfCR; - addAdditionalUser(CmpOp1, &CB); + addAdditionalUser(OtherOp, &CB); // TODO: Actually filp MayIncludeUndef for the created range to false, // once most places in the optimizer respect the branches on // undef/poison are UB rule. The reason why the new range cannot be @@ -1351,7 +1318,7 @@ } else if (Pred == CmpInst::ICMP_EQ && CondVal.isConstant()) { // For non-integer values or integer constant expressions, only // propagate equal constants. - addAdditionalUser(CmpOp1, &CB); + addAdditionalUser(OtherOp, &CB); mergeInValue(IV, &CB, CondVal); return; } diff --git a/llvm/lib/Transforms/Utils/PredicateInfo.cpp b/llvm/lib/Transforms/Utils/PredicateInfo.cpp --- a/llvm/lib/Transforms/Utils/PredicateInfo.cpp +++ b/llvm/lib/Transforms/Utils/PredicateInfo.cpp @@ -822,6 +822,53 @@ } } +Optional PredicateBase::getConstraint() const { + switch (Type) { + case PT_Assume: + case PT_Branch: { + bool TrueEdge = true; + if (auto *PBranch = dyn_cast(this)) + TrueEdge = PBranch->TrueEdge; + + if (Condition == RenamedOp) { + return {{CmpInst::ICMP_EQ, + TrueEdge ? ConstantInt::getTrue(Condition->getType()) + : ConstantInt::getFalse(Condition->getType())}}; + } + + CmpInst *Cmp = dyn_cast(Condition); + assert(Cmp && "Condition should be a CmpInst"); + + CmpInst::Predicate Pred; + Value *OtherOp; + if (Cmp->getOperand(0) == RenamedOp) { + Pred = Cmp->getPredicate(); + OtherOp = Cmp->getOperand(1); + } else if (Cmp->getOperand(1) == RenamedOp) { + Pred = Cmp->getSwappedPredicate(); + OtherOp = Cmp->getOperand(0); + } else { + // TODO: Make this an assertion once RenamedOp is fully accurate. + return None; + } + + // Invert predicate along false edge. + if (!TrueEdge) + Pred = CmpInst::getInversePredicate(Pred); + + return {{Pred, OtherOp}}; + } + case PT_Switch: + if (Condition != RenamedOp) { + // TODO: Make this an assertion once RenamedOp is fully accurate. + return None; + } + + return {{CmpInst::ICMP_EQ, cast(this)->CaseValue}}; + } + llvm_unreachable("Unknown predicate type"); +} + void PredicateInfo::verifyPredicateInfo() const {} char PredicateInfoPrinterLegacyPass::ID = 0; diff --git a/llvm/test/Transforms/SCCP/predicateinfo-cond.ll b/llvm/test/Transforms/SCCP/predicateinfo-cond.ll --- a/llvm/test/Transforms/SCCP/predicateinfo-cond.ll +++ b/llvm/test/Transforms/SCCP/predicateinfo-cond.ll @@ -11,16 +11,13 @@ ; CHECK-NEXT: i32 2, label [[CASE_2:%.*]] ; CHECK-NEXT: ] ; CHECK: case.0: -; CHECK-NEXT: [[ADD:%.*]] = add i32 [[X]], 1 ; CHECK-NEXT: br label [[END:%.*]] ; CHECK: case.2: -; CHECK-NEXT: [[SUB:%.*]] = sub i32 [[X]], 1 ; CHECK-NEXT: br label [[END]] ; CHECK: case.default: ; CHECK-NEXT: br label [[END]] ; CHECK: end: -; CHECK-NEXT: [[PHI:%.*]] = phi i32 [ [[ADD]], [[CASE_0]] ], [ [[SUB]], [[CASE_2]] ], [ 1, [[CASE_DEFAULT]] ] -; CHECK-NEXT: ret i32 [[PHI]] +; CHECK-NEXT: ret i32 1 ; switch i32 %x, label %case.default [ i32 0, label %case.0 @@ -47,7 +44,7 @@ ; CHECK-LABEL: @assume( ; CHECK-NEXT: [[CMP:%.*]] = icmp sge i32 [[X:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[CMP]]) -; CHECK-NEXT: ret i1 [[CMP]] +; CHECK-NEXT: ret i1 true ; %cmp = icmp sge i32 %x, 0 call void @llvm.assume(i1 %cmp) @@ -59,23 +56,17 @@ ; CHECK-NEXT: [[CMP:%.*]] = icmp sge i32 [[X:%.*]], 0 ; CHECK-NEXT: br i1 [[CMP]], label [[IF_THEN1:%.*]], label [[IF_THEN2:%.*]] ; CHECK: if.then1: -; CHECK-NEXT: br i1 [[CMP]], label [[IF2_THEN1:%.*]], label [[IF2_THEN2:%.*]] +; CHECK-NEXT: br label [[IF2_THEN1:%.*]] ; CHECK: if2.then1: ; CHECK-NEXT: br label [[IF2_END:%.*]] -; CHECK: if2.then2: -; CHECK-NEXT: br label [[IF2_END]] ; CHECK: if2.end: -; CHECK-NEXT: [[PHI:%.*]] = phi i32 [ 0, [[IF2_THEN1]] ], [ 1, [[IF2_THEN2]] ] -; CHECK-NEXT: ret i32 [[PHI]] +; CHECK-NEXT: ret i32 0 ; CHECK: if.then2: -; CHECK-NEXT: br i1 [[CMP]], label [[IF3_THEN1:%.*]], label [[IF3_THEN2:%.*]] -; CHECK: if3.then1: -; CHECK-NEXT: br label [[IF3_END:%.*]] +; CHECK-NEXT: br label [[IF3_THEN2:%.*]] ; CHECK: if3.then2: -; CHECK-NEXT: br label [[IF3_END]] +; CHECK-NEXT: br label [[IF3_END:%.*]] ; CHECK: if3.end: -; CHECK-NEXT: [[PHI2:%.*]] = phi i32 [ 0, [[IF3_THEN1]] ], [ 1, [[IF3_THEN2]] ] -; CHECK-NEXT: ret i32 [[PHI2]] +; CHECK-NEXT: ret i32 1 ; %cmp = icmp sge i32 %x, 0 br i1 %cmp, label %if.then1, label %if.then2