Index: llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -2279,12 +2279,37 @@ /// across all DemandedElts. bool SelectionDAG::isSplatValue(SDValue V, const APInt &DemandedElts, APInt &UndefElts) { - if (!DemandedElts) - return false; // No demanded elts, better to assume we don't know anything. - EVT VT = V.getValueType(); assert(VT.isVector() && "Vector type expected"); + if (!VT.isScalableVector() && !DemandedElts) + return false; // No demanded elts, better to assume we don't know anything. + + // Deal with some common cases here that work for both fixed and scalable + // vector types. + switch (V.getOpcode()) { + case ISD::SPLAT_VECTOR: + return true; + case ISD::ADD: + case ISD::SUB: + case ISD::AND: { + APInt UndefLHS, UndefRHS; + SDValue LHS = V.getOperand(0); + SDValue RHS = V.getOperand(1); + if (isSplatValue(LHS, DemandedElts, UndefLHS) && + isSplatValue(RHS, DemandedElts, UndefRHS)) { + UndefElts = UndefLHS | UndefRHS; + return true; + } + break; + } + } + + // We don't support other cases than those above for scalable vectors at + // the moment. + if (VT.isScalableVector()) + return false; + unsigned NumElts = VT.getVectorNumElements(); assert(NumElts == DemandedElts.getBitWidth() && "Vector size mismatch"); UndefElts = APInt::getNullValue(NumElts); @@ -2341,19 +2366,6 @@ } break; } - case ISD::ADD: - case ISD::SUB: - case ISD::AND: { - APInt UndefLHS, UndefRHS; - SDValue LHS = V.getOperand(0); - SDValue RHS = V.getOperand(1); - if (isSplatValue(LHS, DemandedElts, UndefLHS) && - isSplatValue(RHS, DemandedElts, UndefRHS)) { - UndefElts = UndefLHS | UndefRHS; - return true; - } - break; - } } return false; @@ -2363,10 +2375,13 @@ bool SelectionDAG::isSplatValue(SDValue V, bool AllowUndefs) { EVT VT = V.getValueType(); assert(VT.isVector() && "Vector type expected"); - unsigned NumElts = VT.getVectorNumElements(); APInt UndefElts; - APInt DemandedElts = APInt::getAllOnesValue(NumElts); + APInt DemandedElts; + + // For now we don't support this with scalable vectors. + if (!VT.isScalableVector()) + DemandedElts = APInt::getAllOnesValue(VT.getVectorNumElements()); return isSplatValue(V, DemandedElts, UndefElts) && (AllowUndefs || !UndefElts); } @@ -2379,19 +2394,35 @@ switch (Opcode) { default: { APInt UndefElts; - APInt DemandedElts = APInt::getAllOnesValue(VT.getVectorNumElements()); + APInt DemandedElts; + + if (!VT.isScalableVector()) + DemandedElts = APInt::getAllOnesValue(VT.getVectorNumElements()); + if (isSplatValue(V, DemandedElts, UndefElts)) { - // Handle case where all demanded elements are UNDEF. - if (DemandedElts.isSubsetOf(UndefElts)) { + if (VT.isScalableVector()) { + // DemandedElts and UndefElts are ignored for scalable vectors, since + // the only supported cases are SPLAT_VECTOR nodes. SplatIdx = 0; - return getUNDEF(VT); + } else { + // Handle case where all demanded elements are UNDEF. + if (DemandedElts.isSubsetOf(UndefElts)) { + SplatIdx = 0; + return getUNDEF(VT); + } + SplatIdx = (UndefElts & DemandedElts).countTrailingOnes(); } - SplatIdx = (UndefElts & DemandedElts).countTrailingOnes(); return V; } break; } + case ISD::SPLAT_VECTOR: + SplatIdx = 0; + return V; case ISD::VECTOR_SHUFFLE: { + if (VT.isScalableVector()) + return SDValue(); + // Check if this is a shuffle node doing a splat. // TODO - remove this and rely purely on SelectionDAG::isSplatValue, // getTargetVShiftNode currently struggles without the splat source. Index: llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp =================================================================== --- llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp +++ llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp @@ -199,4 +199,100 @@ EXPECT_EQ(Known.One, APInt(8, 0x1)); } +TEST_F(AArch64SelectionDAGTest, isSplatValue_Fixed_BUILD_VECTOR) { + if (!TM) + return; + + TargetLowering TL(*TM); + + SDLoc Loc; + auto IntVT = EVT::getIntegerVT(Context, 8); + auto VecVT = EVT::getVectorVT(Context, IntVT, 16, false); + auto Op = DAG->getConstant(1, Loc, VecVT); + EXPECT_EQ(Op->getOpcode(), ISD::BUILD_VECTOR); + EXPECT_EQ(DAG->isSplatValue(Op, false), true); + + APInt UndefElts; + APInt DemandedElts; + EXPECT_EQ(DAG->isSplatValue(Op, DemandedElts, UndefElts), false); + + DemandedElts = APInt(16, 3); + EXPECT_EQ(DAG->isSplatValue(Op, DemandedElts, UndefElts), true); +} + +TEST_F(AArch64SelectionDAGTest, isSplatValue_Fixed_ADD_of_BUILD_VECTOR) { + if (!TM) + return; + + TargetLowering TL(*TM); + + SDLoc Loc; + auto IntVT = EVT::getIntegerVT(Context, 8); + auto VecVT = EVT::getVectorVT(Context, IntVT, 16, false); + + // Should create BUILD_VECTORs + auto Val1 = DAG->getConstant(1, Loc, VecVT); + auto Val2 = DAG->getConstant(3, Loc, VecVT); + EXPECT_EQ(Val1->getOpcode(), ISD::BUILD_VECTOR); + auto Op = DAG->getNode(ISD::ADD, Loc, VecVT, Val1, Val2); + + EXPECT_EQ(DAG->isSplatValue(Op, false), true); + + APInt UndefElts; + APInt DemandedElts; + EXPECT_EQ(DAG->isSplatValue(Op, DemandedElts, UndefElts), false); + + DemandedElts = APInt(16, 3); + EXPECT_EQ(DAG->isSplatValue(Op, DemandedElts, UndefElts), true); +} + +TEST_F(AArch64SelectionDAGTest, isSplatValue_Scalable_SPLAT_VECTOR) { + if (!TM) + return; + + TargetLowering TL(*TM); + + SDLoc Loc; + auto IntVT = EVT::getIntegerVT(Context, 8); + auto VecVT = EVT::getVectorVT(Context, IntVT, 16, true); + auto Op = DAG->getConstant(1, Loc, VecVT); + EXPECT_EQ(Op->getOpcode(), ISD::SPLAT_VECTOR); + EXPECT_EQ(DAG->isSplatValue(Op, false), true); + + APInt UndefElts; + APInt DemandedElts; + EXPECT_EQ(DAG->isSplatValue(Op, DemandedElts, UndefElts), true); + + // These bits should be ignored. + DemandedElts = APInt(16, 3); + EXPECT_EQ(DAG->isSplatValue(Op, DemandedElts, UndefElts), true); +} + +TEST_F(AArch64SelectionDAGTest, isSplatValue_Scalable_ADD_of_SPLAT_VECTOR) { + if (!TM) + return; + + TargetLowering TL(*TM); + + SDLoc Loc; + auto IntVT = EVT::getIntegerVT(Context, 8); + auto VecVT = EVT::getVectorVT(Context, IntVT, 16, true); + + // Should create SPLAT_VECTORS + auto Val1 = DAG->getConstant(1, Loc, VecVT); + auto Val2 = DAG->getConstant(3, Loc, VecVT); + EXPECT_EQ(Val1->getOpcode(), ISD::SPLAT_VECTOR); + auto Op = DAG->getNode(ISD::ADD, Loc, VecVT, Val1, Val2); + + EXPECT_EQ(DAG->isSplatValue(Op, false), true); + + APInt UndefElts; + APInt DemandedElts; + EXPECT_EQ(DAG->isSplatValue(Op, DemandedElts, UndefElts), true); + + // These bits should be ignored. + DemandedElts = APInt(16, 3); + EXPECT_EQ(DAG->isSplatValue(Op, DemandedElts, UndefElts), true); +} + } // end anonymous namespace