diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -623,6 +623,18 @@ /// BB3: phi [BO, BB1], [(binop C1, C2), BB2] Instruction *foldBinopWithPhiOperands(BinaryOperator &BO); + /// Given an extractelement instruction with a select as base which provides + /// two constant vectors, the extractelement instruction can be replaced by + /// a select instruction which directly selects the corresponding elements + /// from the vectors: + /// extractelt (select %cond, , ), %c -> + /// select %cond, [c], [c] + /// TODO: This can be extended to support arbitrary vectors: + /// extractelt (select %cond, , ), %c -> + /// select %cond, (extractelt , %c), (extractelt , %c) + Instruction *FoldExtractElementSelectConstVector(ExtractElementInst *EI, + SelectInst *SI); + /// Given an instruction with a select as one operand and a constant as the /// other operand, try to fold the binary operator into the select arguments. /// This also works for Cast instructions, which obviously do not have a diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -398,6 +398,10 @@ SQ.getWithInstruction(&EI))) return replaceInstUsesWith(EI, V); + if (SelectInst *SI = dyn_cast(EI.getVectorOperand())) + if (Instruction *R = FoldOpIntoSelect(EI, SI)) + return R; + // If extracting a specified index from the vector, see if we can recursively // find a previously computed scalar that was inserted into the vector. auto *IndexC = dyn_cast(Index); @@ -587,6 +591,7 @@ } } } + return nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -1085,12 +1085,48 @@ return NewBO; } +Instruction * +InstCombinerImpl::FoldExtractElementSelectConstVector(ExtractElementInst *EI, + SelectInst *SI) { + auto *TrueVal = dyn_cast(SI->getTrueValue()); + auto *FalseVal = dyn_cast(SI->getFalseValue()); + if (!TrueVal || !FalseVal) + return nullptr; + + ConstantInt *Idx = dyn_cast(EI->getIndexOperand()); + if (!Idx) + return nullptr; + + const uint64_t IndexVal = Idx->getValue().getZExtValue(); + const uint64_t NumElements = TrueVal->getType()->getNumElements(); + + if (IndexVal >= NumElements) + return nullptr; + + auto *TrueConst = + dyn_cast(TrueVal->getAggregateElement(IndexVal)); + auto *FalseConst = + dyn_cast(FalseVal->getAggregateElement(IndexVal)); + + if (TrueConst && FalseConst) + return SelectInst::Create(SI->getCondition(), TrueConst, FalseConst); + + return nullptr; +} + Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI, bool FoldWithMultiUse) { // Don't modify shared select instructions unless set FoldWithMultiUse if (!SI->hasOneUse() && !FoldWithMultiUse) return nullptr; + // If Op is an ExtractElement instruction and the select + // instruction has constant vector operands, see if we + // can directly extract the constant values from it. + if (auto *EI = dyn_cast(&Op)) + if (auto *Select = FoldExtractElementSelectConstVector(EI, SI)) + return Select; + Value *TV = SI->getTrueValue(); Value *FV = SI->getFalseValue(); if (!(isa(TV) || isa(FV))) diff --git a/llvm/test/Transforms/InstCombine/extractelement.ll b/llvm/test/Transforms/InstCombine/extractelement.ll --- a/llvm/test/Transforms/InstCombine/extractelement.ll +++ b/llvm/test/Transforms/InstCombine/extractelement.ll @@ -785,8 +785,7 @@ define i32 @extelt_select_const_operand_vector(i1 %c) { ; ANY-LABEL: @extelt_select_const_operand_vector( -; ANY-NEXT: [[S:%.*]] = select i1 [[C:%.*]], <3 x i32> , <3 x i32> -; ANY-NEXT: [[R:%.*]] = extractelement <3 x i32> [[S]], i64 2 +; ANY-NEXT: [[R:%.*]] = select i1 [[C:%.*]], i32 4, i32 7 ; ANY-NEXT: ret i32 [[R]] ; %s = select i1 %c, <3 x i32> , <3 x i32> @@ -796,11 +795,10 @@ define i32 @extelt_select_const_operand_extractelt_use(i1 %c) { ; ANY-LABEL: @extelt_select_const_operand_extractelt_use( -; ANY-NEXT: [[S:%.*]] = select i1 [[C:%.*]], <3 x i32> , <3 x i32> -; ANY-NEXT: [[E:%.*]] = extractelement <3 x i32> [[S]], i64 2 -; ANY-NEXT: [[M:%.*]] = shl i32 [[E]], 1 -; ANY-NEXT: [[M_2:%.*]] = shl i32 [[E]], 2 -; ANY-NEXT: [[R:%.*]] = mul i32 [[M]], [[M_2]] +; ANY-NEXT: [[E:%.*]] = select i1 [[C:%.*]], i32 4, i32 7 +; ANY-NEXT: [[M:%.*]] = shl nuw nsw i32 [[E]], 1 +; ANY-NEXT: [[M_2:%.*]] = shl nuw nsw i32 [[E]], 2 +; ANY-NEXT: [[R:%.*]] = mul nuw nsw i32 [[M]], [[M_2]] ; ANY-NEXT: ret i32 [[R]] ; %s = select i1 %c, <3 x i32> , <3 x i32>