diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -634,16 +634,10 @@ bool AssumeSingleUse) const { EVT VT = Op.getValueType(); - // TODO: We can probably do more work on calculating the known bits and - // simplifying the operations for scalable vectors, but for now we just - // bail out. - if (VT.isScalableVector()) { - // Pretend we don't know anything for now. - Known = KnownBits(DemandedBits.getBitWidth()); - return false; - } - - APInt DemandedElts = VT.isVector() + // Since the number of lanes in a scalable vector is unknown at compile time, + // we track one bit which is implicitly broadcast to all lanes. This means + // that all lanes in a scalable vector are considered demanded. + APInt DemandedElts = VT.isFixedLengthVector() ? APInt::getAllOnes(VT.getVectorNumElements()) : APInt(1, 1); return SimplifyDemandedBits(Op, DemandedBits, DemandedElts, Known, TLO, Depth, @@ -656,12 +650,6 @@ SelectionDAG &DAG, unsigned Depth) const { EVT VT = Op.getValueType(); - // Pretend we don't know anything about scalable vectors for now. - // TODO: We can probably do more work on simplifying the operations for - // scalable vectors, but for now we just bail out. - if (VT.isScalableVector()) - return SDValue(); - // Limit search depth. if (Depth >= SelectionDAG::MaxRecursionDepth) return SDValue(); @@ -680,6 +668,9 @@ KnownBits LHSKnown, RHSKnown; switch (Op.getOpcode()) { case ISD::BITCAST: { + if (VT.isScalableVector()) + return SDValue(); + SDValue Src = peekThroughBitcasts(Op.getOperand(0)); EVT SrcVT = Src.getValueType(); EVT DstVT = Op.getValueType(); @@ -825,6 +816,9 @@ case ISD::ANY_EXTEND_VECTOR_INREG: case ISD::SIGN_EXTEND_VECTOR_INREG: case ISD::ZERO_EXTEND_VECTOR_INREG: { + if (VT.isScalableVector()) + return SDValue(); + // If we only want the lowest element and none of extended bits, then we can // return the bitcasted source vector. SDValue Src = Op.getOperand(0); @@ -838,6 +832,9 @@ break; } case ISD::INSERT_VECTOR_ELT: { + if (VT.isScalableVector()) + return SDValue(); + // If we don't demand the inserted element, return the base vector. SDValue Vec = Op.getOperand(0); auto *CIdx = dyn_cast(Op.getOperand(2)); @@ -848,6 +845,9 @@ break; } case ISD::INSERT_SUBVECTOR: { + if (VT.isScalableVector()) + return SDValue(); + SDValue Vec = Op.getOperand(0); SDValue Sub = Op.getOperand(1); uint64_t Idx = Op.getConstantOperandVal(2); @@ -868,6 +868,7 @@ break; } case ISD::VECTOR_SHUFFLE: { + assert(!VT.isScalableVector()); ArrayRef ShuffleMask = cast(Op)->getMask(); // If all the demanded elts are from one operand and are inline, @@ -891,6 +892,11 @@ break; } default: + // TODO: Probably okay to remove after audit; here to reduce change size + // in initial enablement patch for scalable vectors + if (VT.isScalableVector()) + return SDValue(); + if (Op.getOpcode() >= ISD::BUILTIN_OP_END) if (SDValue V = SimplifyMultipleUseDemandedBitsForTargetNode( Op, DemandedBits, DemandedElts, DAG, Depth)) @@ -904,14 +910,10 @@ SDValue Op, const APInt &DemandedBits, SelectionDAG &DAG, unsigned Depth) const { EVT VT = Op.getValueType(); - - // Pretend we don't know anything about scalable vectors for now. - // TODO: We can probably do more work on simplifying the operations for - // scalable vectors, but for now we just bail out. - if (VT.isScalableVector()) - return SDValue(); - - APInt DemandedElts = VT.isVector() + // Since the number of lanes in a scalable vector is unknown at compile time, + // we track one bit which is implicitly broadcast to all lanes. This means + // that all lanes in a scalable vector are considered demanded. + APInt DemandedElts = VT.isFixedLengthVector() ? APInt::getAllOnes(VT.getVectorNumElements()) : APInt(1, 1); return SimplifyMultipleUseDemandedBits(Op, DemandedBits, DemandedElts, DAG, @@ -1070,16 +1072,10 @@ // Don't know anything. Known = KnownBits(BitWidth); - // TODO: We can probably do more work on calculating the known bits and - // simplifying the operations for scalable vectors, but for now we just - // bail out. EVT VT = Op.getValueType(); - if (VT.isScalableVector()) - return false; - bool IsLE = TLO.DAG.getDataLayout().isLittleEndian(); unsigned NumElts = OriginalDemandedElts.getBitWidth(); - assert((!VT.isVector() || NumElts == VT.getVectorNumElements()) && + assert((!VT.isFixedLengthVector() || NumElts == VT.getVectorNumElements()) && "Unexpected vector size"); APInt DemandedBits = OriginalDemandedBits; @@ -1130,6 +1126,8 @@ KnownBits Known2; switch (Op.getOpcode()) { case ISD::SCALAR_TO_VECTOR: { + if (VT.isScalableVector()) + return false; if (!DemandedElts[0]) return TLO.CombineTo(Op, TLO.DAG.getUNDEF(VT)); @@ -1167,6 +1165,8 @@ break; } case ISD::INSERT_VECTOR_ELT: { + if (VT.isScalableVector()) + return false; SDValue Vec = Op.getOperand(0); SDValue Scl = Op.getOperand(1); auto *CIdx = dyn_cast(Op.getOperand(2)); @@ -1203,6 +1203,8 @@ return false; } case ISD::INSERT_SUBVECTOR: { + if (VT.isScalableVector()) + return false; // Demand any elements from the subvector and the remainder from the src its // inserted into. SDValue Src = Op.getOperand(0); @@ -1246,6 +1248,8 @@ break; } case ISD::EXTRACT_SUBVECTOR: { + if (VT.isScalableVector()) + return false; // Offset the demanded elts by the subvector index. SDValue Src = Op.getOperand(0); if (Src.getValueType().isScalableVector()) @@ -1271,6 +1275,8 @@ break; } case ISD::CONCAT_VECTORS: { + if (VT.isScalableVector()) + return false; Known.Zero.setAllBits(); Known.One.setAllBits(); EVT SubVT = Op.getOperand(0).getValueType(); @@ -1289,6 +1295,7 @@ break; } case ISD::VECTOR_SHUFFLE: { + assert(!VT.isScalableVector()); ArrayRef ShuffleMask = cast(Op)->getMask(); // Collect demanded elements from shuffle operands.. @@ -1366,7 +1373,7 @@ // AND(INSERT_SUBVECTOR(C,X,I),M) -> INSERT_SUBVECTOR(AND(C,M),X,I) // iff 'C' is Undef/Constant and AND(X,M) == X (for DemandedBits). - if (Op0.getOpcode() == ISD::INSERT_SUBVECTOR && + if (Op0.getOpcode() == ISD::INSERT_SUBVECTOR && !VT.isScalableVector() && (Op0.getOperand(0).isUndef() || ISD::isBuildVectorOfConstantSDNodes(Op0.getOperand(0).getNode())) && Op0->hasOneUse()) { @@ -2226,12 +2233,15 @@ Known = KnownHi.concat(KnownLo); break; } - case ISD::ZERO_EXTEND: - case ISD::ZERO_EXTEND_VECTOR_INREG: { + case ISD::ZERO_EXTEND_VECTOR_INREG: + if (VT.isScalableVector()) + return false; + [[fallthrough]]; + case ISD::ZERO_EXTEND: { SDValue Src = Op.getOperand(0); EVT SrcVT = Src.getValueType(); unsigned InBits = SrcVT.getScalarSizeInBits(); - unsigned InElts = SrcVT.isVector() ? SrcVT.getVectorNumElements() : 1; + unsigned InElts = SrcVT.isFixedLengthVector() ? SrcVT.getVectorNumElements() : 1; bool IsVecInReg = Op.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG; // If none of the top bits are demanded, convert this into an any_extend. @@ -2263,12 +2273,15 @@ return TLO.CombineTo(Op, TLO.DAG.getNode(Op.getOpcode(), dl, VT, NewSrc)); break; } - case ISD::SIGN_EXTEND: - case ISD::SIGN_EXTEND_VECTOR_INREG: { + case ISD::SIGN_EXTEND_VECTOR_INREG: + if (VT.isScalableVector()) + return false; + [[fallthrough]]; + case ISD::SIGN_EXTEND: { SDValue Src = Op.getOperand(0); EVT SrcVT = Src.getValueType(); unsigned InBits = SrcVT.getScalarSizeInBits(); - unsigned InElts = SrcVT.isVector() ? SrcVT.getVectorNumElements() : 1; + unsigned InElts = SrcVT.isFixedLengthVector() ? SrcVT.getVectorNumElements() : 1; bool IsVecInReg = Op.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG; // If none of the top bits are demanded, convert this into an any_extend. @@ -2315,12 +2328,15 @@ return TLO.CombineTo(Op, TLO.DAG.getNode(Op.getOpcode(), dl, VT, NewSrc)); break; } - case ISD::ANY_EXTEND: - case ISD::ANY_EXTEND_VECTOR_INREG: { + case ISD::ANY_EXTEND_VECTOR_INREG: + if (VT.isScalableVector()) + return false; + [[fallthrough]]; + case ISD::ANY_EXTEND: { SDValue Src = Op.getOperand(0); EVT SrcVT = Src.getValueType(); unsigned InBits = SrcVT.getScalarSizeInBits(); - unsigned InElts = SrcVT.isVector() ? SrcVT.getVectorNumElements() : 1; + unsigned InElts = SrcVT.isFixedLengthVector() ? SrcVT.getVectorNumElements() : 1; bool IsVecInReg = Op.getOpcode() == ISD::ANY_EXTEND_VECTOR_INREG; // If we only need the bottom element then we can just bitcast. @@ -2459,6 +2475,8 @@ break; } case ISD::BITCAST: { + if (VT.isScalableVector()) + return false; SDValue Src = Op.getOperand(0); EVT SrcVT = Src.getValueType(); unsigned NumSrcEltBits = SrcVT.getScalarSizeInBits(); @@ -2680,6 +2698,10 @@ // We also ask the target about intrinsics (which could be specific to it). if (Op.getOpcode() >= ISD::BUILTIN_OP_END || Op.getOpcode() == ISD::INTRINSIC_WO_CHAIN) { + // TODO: Probably okay to remove after audit; here to reduce change size + // in initial enablement patch for scalable vectors + if (Op.getValueType().isScalableVector()) + break; if (SimplifyDemandedBitsForTargetNode(Op, DemandedBits, DemandedElts, Known, TLO, Depth)) return true; @@ -2749,7 +2771,7 @@ "Vector binop only"); EVT EltVT = VT.getVectorElementType(); - unsigned NumElts = VT.getVectorNumElements(); + unsigned NumElts = VT.isFixedLengthVector() ? VT.getVectorNumElements() : 1; assert(UndefOp0.getBitWidth() == NumElts && UndefOp1.getBitWidth() == NumElts && "Bad type for undef analysis"); diff --git a/llvm/test/CodeGen/AArch64/active_lane_mask.ll b/llvm/test/CodeGen/AArch64/active_lane_mask.ll --- a/llvm/test/CodeGen/AArch64/active_lane_mask.ll +++ b/llvm/test/CodeGen/AArch64/active_lane_mask.ll @@ -113,14 +113,13 @@ ; CHECK: // %bb.0: ; CHECK-NEXT: and w8, w0, #0xff ; CHECK-NEXT: index z0.s, #0, #1 +; CHECK-NEXT: and w9, w1, #0xff ; CHECK-NEXT: and z0.s, z0.s, #0xff ; CHECK-NEXT: ptrue p0.s ; CHECK-NEXT: mov z1.s, w8 -; CHECK-NEXT: and w8, w1, #0xff ; CHECK-NEXT: add z0.s, z0.s, z1.s +; CHECK-NEXT: mov z1.s, w9 ; CHECK-NEXT: umin z0.s, z0.s, #255 -; CHECK-NEXT: and z0.s, z0.s, #0xff -; CHECK-NEXT: mov z1.s, w8 ; CHECK-NEXT: cmphi p0.s, p0/z, z1.s, z0.s ; CHECK-NEXT: ret %active.lane.mask = call @llvm.get.active.lane.mask.nxv4i1.i8(i8 %index, i8 %TC) @@ -132,17 +131,16 @@ ; CHECK: // %bb.0: ; CHECK-NEXT: // kill: def $w0 killed $w0 def $x0 ; CHECK-NEXT: and x8, x0, #0xff -; CHECK-NEXT: index z0.d, #0, #1 ; CHECK-NEXT: // kill: def $w1 killed $w1 def $x1 ; CHECK-NEXT: and x9, x1, #0xff -; CHECK-NEXT: and z0.d, z0.d, #0xff +; CHECK-NEXT: index z0.d, #0, #1 ; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: and z0.d, z0.d, #0xff ; CHECK-NEXT: mov z1.d, x8 +; CHECK-NEXT: mov z2.d, x9 ; CHECK-NEXT: add z0.d, z0.d, z1.d -; CHECK-NEXT: mov z1.d, x9 ; CHECK-NEXT: umin z0.d, z0.d, #255 -; CHECK-NEXT: and z0.d, z0.d, #0xff -; CHECK-NEXT: cmphi p0.d, p0/z, z1.d, z0.d +; CHECK-NEXT: cmphi p0.d, p0/z, z2.d, z0.d ; CHECK-NEXT: ret %active.lane.mask = call @llvm.get.active.lane.mask.nxv2i1.i8(i8 %index, i8 %TC) ret %active.lane.mask diff --git a/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp b/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp --- a/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp +++ b/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp @@ -224,11 +224,15 @@ SDValue Op = DAG->getNode(ISD::AND, Loc, InVecVT, N0, Mask2V); + // N0 = ?000?0?0 + // Mask2V = 01010101 + // => + // Known.Zero = 00100000 (0xAA) KnownBits Known; APInt DemandedBits = APInt(8, 0xFF); TargetLowering::TargetLoweringOpt TLO(*DAG, false, false); - EXPECT_FALSE(TL.SimplifyDemandedBits(Op, DemandedBits, Known, TLO)); - EXPECT_EQ(Known.Zero, APInt(8, 0)); + EXPECT_TRUE(TL.SimplifyDemandedBits(Op, DemandedBits, Known, TLO)); + EXPECT_EQ(Known.Zero, APInt(8, 0xAA)); } // Piggy-backing on the AArch64 tests to verify SelectionDAG::computeKnownBits.