Index: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -24692,6 +24692,89 @@ return Shuf->getOperand(0); } +/// Helper function for simplifyShuffleOfI1Concats. If the operand `OpIdx' of +/// `SV' is a concatenation of an i1 vector and undef, see if we can rewrite it +/// without the concatenation. +static SDValue simplifyConcatOfI1AndUndef(SelectionDAG &DAG, + ShuffleVectorSDNode const *SV, + int const OpIdx, + SmallVector &Mask) { + EVT VT = SV->getValueType(0); + unsigned NumElts = VT.getVectorNumElements(); + + SDValue N = SV->getOperand(OpIdx); + if (N.getOpcode() != ISD::CONCAT_VECTORS || N.getNumOperands() != 2) + return SDValue(); + + // NOTE: As long as the second operand is not used (which we check below), it + // doesn't really matter if the second operand is undef or not. + SDValue N0 = N.getOperand(0); // setcc|truncate + SDValue N1 = N.getOperand(1); // undef/unused + if (N0.getOpcode() != ISD::SETCC && N0.getOpcode() != ISD::TRUNCATE) + return SDValue(); + if (!N1.isUndef()) + return SDValue(); + + // We are restricting ourselves to shuffles of i1 vectors, but we could apply + // the same idea to other types. + if (N0.getValueType().getVectorElementType() != MVT::i1) + return SDValue(); + + // Check if elements of the second operand are accessed. + unsigned MaskBegin = OpIdx*NumElts; + unsigned MaskEnd = (OpIdx+1)*NumElts; + unsigned NumEltsOfN0 = N0.getValueType().getVectorNumElements(); + for (int i = 0, e = (int)Mask.size(); i != e; ++i) { + if (Mask[i] >= (int)(MaskBegin+NumEltsOfN0) && Mask[i] < (int)MaskEnd) + return SDValue(); + } + + // Check if we can do a bitcast from NumSrcBits to VT. + unsigned NumSrcBits = N0.getOperand(0).getValueSizeInBits(); + if (NumSrcBits % NumElts != 0 || NumElts % NumEltsOfN0 != 0) + return SDValue(); + + EVT NewST = EVT::getIntegerVT(*DAG.getContext(), NumSrcBits/NumElts); + EVT NewVT = EVT::getVectorVT(*DAG.getContext(), NewST, NumElts); + + SDValue NewN = N0.getOperand(0); + if (N0.getOpcode() == ISD::SETCC) + NewN = DAG.getSetCC(SDLoc(SV), N0.getOperand(0).getValueType(), + N0.getOperand(0), N0.getOperand(1), + cast(N0.getOperand(2))->get()); + + // The bitcast essentially interleaves new elements in the operand. Here we + // adjust the mask to account for the interleaved elements. + unsigned MaskFactor = NumElts/NumEltsOfN0; + for (int i = 0, e = (int)Mask.size(); i != e; ++i) { + if (Mask[i] >= (int)MaskBegin && Mask[i] < (int)MaskEnd) + Mask[i] = (Mask[i]-MaskBegin)*MaskFactor+MaskBegin; + } + + return DAG.getAnyExtOrTrunc(DAG.getBitcast(NewVT, NewN), SDLoc(SV), VT); +} + +/// If the operands of a shuffle are concatenations of i1 vectors and undef, +/// see if we can rewrite the operands without the concatenation by doing a +/// bitcast instead. This avoids BUILD_VECTORs during legalisation, which +/// ultimately leads to better codegen. +static SDValue simplifyShuffleOfI1Concats(SelectionDAG &DAG, + ShuffleVectorSDNode const *SV) { + EVT VT = SV->getValueType(0); + SmallVector Mask(SV->getMask()); + + SDValue Op0 = simplifyConcatOfI1AndUndef(DAG, SV, 0, Mask); + SDValue Op1 = simplifyConcatOfI1AndUndef(DAG, SV, 1, Mask); + + if (!Op0 && !Op1) + return SDValue(); + + if (!Op0) Op0 = SV->getOperand(0); + if (!Op1) Op1 = SV->getOperand(1); + + return DAG.getVectorShuffle(VT, SDLoc(SV), Op0, Op1, Mask); +} + SDValue DAGCombiner::visitVECTOR_SHUFFLE(SDNode *N) { EVT VT = N->getValueType(0); unsigned NumElts = VT.getVectorNumElements(); @@ -25332,6 +25415,11 @@ } } + // Try to replace the concats in shuffle (concat (i1, undef), concat (i1, + // undef), Mask) with bitcasts + if (SDValue V = simplifyShuffleOfI1Concats(DAG, SVN)) + return V; + if (SDValue V = foldShuffleOfConcatUndefs(SVN, DAG)) return V; Index: llvm/test/CodeGen/AArch64/aarch64-shuffle-of-i1s.ll =================================================================== --- llvm/test/CodeGen/AArch64/aarch64-shuffle-of-i1s.ll +++ llvm/test/CodeGen/AArch64/aarch64-shuffle-of-i1s.ll @@ -5,14 +5,9 @@ define <4 x i1> @t11(<2 x i1> %a, <2 x i1> %b) { ; CHECK-LABEL: t11: ; CHECK: // %bb.0: -; CHECK-NEXT: fmov d2, d0 -; CHECK-NEXT: // kill: def $d1 killed $d1 def $q1 -; CHECK-NEXT: mov w8, v1.s[1] -; CHECK-NEXT: mov v0.16b, v2.16b -; CHECK-NEXT: mov v0.h[1], w8 -; CHECK-NEXT: mov v0.h[2], w8 -; CHECK-NEXT: mov v0.h[3], v2.h[0] -; CHECK-NEXT: // kill: def $d0 killed $d0 killed $q0 +; CHECK-NEXT: ext v0.8b, v1.8b, v0.8b, #2 +; CHECK-NEXT: trn2 v0.4h, v0.4h, v0.4h +; CHECK-NEXT: ext v0.8b, v0.8b, v0.8b, #6 ; CHECK-NEXT: ret %r = shufflevector <2 x i1> %a, <2 x i1> %b, <4 x i32> ret <4 x i1> %r @@ -21,14 +16,9 @@ define <4 x i1> @t12(<2 x i32> %a, <2 x i32> %b) { ; CHECK-LABEL: t12: ; CHECK: // %bb.0: -; CHECK-NEXT: fmov d2, d0 -; CHECK-NEXT: // kill: def $d1 killed $d1 def $q1 -; CHECK-NEXT: mov w8, v1.s[1] -; CHECK-NEXT: mov v0.16b, v2.16b -; CHECK-NEXT: mov v0.h[1], w8 -; CHECK-NEXT: mov v0.h[2], w8 -; CHECK-NEXT: mov v0.h[3], v2.h[0] -; CHECK-NEXT: // kill: def $d0 killed $d0 killed $q0 +; CHECK-NEXT: ext v0.8b, v1.8b, v0.8b, #2 +; CHECK-NEXT: trn2 v0.4h, v0.4h, v0.4h +; CHECK-NEXT: ext v0.8b, v0.8b, v0.8b, #6 ; CHECK-NEXT: ret %a2 = trunc <2 x i32> %a to <2 x i1> %b2 = trunc <2 x i32> %b to <2 x i1> @@ -42,14 +32,10 @@ define <4 x i1> @t21(<2 x i1> %a, <2 x i1> %b) { ; CHECK-LABEL: t21: ; CHECK: // %bb.0: -; CHECK-NEXT: fmov d2, d0 +; CHECK-NEXT: // kill: def $d0 killed $d0 def $q0 ; CHECK-NEXT: // kill: def $d1 killed $d1 def $q1 -; CHECK-NEXT: mov w8, v1.s[1] -; CHECK-NEXT: mov v0.16b, v2.16b -; CHECK-NEXT: mov v0.h[1], w8 -; CHECK-NEXT: mov v0.h[2], v2.h[0] -; CHECK-NEXT: mov v0.h[3], w8 -; CHECK-NEXT: // kill: def $d0 killed $d0 killed $q0 +; CHECK-NEXT: mov v0.h[2], v1.h[2] +; CHECK-NEXT: uzp1 v0.4h, v0.4h, v0.4h ; CHECK-NEXT: ret %r = shufflevector <2 x i1> %a, <2 x i1> %b, <4 x i32> ret <4 x i1> %r @@ -58,14 +44,10 @@ define <4 x i1> @t22(<2 x i32> %a, <2 x i32> %b) { ; CHECK-LABEL: t22: ; CHECK: // %bb.0: -; CHECK-NEXT: fmov d2, d0 +; CHECK-NEXT: // kill: def $d0 killed $d0 def $q0 ; CHECK-NEXT: // kill: def $d1 killed $d1 def $q1 -; CHECK-NEXT: mov w8, v1.s[1] -; CHECK-NEXT: mov v0.16b, v2.16b -; CHECK-NEXT: mov v0.h[1], w8 -; CHECK-NEXT: mov v0.h[2], v2.h[0] -; CHECK-NEXT: mov v0.h[3], w8 -; CHECK-NEXT: // kill: def $d0 killed $d0 killed $q0 +; CHECK-NEXT: mov v0.h[2], v1.h[2] +; CHECK-NEXT: uzp1 v0.4h, v0.4h, v0.4h ; CHECK-NEXT: ret %a2 = trunc <2 x i32> %a to <2 x i1> %b2 = trunc <2 x i32> %b to <2 x i1> @@ -82,18 +64,13 @@ ; CHECK: // %bb.0: ; CHECK-NEXT: cmtst v2.2s, v0.2s, v0.2s ; CHECK-NEXT: cmeq v0.2s, v0.2s, #0 -; CHECK-NEXT: cmtst v3.2s, v1.2s, v1.2s +; CHECK-NEXT: ext v0.8b, v0.8b, v2.8b, #2 +; CHECK-NEXT: cmtst v2.2s, v1.2s, v1.2s ; CHECK-NEXT: cmeq v1.2s, v1.2s, #0 -; CHECK-NEXT: mov w8, v0.s[1] -; CHECK-NEXT: mov w9, v1.s[1] -; CHECK-NEXT: mov v0.16b, v2.16b -; CHECK-NEXT: mov v1.16b, v3.16b -; CHECK-NEXT: mov v0.h[1], w8 -; CHECK-NEXT: mov v1.h[1], w9 -; CHECK-NEXT: mov v0.h[2], w8 -; CHECK-NEXT: mov v1.h[2], v3.h[0] -; CHECK-NEXT: mov v0.h[3], v2.h[0] -; CHECK-NEXT: mov v1.h[3], w9 +; CHECK-NEXT: mov v2.h[2], v1.h[2] +; CHECK-NEXT: trn2 v0.4h, v0.4h, v0.4h +; CHECK-NEXT: uzp1 v1.4h, v2.4h, v2.4h +; CHECK-NEXT: ext v0.8b, v0.8b, v0.8b, #6 ; CHECK-NEXT: and v0.8b, v0.8b, v1.8b ; CHECK-NEXT: ret %a1 = icmp ne <2 x i32> %a, zeroinitializer