diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -391,6 +391,35 @@ IndexC->getValue().zextOrTrunc(64)); } +static SelectInst *foldExtractElementConstSelect(ExtractElementInst &EI, + Value *Vec, ConstantInt *Idx) { + if (!EI.hasOneUse()) + return nullptr; + + if (auto *Select = dyn_cast(Vec); + Select != nullptr && Select->hasOneUse()) { + auto *TrueVal = dyn_cast(Select->getTrueValue()); + auto *FalseVal = dyn_cast(Select->getFalseValue()); + if (TrueVal && FalseVal) { + const uint64_t IndexVal = Idx->getValue().getZExtValue(); + const uint64_t NumElements = TrueVal->getType()->getNumElements(); + + if (IndexVal < NumElements) { + auto *TrueConst = + dyn_cast(TrueVal->getAggregateElement(IndexVal)); + auto *FalseConst = + dyn_cast(FalseVal->getAggregateElement(IndexVal)); + + if (TrueConst && FalseConst) + return SelectInst::Create(Select->getCondition(), TrueConst, + FalseConst); + } + } + } + + return nullptr; +} + Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { Value *SrcVec = EI.getVectorOperand(); Value *Index = EI.getIndexOperand(); @@ -402,6 +431,11 @@ // find a previously computed scalar that was inserted into the vector. auto *IndexC = dyn_cast(Index); if (IndexC) { + // extractelt (select %cond, , ), %c1, %c2 -> + // select %cond, [c1], [c2] + if (auto *Select = foldExtractElementConstSelect(EI, SrcVec, IndexC)) + return Select; + // Canonicalize type of constant indices to i64 to simplify CSE if (auto *NewIdx = getPreferredVectorIndex(IndexC)) return replaceOperand(EI, 1, NewIdx); diff --git a/llvm/test/Transforms/InstCombine/extractelement.ll b/llvm/test/Transforms/InstCombine/extractelement.ll --- a/llvm/test/Transforms/InstCombine/extractelement.ll +++ b/llvm/test/Transforms/InstCombine/extractelement.ll @@ -785,8 +785,7 @@ define i32 @extelt_select_const_operand_vector(i1 %c) { ; ANY-LABEL: @extelt_select_const_operand_vector( -; ANY-NEXT: [[S:%.*]] = select i1 [[C:%.*]], <3 x i32> , <3 x i32> -; ANY-NEXT: [[R:%.*]] = extractelement <3 x i32> [[S]], i64 2 +; ANY-NEXT: [[R:%.*]] = select i1 [[C:%.*]], i32 4, i32 7 ; ANY-NEXT: ret i32 [[R]] ; %s = select i1 %c, <3 x i32> , <3 x i32>