Index: lib/Transforms/InstCombine/InstCombineSelect.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineSelect.cpp +++ lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -1217,6 +1217,78 @@ return CastInst::CreateBitOrPointerCast(NewSel, Sel.getType()); } +/// Try to eliminate select instructions that test the returned flag of cmpxchg +/// instructions. +/// +/// If a select instruction tests the returned flag of a cmpxchg instruction and +/// selects between the returned value of the cmpxchg instruction its compare +/// operand, the result of the select will always be equal to its false value. +/// For example: +/// +/// %0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value seq_cst seq_cst +/// %1 = extractvalue { i64, i1 } %0, 1 +/// %2 = extractvalue { i64, i1 } %0, 0 +/// %3 = select i1 %1, i64 %compare, i64 %2 +/// ret i64 %3 +/// +/// The returned value of the cmpxchg instruction (%2) is the original value +/// located at %ptr prior to any update. If the cmpxchg operation succeeds %2 +/// must have been equal to %compare. Thus, the result of the select is always +/// equal to %2, and the code can be simplified to: +/// +/// %0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value seq_cst seq_cst +/// %2 = extractvalue { i64, i1 } %0, 0 +/// ret i64 %2 +/// +static Instruction *foldSelectCmpXchg(SelectInst &SI) { + // A helper that determines if V is an extractvalue instruction whose + // aggregate operand is a cmpxchg instruction and whose single index is equal + // to I. If such conditions are true, the helper returns the cmpxchg + // instruction; otherwise, a nullptr is returned. + auto isExtractFromCmpXchg = [](Value *V, unsigned I) -> AtomicCmpXchgInst * { + auto *Extract = dyn_cast(V); + if (!Extract) + return nullptr; + if (Extract->getIndices()[0] != I) + return nullptr; + return dyn_cast(Extract->getAggregateOperand()); + }; + + // If the select has a single user, and this user is a select instruction that + // we can simplify, skip the cmpxchg simplification for now. + if (SI.hasOneUse()) + if (auto *Select = dyn_cast(SI.user_back())) + if (Select->getCondition() == SI.getCondition()) + if (Select->getFalseValue() == SI.getTrueValue() || + Select->getTrueValue() == SI.getFalseValue()) + return nullptr; + + // Ensure the select condition is the returned flag of a cmpxchg instruction. + auto *CmpXchg = isExtractFromCmpXchg(SI.getCondition(), 1); + if (!CmpXchg) + return nullptr; + + // Check the true value case: The true value of the select is the returned + // value of the same cmpxchg used by the condition, and the false value is the + // cmpxchg instruction's condition operand. + if (auto *X = isExtractFromCmpXchg(SI.getTrueValue(), 0)) + if (X == CmpXchg && X->getCompareOperand() == SI.getFalseValue()) { + SI.setTrueValue(SI.getFalseValue()); + return &SI; + } + + // Check the false value case: The false value of the select is the returned + // value of the same cmpxchg used by the condition, and the true value is the + // cmpxchg instruction's condition operand. + if (auto *X = isExtractFromCmpXchg(SI.getFalseValue(), 0)) + if (X == CmpXchg && X->getCompareOperand() == SI.getTrueValue()) { + SI.setTrueValue(SI.getFalseValue()); + return &SI; + } + + return nullptr; +} + Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -1624,5 +1696,9 @@ if (Instruction *BitCastSel = foldSelectCmpBitcasts(SI, Builder)) return BitCastSel; + // Simplify selects that test the returned flag of cmpxchg instructions. + if (Instruction *Select = foldSelectCmpXchg(SI)) + return Select; + return nullptr; } Index: test/Transforms/InstCombine/select-cmpxchg.ll =================================================================== --- /dev/null +++ test/Transforms/InstCombine/select-cmpxchg.ll @@ -0,0 +1,39 @@ +; RUN: opt < %s -instcombine -S | FileCheck %s + +define i64 @cmpxchg_0(i64* %ptr, i64 %compare, i64 %new_value) { +; CHECK-LABEL: @cmpxchg_0( +; CHECK-NEXT: %tmp0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value seq_cst seq_cst +; CHECK-NEXT: %tmp2 = extractvalue { i64, i1 } %tmp0, 0 +; CHECK-NEXT: ret i64 %tmp2 +; + %tmp0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value seq_cst seq_cst + %tmp1 = extractvalue { i64, i1 } %tmp0, 1 + %tmp2 = extractvalue { i64, i1 } %tmp0, 0 + %tmp3 = select i1 %tmp1, i64 %compare, i64 %tmp2 + ret i64 %tmp3 +} + +define i64 @cmpxchg_1(i64* %ptr, i64 %compare, i64 %new_value) { +; CHECK-LABEL: @cmpxchg_1( +; CHECK-NEXT: %tmp0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value seq_cst seq_cst +; CHECK-NEXT: ret i64 %compare +; + %tmp0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value seq_cst seq_cst + %tmp1 = extractvalue { i64, i1 } %tmp0, 1 + %tmp2 = extractvalue { i64, i1 } %tmp0, 0 + %tmp3 = select i1 %tmp1, i64 %tmp2, i64 %compare + ret i64 %tmp3 +} + +define i64 @cmpxchg_2(i64* %ptr, i64 %compare, i64 %new_value) { +; CHECK-LABEL: @cmpxchg_2( +; CHECK-NEXT: %tmp0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value acq_rel monotonic +; CHECK-NEXT: ret i64 %compare +; + %tmp0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value acq_rel monotonic + %tmp1 = extractvalue { i64, i1 } %tmp0, 1 + %tmp2 = extractvalue { i64, i1 } %tmp0, 0 + %tmp3 = select i1 %tmp1, i64 %compare, i64 %tmp2 + %tmp4 = select i1 %tmp1, i64 %tmp3, i64 %compare + ret i64 %tmp4 +}