diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -2276,15 +2276,42 @@ } /// isSplatValue - Return true if the vector V has the same value -/// across all DemandedElts. +/// across all DemandedElts. For scalable vectors it does not make +/// sense to specify which elements are demanded or undefined, therefore +/// they are simply ignored. bool SelectionDAG::isSplatValue(SDValue V, const APInt &DemandedElts, APInt &UndefElts) { - if (!DemandedElts) - return false; // No demanded elts, better to assume we don't know anything. - EVT VT = V.getValueType(); assert(VT.isVector() && "Vector type expected"); + if (!VT.isScalableVector() && !DemandedElts) + return false; // No demanded elts, better to assume we don't know anything. + + // Deal with some common cases here that work for both fixed and scalable + // vector types. + switch (V.getOpcode()) { + case ISD::SPLAT_VECTOR: + return true; + case ISD::ADD: + case ISD::SUB: + case ISD::AND: { + APInt UndefLHS, UndefRHS; + SDValue LHS = V.getOperand(0); + SDValue RHS = V.getOperand(1); + if (isSplatValue(LHS, DemandedElts, UndefLHS) && + isSplatValue(RHS, DemandedElts, UndefRHS)) { + UndefElts = UndefLHS | UndefRHS; + return true; + } + break; + } + } + + // We don't support other cases than those above for scalable vectors at + // the moment. + if (VT.isScalableVector()) + return false; + unsigned NumElts = VT.getVectorNumElements(); assert(NumElts == DemandedElts.getBitWidth() && "Vector size mismatch"); UndefElts = APInt::getNullValue(NumElts); @@ -2341,19 +2368,6 @@ } break; } - case ISD::ADD: - case ISD::SUB: - case ISD::AND: { - APInt UndefLHS, UndefRHS; - SDValue LHS = V.getOperand(0); - SDValue RHS = V.getOperand(1); - if (isSplatValue(LHS, DemandedElts, UndefLHS) && - isSplatValue(RHS, DemandedElts, UndefRHS)) { - UndefElts = UndefLHS | UndefRHS; - return true; - } - break; - } } return false; @@ -2363,10 +2377,13 @@ bool SelectionDAG::isSplatValue(SDValue V, bool AllowUndefs) { EVT VT = V.getValueType(); assert(VT.isVector() && "Vector type expected"); - unsigned NumElts = VT.getVectorNumElements(); APInt UndefElts; - APInt DemandedElts = APInt::getAllOnesValue(NumElts); + APInt DemandedElts; + + // For now we don't support this with scalable vectors. + if (!VT.isScalableVector()) + DemandedElts = APInt::getAllOnesValue(VT.getVectorNumElements()); return isSplatValue(V, DemandedElts, UndefElts) && (AllowUndefs || !UndefElts); } @@ -2379,19 +2396,35 @@ switch (Opcode) { default: { APInt UndefElts; - APInt DemandedElts = APInt::getAllOnesValue(VT.getVectorNumElements()); + APInt DemandedElts; + + if (!VT.isScalableVector()) + DemandedElts = APInt::getAllOnesValue(VT.getVectorNumElements()); + if (isSplatValue(V, DemandedElts, UndefElts)) { - // Handle case where all demanded elements are UNDEF. - if (DemandedElts.isSubsetOf(UndefElts)) { + if (VT.isScalableVector()) { + // DemandedElts and UndefElts are ignored for scalable vectors, since + // the only supported cases are SPLAT_VECTOR nodes. SplatIdx = 0; - return getUNDEF(VT); + } else { + // Handle case where all demanded elements are UNDEF. + if (DemandedElts.isSubsetOf(UndefElts)) { + SplatIdx = 0; + return getUNDEF(VT); + } + SplatIdx = (UndefElts & DemandedElts).countTrailingOnes(); } - SplatIdx = (UndefElts & DemandedElts).countTrailingOnes(); return V; } break; } + case ISD::SPLAT_VECTOR: + SplatIdx = 0; + return V; case ISD::VECTOR_SHUFFLE: { + if (VT.isScalableVector()) + return SDValue(); + // Check if this is a shuffle node doing a splat. // TODO - remove this and rely purely on SelectionDAG::isSplatValue, // getTargetVShiftNode currently struggles without the splat source. diff --git a/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp b/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp --- a/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp +++ b/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp @@ -199,4 +199,182 @@ EXPECT_EQ(Known.One, APInt(8, 0x1)); } +TEST_F(AArch64SelectionDAGTest, isSplatValue_Fixed_BUILD_VECTOR) { + if (!TM) + return; + + TargetLowering TL(*TM); + + SDLoc Loc; + auto IntVT = EVT::getIntegerVT(Context, 8); + auto VecVT = EVT::getVectorVT(Context, IntVT, 16, false); + // Create a BUILD_VECTOR + SDValue Op = DAG->getConstant(1, Loc, VecVT); + EXPECT_EQ(Op->getOpcode(), ISD::BUILD_VECTOR); + EXPECT_TRUE(DAG->isSplatValue(Op, /*AllowUndefs=*/false)); + + APInt UndefElts; + APInt DemandedElts; + EXPECT_FALSE(DAG->isSplatValue(Op, DemandedElts, UndefElts)); + + // Width=16, Mask=3 + DemandedElts = APInt(16, 3); + EXPECT_TRUE(DAG->isSplatValue(Op, DemandedElts, UndefElts)); +} + +TEST_F(AArch64SelectionDAGTest, isSplatValue_Fixed_ADD_of_BUILD_VECTOR) { + if (!TM) + return; + + TargetLowering TL(*TM); + + SDLoc Loc; + auto IntVT = EVT::getIntegerVT(Context, 8); + auto VecVT = EVT::getVectorVT(Context, IntVT, 16, false); + + // Should create BUILD_VECTORs + SDValue Val1 = DAG->getConstant(1, Loc, VecVT); + SDValue Val2 = DAG->getConstant(3, Loc, VecVT); + EXPECT_EQ(Val1->getOpcode(), ISD::BUILD_VECTOR); + SDValue Op = DAG->getNode(ISD::ADD, Loc, VecVT, Val1, Val2); + + EXPECT_TRUE(DAG->isSplatValue(Op, /*AllowUndefs=*/false)); + + APInt UndefElts; + APInt DemandedElts; + EXPECT_FALSE(DAG->isSplatValue(Op, DemandedElts, UndefElts)); + + // Width=16, Mask=3 + DemandedElts = APInt(16, 3); + EXPECT_TRUE(DAG->isSplatValue(Op, DemandedElts, UndefElts)); +} + +TEST_F(AArch64SelectionDAGTest, isSplatValue_Scalable_SPLAT_VECTOR) { + if (!TM) + return; + + TargetLowering TL(*TM); + + SDLoc Loc; + auto IntVT = EVT::getIntegerVT(Context, 8); + auto VecVT = EVT::getVectorVT(Context, IntVT, 16, true); + // Create a SPLAT_VECTOR + SDValue Op = DAG->getConstant(1, Loc, VecVT); + EXPECT_EQ(Op->getOpcode(), ISD::SPLAT_VECTOR); + EXPECT_TRUE(DAG->isSplatValue(Op, /*AllowUndefs=*/false)); + + APInt UndefElts; + APInt DemandedElts; + EXPECT_TRUE(DAG->isSplatValue(Op, DemandedElts, UndefElts)); + + // Width=16, Mask=3. These bits should be ignored. + DemandedElts = APInt(16, 3); + EXPECT_TRUE(DAG->isSplatValue(Op, DemandedElts, UndefElts)); +} + +TEST_F(AArch64SelectionDAGTest, isSplatValue_Scalable_ADD_of_SPLAT_VECTOR) { + if (!TM) + return; + + TargetLowering TL(*TM); + + SDLoc Loc; + auto IntVT = EVT::getIntegerVT(Context, 8); + auto VecVT = EVT::getVectorVT(Context, IntVT, 16, true); + + // Should create SPLAT_VECTORS + SDValue Val1 = DAG->getConstant(1, Loc, VecVT); + SDValue Val2 = DAG->getConstant(3, Loc, VecVT); + EXPECT_EQ(Val1->getOpcode(), ISD::SPLAT_VECTOR); + SDValue Op = DAG->getNode(ISD::ADD, Loc, VecVT, Val1, Val2); + + EXPECT_TRUE(DAG->isSplatValue(Op, /*AllowUndefs=*/false)); + + APInt UndefElts; + APInt DemandedElts; + EXPECT_TRUE(DAG->isSplatValue(Op, DemandedElts, UndefElts)); + + // Width=16, Mask=3. These bits should be ignored. + DemandedElts = APInt(16, 3); + EXPECT_TRUE(DAG->isSplatValue(Op, DemandedElts, UndefElts)); +} + +TEST_F(AArch64SelectionDAGTest, getSplatSourceVector_Fixed_BUILD_VECTOR) { + if (!TM) + return; + + TargetLowering TL(*TM); + + SDLoc Loc; + auto IntVT = EVT::getIntegerVT(Context, 8); + auto VecVT = EVT::getVectorVT(Context, IntVT, 16, false); + // Create a BUILD_VECTOR + SDValue Op = DAG->getConstant(1, Loc, VecVT); + EXPECT_EQ(Op->getOpcode(), ISD::BUILD_VECTOR); + + int SplatIdx = -1; + EXPECT_EQ(DAG->getSplatSourceVector(Op, SplatIdx), Op); + EXPECT_EQ(SplatIdx, 0); +} + +TEST_F(AArch64SelectionDAGTest, getSplatSourceVector_Fixed_ADD_of_BUILD_VECTOR) { + if (!TM) + return; + + TargetLowering TL(*TM); + + SDLoc Loc; + auto IntVT = EVT::getIntegerVT(Context, 8); + auto VecVT = EVT::getVectorVT(Context, IntVT, 16, false); + + // Should create BUILD_VECTORs + SDValue Val1 = DAG->getConstant(1, Loc, VecVT); + SDValue Val2 = DAG->getConstant(3, Loc, VecVT); + EXPECT_EQ(Val1->getOpcode(), ISD::BUILD_VECTOR); + SDValue Op = DAG->getNode(ISD::ADD, Loc, VecVT, Val1, Val2); + + int SplatIdx = -1; + EXPECT_EQ(DAG->getSplatSourceVector(Op, SplatIdx), Op); + EXPECT_EQ(SplatIdx, 0); +} + +TEST_F(AArch64SelectionDAGTest, getSplatSourceVector_Scalable_SPLAT_VECTOR) { + if (!TM) + return; + + TargetLowering TL(*TM); + + SDLoc Loc; + auto IntVT = EVT::getIntegerVT(Context, 8); + auto VecVT = EVT::getVectorVT(Context, IntVT, 16, true); + // Create a SPLAT_VECTOR + SDValue Op = DAG->getConstant(1, Loc, VecVT); + EXPECT_EQ(Op->getOpcode(), ISD::SPLAT_VECTOR); + + int SplatIdx = -1; + EXPECT_EQ(DAG->getSplatSourceVector(Op, SplatIdx), Op); + EXPECT_EQ(SplatIdx, 0); +} + +TEST_F(AArch64SelectionDAGTest, getSplatSourceVector_Scalable_ADD_of_SPLAT_VECTOR) { + if (!TM) + return; + + TargetLowering TL(*TM); + + SDLoc Loc; + auto IntVT = EVT::getIntegerVT(Context, 8); + auto VecVT = EVT::getVectorVT(Context, IntVT, 16, true); + + // Should create SPLAT_VECTORS + SDValue Val1 = DAG->getConstant(1, Loc, VecVT); + SDValue Val2 = DAG->getConstant(3, Loc, VecVT); + EXPECT_EQ(Val1->getOpcode(), ISD::SPLAT_VECTOR); + SDValue Op = DAG->getNode(ISD::ADD, Loc, VecVT, Val1, Val2); + + int SplatIdx = -1; + EXPECT_EQ(DAG->getSplatSourceVector(Op, SplatIdx), Op); + EXPECT_EQ(SplatIdx, 0); +} + } // end anonymous namespace