diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -10206,7 +10206,7 @@ // Check if a vector is built from one vector via extracted elements of // another together with an AND mask, ensuring that all elements fit // within range. This can be reconstructed using AND and NEON's TBL1. -SDValue ReconstructShuffleWithConstantAndMask(SDValue Op, SelectionDAG &DAG) { +SDValue ReconstructShuffleWithRuntimeMask(SDValue Op, SelectionDAG &DAG) { assert(Op.getOpcode() == ISD::BUILD_VECTOR && "Unknown opcode!"); SDLoc dl(Op); EVT VT = Op.getValueType(); @@ -10215,7 +10215,7 @@ // Can only recreate a shuffle with 16xi8 or 8xi8 elements, as they map // directly to TBL1. - if (VT.getSimpleVT() != MVT::v16i8 && VT.getSimpleVT() != MVT::v8i8) + if (VT != MVT::v16i8 && VT != MVT::v8i8) return SDValue(); unsigned NumElts = VT.getVectorNumElements(); @@ -10231,24 +10231,26 @@ if (V.getOpcode() != ISD::EXTRACT_VECTOR_ELT) return SDValue(); - // This only looks at shuffles with elements that are truncated by a - // constant AND mask extracted from a mask vector. - SDValue Operand = V.getOperand(1); - if (Operand.getOpcode() != ISD::AND || - !isa(Operand.getOperand(1))) - return SDValue(); - - ConstantSDNode *ConstantNode = cast(Operand.getOperand(1)); - AndMaskConstants.push_back(SDValue(ConstantNode, 0)); - SDValue OperandSourceVec = V.getOperand(0); if (!SourceVec) SourceVec = OperandSourceVec; else if (SourceVec != OperandSourceVec) return SDValue(); - // Find source vector of mask to use later in TBL. - SDValue MaskSource = Operand.getOperand(0); + // This only looks at shuffles with elements that are + // a) truncated by a constant AND mask extracted from a mask vector, or + // b) extracted directly from a mask vector. + SDValue MaskSource = V.getOperand(1); + if (MaskSource.getOpcode() == ISD::AND) { + if (!isa(MaskSource.getOperand(1))) + return SDValue(); + + AndMaskConstants.push_back(MaskSource.getOperand(1)); + MaskSource = MaskSource->getOperand(0); + } else if (!AndMaskConstants.empty()) { + // Either all or no operands should have an AND mask. + return SDValue(); + } // An ANY_EXTEND may be inserted between the AND and the source vector // extraction. We don't care about that, so we can just skip it. @@ -10258,6 +10260,11 @@ if (MaskSource.getOpcode() != ISD::EXTRACT_VECTOR_ELT) return SDValue(); + SDValue MaskIdx = MaskSource.getOperand(1); + if (!isa(MaskIdx) || + !cast(MaskIdx)->getConstantIntValue()->equalsInt(i)) + return SDValue(); + // We only apply this if all elements come from the same vector with the // same vector type. if (!MaskSourceVec) { @@ -10272,19 +10279,20 @@ // We need a v16i8 for TBL, so we extend the source with a placeholder vector // for v8i8 to get a v16i8. As the pattern we are replacing is extract + // insert, we know that the index in the mask must be smaller than the number - // of elements in the source, or we would have an out-of-bounds access. So we - // can simply duplicate the source vector. + // of elements in the source, or we would have an out-of-bounds access. if (NumElts == 8) - SourceVec = - DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v16i8, SourceVec, SourceVec); + SourceVec = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v16i8, SourceVec, + DAG.getUNDEF(VT)); + + // Preconditions met, so we can use a vector (AND +) TBL to build this vector. + if (!AndMaskConstants.empty()) + MaskSourceVec = DAG.getNode(ISD::AND, dl, VT, MaskSourceVec, + DAG.getBuildVector(VT, dl, AndMaskConstants)); - // Preconditions met, so we can use a vector AND + TBL to build this vector. - SDValue AndMask = DAG.getBuildVector(VT, dl, AndMaskConstants); - SDValue MaskedVec = DAG.getNode(ISD::AND, dl, VT, MaskSourceVec, AndMask); return DAG.getNode( ISD::INTRINSIC_WO_CHAIN, dl, VT, DAG.getConstant(Intrinsic::aarch64_neon_tbl1, dl, MVT::i32), SourceVec, - MaskedVec); + MaskSourceVec); } // Gather data to see if the operation can be modelled as a @@ -12431,7 +12439,7 @@ if (SDValue Shuffle = ReconstructShuffle(Op, DAG)) return Shuffle; - if (SDValue Shuffle = ReconstructShuffleWithConstantAndMask(Op, DAG)) + if (SDValue Shuffle = ReconstructShuffleWithRuntimeMask(Op, DAG)) return Shuffle; } diff --git a/llvm/test/CodeGen/AArch64/neon-shuffle-vector-tbl.ll b/llvm/test/CodeGen/AArch64/neon-shuffle-vector-tbl.ll --- a/llvm/test/CodeGen/AArch64/neon-shuffle-vector-tbl.ll +++ b/llvm/test/CodeGen/AArch64/neon-shuffle-vector-tbl.ll @@ -66,9 +66,8 @@ define <8 x i8> @shuffle8_with_and_mask(<8 x i8> %src, <8 x i8> %mask) { ; CHECK-LABEL: shuffle8_with_and_mask: ; CHECK: // %bb.0: -; CHECK-NEXT: // kill: def $d0 killed $d0 def $q0 ; CHECK-NEXT: movi.8b v2, #7 -; CHECK-NEXT: mov.d v0[1], v0[0] +; CHECK-NEXT: // kill: def $d0 killed $d0 def $q0 ; CHECK-NEXT: and.8b v1, v1, v2 ; CHECK-NEXT: tbl.8b v0, { v0 }, v1 ; CHECK-NEXT: ret @@ -101,6 +100,123 @@ ret <8 x i8> %24 } +define <8 x i8> @shuffle8_with_and_mask_different_constants(<8 x i8> %src, <8 x i8> %mask) { +; CHECK-LABEL: LCPI2_0: +; CHECK-NEXT: .byte 3 +; CHECK-NEXT: .byte 1 +; CHECK-NEXT: .byte 7 +; CHECK-NEXT: .byte 1 +; CHECK-NEXT: .byte 7 +; CHECK-NEXT: .byte 3 +; CHECK-NEXT: .byte 7 +; CHECK-NEXT: .byte 7 + +; CHECK-LABEL: shuffle8_with_and_mask_different_constants: +; CHECK: // %bb.0: +; CHECK-NEXT: adrp x8, .LCPI2_0 +; CHECK-NEXT: // kill: def $d0 killed $d0 def $q0 +; CHECK-NEXT: ldr d2, [x8, :lo12:.LCPI2_0] +; CHECK-NEXT: and.8b v1, v1, v2 +; CHECK-NEXT: tbl.8b v0, { v0 }, v1 +; CHECK-NEXT: ret + + %masked_mask = and <8 x i8> %mask, + %1 = extractelement <8 x i8> %masked_mask, i64 0 + %2 = extractelement <8 x i8> %src, i8 %1 + %3 = insertelement <8 x i8> undef, i8 %2, i64 0 + %4 = extractelement <8 x i8> %masked_mask, i64 1 + %5 = extractelement <8 x i8> %src, i8 %4 + %6 = insertelement <8 x i8> %3, i8 %5, i64 1 + %7 = extractelement <8 x i8> %masked_mask, i64 2 + %8 = extractelement <8 x i8> %src, i8 %7 + %9 = insertelement <8 x i8> %6, i8 %8, i64 2 + %10 = extractelement <8 x i8> %masked_mask, i64 3 + %11 = extractelement <8 x i8> %src, i8 %10 + %12 = insertelement <8 x i8> %9, i8 %11, i64 3 + %13 = extractelement <8 x i8> %masked_mask, i64 4 + %14 = extractelement <8 x i8> %src, i8 %13 + %15 = insertelement <8 x i8> %12, i8 %14, i64 4 + %16 = extractelement <8 x i8> %masked_mask, i64 5 + %17 = extractelement <8 x i8> %src, i8 %16 + %18 = insertelement <8 x i8> %15, i8 %17, i64 5 + %19 = extractelement <8 x i8> %masked_mask, i64 6 + %20 = extractelement <8 x i8> %src, i8 %19 + %21 = insertelement <8 x i8> %18, i8 %20, i64 6 + %22 = extractelement <8 x i8> %masked_mask, i64 7 + %23 = extractelement <8 x i8> %src, i8 %22 + %24 = insertelement <8 x i8> %21, i8 %23, i64 7 + ret <8 x i8> %24 +} + +define <8 x i8> @shuffle8_with_mask(<8 x i8> %src, <8 x i8> %mask) { +; CHECK-LABEL: shuffle8_with_mask: +; CHECK: // %bb.0: +; CHECK-NEXT: // kill: def $d0 killed $d0 def $q0 +; CHECK-NEXT: tbl.8b v0, { v0 }, v1 +; CHECK-NEXT: ret + + %1 = extractelement <8 x i8> %mask, i64 0 + %2 = extractelement <8 x i8> %src, i8 %1 + %3 = insertelement <8 x i8> undef, i8 %2, i64 0 + %4 = extractelement <8 x i8> %mask, i64 1 + %5 = extractelement <8 x i8> %src, i8 %4 + %6 = insertelement <8 x i8> %3, i8 %5, i64 1 + %7 = extractelement <8 x i8> %mask, i64 2 + %8 = extractelement <8 x i8> %src, i8 %7 + %9 = insertelement <8 x i8> %6, i8 %8, i64 2 + %10 = extractelement <8 x i8> %mask, i64 3 + %11 = extractelement <8 x i8> %src, i8 %10 + %12 = insertelement <8 x i8> %9, i8 %11, i64 3 + %13 = extractelement <8 x i8> %mask, i64 4 + %14 = extractelement <8 x i8> %src, i8 %13 + %15 = insertelement <8 x i8> %12, i8 %14, i64 4 + %16 = extractelement <8 x i8> %mask, i64 5 + %17 = extractelement <8 x i8> %src, i8 %16 + %18 = insertelement <8 x i8> %15, i8 %17, i64 5 + %19 = extractelement <8 x i8> %mask, i64 6 + %20 = extractelement <8 x i8> %src, i8 %19 + %21 = insertelement <8 x i8> %18, i8 %20, i64 6 + %22 = extractelement <8 x i8> %mask, i64 7 + %23 = extractelement <8 x i8> %src, i8 %22 + %24 = insertelement <8 x i8> %21, i8 %23, i64 7 + ret <8 x i8> %24 +} + +define <8 x i8> @no_shuffle_only_some_and_constants(<8 x i8> %src, <8 x i8> %mask) { +; CHECK-LABEL: no_shuffle_only_some_and_constants: +; CHECK: // %bb.0: +; CHECK-NOT: tbl.16b + + ; Element at 0 has a AND mask, element at 1 does not. + %1 = extractelement <8 x i8> %mask, i64 0 + %masked_elt1 = and i8 %1, 7 + %2 = extractelement <8 x i8> %src, i8 %masked_elt1 + %3 = insertelement <8 x i8> undef, i8 %2, i64 0 + %4 = extractelement <8 x i8> %mask, i64 1 + %5 = extractelement <8 x i8> %src, i8 %4 + %6 = insertelement <8 x i8> %3, i8 %5, i64 1 + + %7 = extractelement <8 x i8> %mask, i64 2 + %8 = extractelement <8 x i8> %src, i8 %7 + %9 = insertelement <8 x i8> %6, i8 %8, i64 2 + %10 = extractelement <8 x i8> %mask, i64 3 + %11 = extractelement <8 x i8> %src, i8 %10 + %12 = insertelement <8 x i8> %9, i8 %11, i64 3 + %13 = extractelement <8 x i8> %mask, i64 4 + %14 = extractelement <8 x i8> %src, i8 %13 + %15 = insertelement <8 x i8> %12, i8 %14, i64 4 + %16 = extractelement <8 x i8> %mask, i64 5 + %17 = extractelement <8 x i8> %src, i8 %16 + %18 = insertelement <8 x i8> %15, i8 %17, i64 5 + %19 = extractelement <8 x i8> %mask, i64 6 + %20 = extractelement <8 x i8> %src, i8 %19 + %21 = insertelement <8 x i8> %18, i8 %20, i64 6 + %22 = extractelement <8 x i8> %mask, i64 7 + %23 = extractelement <8 x i8> %src, i8 %22 + %24 = insertelement <8 x i8> %21, i8 %23, i64 7 + ret <8 x i8> %24 +} + ; Takes alternating entries from two mask source vectors. Currently not supported. define <16 x i8> @no_shuffle_with_two_mask_sources(<16 x i8> %src, <16 x i8> %mask1, <16 x i8> %mask2) { ; CHECK-LABEL: shuffle_with_two_mask_sources: @@ -235,3 +351,39 @@ %24 = insertelement <8 x i8> %21, i8 %23, i64 7 ret <8 x i8> %24 } + +define <8 x i8> @no_shuffle_bad_mask_index(<8 x i8> %src, <8 x i8> %mask) { +; CHECK-LABEL: no_shuffle_bad_mask_index: +; CHECK: // %bb.0: +; CHECK-NOT: tbl.16b + + %masked_mask = and <8 x i8> %mask, + + ; This should extract at 0, but because it extracts at 1, the pattern does not match. + %1 = extractelement <8 x i8> %masked_mask, i64 1 + + %2 = extractelement <8 x i8> %src, i8 %1 + %3 = insertelement <8 x i8> undef, i8 %2, i64 0 + %4 = extractelement <8 x i8> %masked_mask, i64 1 + %5 = extractelement <8 x i8> %src, i8 %4 + %6 = insertelement <8 x i8> %3, i8 %5, i64 1 + %7 = extractelement <8 x i8> %masked_mask, i64 2 + %8 = extractelement <8 x i8> %src, i8 %7 + %9 = insertelement <8 x i8> %6, i8 %8, i64 2 + %10 = extractelement <8 x i8> %masked_mask, i64 3 + %11 = extractelement <8 x i8> %src, i8 %10 + %12 = insertelement <8 x i8> %9, i8 %11, i64 3 + %13 = extractelement <8 x i8> %masked_mask, i64 4 + %14 = extractelement <8 x i8> %src, i8 %13 + %15 = insertelement <8 x i8> %12, i8 %14, i64 4 + %16 = extractelement <8 x i8> %masked_mask, i64 5 + %17 = extractelement <8 x i8> %src, i8 %16 + %18 = insertelement <8 x i8> %15, i8 %17, i64 5 + %19 = extractelement <8 x i8> %masked_mask, i64 6 + %20 = extractelement <8 x i8> %src, i8 %19 + %21 = insertelement <8 x i8> %18, i8 %20, i64 6 + %22 = extractelement <8 x i8> %masked_mask, i64 7 + %23 = extractelement <8 x i8> %src, i8 %22 + %24 = insertelement <8 x i8> %21, i8 %23, i64 7 + ret <8 x i8> %24 +}