diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -40611,7 +40611,8 @@ const TargetLowering &TLI = DAG.getTargetLoweringInfo(); EVT ShuffleVT = N.getValueType(); - auto IsMergeableWithShuffle = [&DAG](SDValue Op, bool FoldLoad = false) { + auto IsMergeableWithShuffle = [&DAG](unsigned Opcode, unsigned OpNo, + SDValue Op, bool FoldLoad = false) { // AllZeros/AllOnes constants are freely shuffled and will peek through // bitcasts. Other constant build vectors do not peek through bitcasts. Only // merge with target shuffles if it has one use so shuffle combining is @@ -40632,6 +40633,32 @@ (Op.getScalarValueSizeInBits() <= ShuffleVT.getScalarSizeInBits()); }; + auto AreAllOperandMergeableWithShuffle = + [IsMergeableWithShuffle](unsigned Opcode, unsigned OpNo, + ArrayRef Ops, bool FoldLoad = false) { + return all_of( + Ops, [IsMergeableWithShuffle, Opcode, OpNo, FoldLoad](SDValue Op) { + return IsMergeableWithShuffle(Opcode, OpNo, Op, FoldLoad); + }); + }; + auto AreAnyOperandMergeableWithShuffle = + [IsMergeableWithShuffle](unsigned Opcode, unsigned OpNo, + ArrayRef Ops, bool FoldLoad = false) { + return any_of( + Ops, [IsMergeableWithShuffle, Opcode, OpNo, FoldLoad](SDValue Op) { + return IsMergeableWithShuffle(Opcode, OpNo, Op, FoldLoad); + }); + }; + + auto IsAnyOperandMergeableWithShuffle = + [IsMergeableWithShuffle](unsigned Opcode, ArrayRef Ops, + bool FoldLoad = false) { + return any_of(enumerate(Ops), [IsMergeableWithShuffle, Opcode, + FoldLoad](auto I) { + return IsMergeableWithShuffle(Opcode, I.index(), I.value(), FoldLoad); + }); + }; + unsigned Opc = N.getOpcode(); switch (Opc) { // Unary and Unary+Permute Shuffles. @@ -40656,25 +40683,25 @@ SDValue N0 = peekThroughOneUseBitcasts(N.getOperand(0)); unsigned SrcOpcode = N0.getOpcode(); if (TLI.isBinOp(SrcOpcode) && IsSafeToMoveShuffle(N0, SrcOpcode)) { - SDValue Op00 = peekThroughOneUseBitcasts(N0.getOperand(0)); - SDValue Op01 = peekThroughOneUseBitcasts(N0.getOperand(1)); - if (IsMergeableWithShuffle(Op00, Opc != X86ISD::PSHUFB) || - IsMergeableWithShuffle(Op01, Opc != X86ISD::PSHUFB)) { - SDValue LHS, RHS; - Op00 = DAG.getBitcast(ShuffleVT, Op00); - Op01 = DAG.getBitcast(ShuffleVT, Op01); - if (N.getNumOperands() == 2) { - LHS = DAG.getNode(Opc, DL, ShuffleVT, Op00, N.getOperand(1)); - RHS = DAG.getNode(Opc, DL, ShuffleVT, Op01, N.getOperand(1)); - } else { - LHS = DAG.getNode(Opc, DL, ShuffleVT, Op00); - RHS = DAG.getNode(Opc, DL, ShuffleVT, Op01); - } + assert(N0.getNumOperands() == 2 && "Not a binop?"); + std::array N0Ops; + for (unsigned OpIdx : seq(0U, 2U)) + N0Ops[OpIdx] = peekThroughOneUseBitcasts(N0.getOperand(OpIdx)); + if (IsAnyOperandMergeableWithShuffle(SrcOpcode, N0Ops, + /*FoldLoad=*/Opc != + X86ISD::PSHUFB)) { EVT OpVT = N0.getValueType(); + for (unsigned OpIdx : seq(0U, 2U)) { + SDValue &Op = N0Ops[OpIdx]; + Op = DAG.getBitcast(ShuffleVT, Op); + if (N.getNumOperands() == 2) + Op = DAG.getNode(Opc, DL, ShuffleVT, Op, N.getOperand(1)); + else + Op = DAG.getNode(Opc, DL, ShuffleVT, Op); + Op = DAG.getBitcast(OpVT, Op); + } return DAG.getBitcast(ShuffleVT, - DAG.getNode(SrcOpcode, DL, OpVT, - DAG.getBitcast(OpVT, LHS), - DAG.getBitcast(OpVT, RHS))); + DAG.getNode(SrcOpcode, DL, OpVT, N0Ops)); } } } @@ -40697,39 +40724,51 @@ case X86ISD::UNPCKL: { if (N->isOnlyUserOf(N.getOperand(0).getNode()) && N->isOnlyUserOf(N.getOperand(1).getNode())) { - SDValue N0 = peekThroughOneUseBitcasts(N.getOperand(0)); - SDValue N1 = peekThroughOneUseBitcasts(N.getOperand(1)); - unsigned SrcOpcode = N0.getOpcode(); - if (TLI.isBinOp(SrcOpcode) && N1.getOpcode() == SrcOpcode && - IsSafeToMoveShuffle(N0, SrcOpcode) && - IsSafeToMoveShuffle(N1, SrcOpcode)) { - SDValue Op00 = peekThroughOneUseBitcasts(N0.getOperand(0)); - SDValue Op10 = peekThroughOneUseBitcasts(N1.getOperand(0)); - SDValue Op01 = peekThroughOneUseBitcasts(N0.getOperand(1)); - SDValue Op11 = peekThroughOneUseBitcasts(N1.getOperand(1)); + std::array OpsOfShuf; + for (unsigned OpIdx : seq(0U, 2U)) + OpsOfShuf[OpIdx] = peekThroughOneUseBitcasts(N.getOperand(OpIdx)); + unsigned SrcOpcode = OpsOfShuf[0].getOpcode(); + if (TLI.isBinOp(SrcOpcode) && OpsOfShuf[1].getOpcode() == SrcOpcode && + IsSafeToMoveShuffle(OpsOfShuf[0], SrcOpcode) && + IsSafeToMoveShuffle(OpsOfShuf[1], SrcOpcode)) { + assert(OpsOfShuf[0].getNumOperands() == 2 && + OpsOfShuf[1].getNumOperands() == 2 && "Not binops?"); + std::array, 2> NthOpsOfShufOps; + for (unsigned OpIdx : seq(0U, 2U)) { + for (unsigned ShufOpIdx : seq(0U, 2U)) { + NthOpsOfShufOps[OpIdx][ShufOpIdx] = peekThroughOneUseBitcasts( + OpsOfShuf[ShufOpIdx].getOperand(OpIdx)); + } + } // Ensure the total number of shuffles doesn't increase by folding this // shuffle through to the source ops. - if (((IsMergeableWithShuffle(Op00) && IsMergeableWithShuffle(Op10)) || - (IsMergeableWithShuffle(Op01) && IsMergeableWithShuffle(Op11))) || - ((IsMergeableWithShuffle(Op00) || IsMergeableWithShuffle(Op10)) && - (IsMergeableWithShuffle(Op01) || IsMergeableWithShuffle(Op11)))) { - SDValue LHS, RHS; - Op00 = DAG.getBitcast(ShuffleVT, Op00); - Op10 = DAG.getBitcast(ShuffleVT, Op10); - Op01 = DAG.getBitcast(ShuffleVT, Op01); - Op11 = DAG.getBitcast(ShuffleVT, Op11); - if (N.getNumOperands() == 3) { - LHS = DAG.getNode(Opc, DL, ShuffleVT, Op00, Op10, N.getOperand(2)); - RHS = DAG.getNode(Opc, DL, ShuffleVT, Op01, Op11, N.getOperand(2)); - } else { - LHS = DAG.getNode(Opc, DL, ShuffleVT, Op00, Op10); - RHS = DAG.getNode(Opc, DL, ShuffleVT, Op01, Op11); + if (any_of(enumerate(NthOpsOfShufOps), + [&](auto I) { + return AreAllOperandMergeableWithShuffle( + SrcOpcode, I.index(), I.value()); + }) || + all_of(enumerate(NthOpsOfShufOps), [&](auto I) { + return AreAnyOperandMergeableWithShuffle(SrcOpcode, I.index(), + I.value()); + })) { + EVT OpVT = OpsOfShuf[0].getValueType(); + std::array NewOps; + for (unsigned OpIdx : seq(0U, 2U)) { + MutableArrayRef NthOpsOfOpsOfShufOp = + NthOpsOfShufOps[OpIdx]; + SDValue &NewOp = NewOps[OpIdx]; + for (SDValue &V : NthOpsOfOpsOfShufOp) + V = DAG.getBitcast(ShuffleVT, V); + if (N.getNumOperands() == 3) + NewOp = DAG.getNode(Opc, DL, ShuffleVT, NthOpsOfOpsOfShufOp[0], + NthOpsOfOpsOfShufOp[1], N.getOperand(2)); + else + NewOp = DAG.getNode(Opc, DL, ShuffleVT, NthOpsOfOpsOfShufOp[0], + NthOpsOfOpsOfShufOp[1]); + NewOp = DAG.getBitcast(OpVT, NewOp); } - EVT OpVT = N0.getValueType(); return DAG.getBitcast(ShuffleVT, - DAG.getNode(SrcOpcode, DL, OpVT, - DAG.getBitcast(OpVT, LHS), - DAG.getBitcast(OpVT, RHS))); + DAG.getNode(SrcOpcode, DL, OpVT, NewOps)); } } }