diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -52,20 +52,29 @@ "original aggregate"); /// Return true if the value is cheaper to scalarize than it is to leave as a -/// vector operation. IsConstantExtractIndex indicates whether we are extracting -/// one known element from a vector constant. +/// vector operation. If the extract index \p EI is a constant integer then +/// some operations may be cheap to scalarize. /// /// FIXME: It's possible to create more instructions than previously existed. -static bool cheapToScalarize(Value *V, bool IsConstantExtractIndex) { +static bool cheapToScalarize(Value *V, Value *EI) { + ConstantInt *CEI = dyn_cast(EI); + // If we can pick a scalar constant value out of a vector, that is free. if (auto *C = dyn_cast(V)) - return IsConstantExtractIndex || C->getSplatValue(); + return CEI || C->getSplatValue(); + + if (CEI && match(V, m_Intrinsic())) { + ElementCount EC = cast(V->getType())->getElementCount(); + // Index needs to be lower than the minimum size of the vector, because + // for scalable vector, the vector size is known at run time. + return CEI->getValue().ult(EC.getKnownMinValue()); + } // An insertelement to the same constant index as our extract will simplify // to the scalar inserted element. An insertelement to a different constant // index is irrelevant to our extract. if (match(V, m_InsertElt(m_Value(), m_Value(), m_ConstantInt()))) - return IsConstantExtractIndex; + return CEI; if (match(V, m_OneUse(m_Load(m_Value())))) return true; @@ -75,14 +84,12 @@ Value *V0, *V1; if (match(V, m_OneUse(m_BinOp(m_Value(V0), m_Value(V1))))) - if (cheapToScalarize(V0, IsConstantExtractIndex) || - cheapToScalarize(V1, IsConstantExtractIndex)) + if (cheapToScalarize(V0, EI) || cheapToScalarize(V1, EI)) return true; CmpInst::Predicate UnusedPred; if (match(V, m_OneUse(m_Cmp(UnusedPred, m_Value(V0), m_Value(V1))))) - if (cheapToScalarize(V0, IsConstantExtractIndex) || - cheapToScalarize(V1, IsConstantExtractIndex)) + if (cheapToScalarize(V0, EI) || cheapToScalarize(V1, EI)) return true; return false; @@ -119,7 +126,8 @@ // and that it is a binary operation which is cheap to scalarize. // otherwise return nullptr. if (!PHIUser->hasOneUse() || !(PHIUser->user_back() == PN) || - !(isa(PHIUser)) || !cheapToScalarize(PHIUser, true)) + !(isa(PHIUser)) || + !cheapToScalarize(PHIUser, EI.getIndexOperand())) return nullptr; // Create a scalar PHI node that will replace the vector PHI node @@ -415,7 +423,7 @@ // TODO come up with a n-ary matcher that subsumes both unary and // binary matchers. UnaryOperator *UO; - if (match(SrcVec, m_UnOp(UO)) && cheapToScalarize(SrcVec, IndexC)) { + if (match(SrcVec, m_UnOp(UO)) && cheapToScalarize(SrcVec, Index)) { // extelt (unop X), Index --> unop (extelt X, Index) Value *X = UO->getOperand(0); Value *E = Builder.CreateExtractElement(X, Index); @@ -423,7 +431,7 @@ } BinaryOperator *BO; - if (match(SrcVec, m_BinOp(BO)) && cheapToScalarize(SrcVec, IndexC)) { + if (match(SrcVec, m_BinOp(BO)) && cheapToScalarize(SrcVec, Index)) { // extelt (binop X, Y), Index --> binop (extelt X, Index), (extelt Y, Index) Value *X = BO->getOperand(0), *Y = BO->getOperand(1); Value *E0 = Builder.CreateExtractElement(X, Index); @@ -434,7 +442,7 @@ Value *X, *Y; CmpInst::Predicate Pred; if (match(SrcVec, m_Cmp(Pred, m_Value(X), m_Value(Y))) && - cheapToScalarize(SrcVec, IndexC)) { + cheapToScalarize(SrcVec, Index)) { // extelt (cmp X, Y), Index --> cmp (extelt X, Index), (extelt Y, Index) Value *E0 = Builder.CreateExtractElement(X, Index); Value *E1 = Builder.CreateExtractElement(Y, Index); diff --git a/llvm/test/Transforms/InstCombine/vscale_extractelement.ll b/llvm/test/Transforms/InstCombine/vscale_extractelement.ll --- a/llvm/test/Transforms/InstCombine/vscale_extractelement.ll +++ b/llvm/test/Transforms/InstCombine/vscale_extractelement.ll @@ -243,6 +243,35 @@ ret i8 %1 } +; Check that we can extract more complex cases where the stepvector is +; involved in a binary operation prior to the lane being extracted. + +define i64 @ext_lane0_from_add_with_stepvec(i64 %i) { +; CHECK-LABEL: @ext_lane0_from_add_with_stepvec( +; CHECK-NEXT: ret i64 [[I:%.*]] +; + %tmp = insertelement poison, i64 %i, i32 0 + %splatofi = shufflevector %tmp, poison, zeroinitializer + %stepvec = call @llvm.experimental.stepvector.nxv2i64() + %add = add %splatofi, %stepvec + %res = extractelement %add, i32 0 + ret i64 %res +} + +define i1 @ext_lane1_from_cmp_with_stepvec(i64 %i) { +; CHECK-LABEL: @ext_lane1_from_cmp_with_stepvec( +; CHECK-NEXT: [[RES:%.*]] = icmp eq i64 [[I:%.*]], 1 +; CHECK-NEXT: ret i1 [[RES]] +; + %tmp = insertelement poison, i64 %i, i32 0 + %splatofi = shufflevector %tmp, poison, zeroinitializer + %stepvec = call @llvm.experimental.stepvector.nxv2i64() + %cmp = icmp eq %splatofi, %stepvec + %res = extractelement %cmp, i32 1 + ret i1 %res +} + +declare @llvm.experimental.stepvector.nxv2i64() declare @llvm.experimental.stepvector.nxv4i64() declare @llvm.experimental.stepvector.nxv4i32() declare @llvm.experimental.stepvector.nxv512i8()