diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -40611,11 +40611,57 @@ const TargetLowering &TLI = DAG.getTargetLoweringInfo(); EVT ShuffleVT = N.getValueType(); - auto IsMergeableWithShuffle = [&DAG](SDValue Op, bool FoldLoad = false) { + auto IsBinOp = [&TLI](unsigned Opcode) { + switch (Opcode) { + case X86ISD::VSHL: + case X86ISD::VSHLI: + case X86ISD::VSRL: + case X86ISD::VSRLI: + case X86ISD::VSRA: + case X86ISD::VSRAI: + return true; + default: + return TLI.isBinOp(Opcode); + } + }; + auto IsMatchingBinOp = [](SDValue X, SDValue Y) { + if (X.getOpcode() != Y.getOpcode()) + return false; + switch (X.getOpcode()) { + case X86ISD::VSHL: + case X86ISD::VSHLI: + case X86ISD::VSRL: + case X86ISD::VSRLI: + case X86ISD::VSRA: + case X86ISD::VSRAI: + // SSE vector shifts must have matching (scalar) shift amounts. + return X.getOperand(1) == Y.getOperand(1); + default: + return true; + } + }; + auto IgnoreOp = [](unsigned Opcode, unsigned OpNo) { + switch (Opcode) { + case X86ISD::VSHL: + case X86ISD::VSHLI: + case X86ISD::VSRL: + case X86ISD::VSRLI: + case X86ISD::VSRA: + case X86ISD::VSRAI: + return OpNo == 1; + default: + return false; + } + }; + auto IsMergeableWithShuffle = [&DAG, &IgnoreOp](SDValue Src, unsigned OpNo, + bool FoldLoad = false) { // AllZeros/AllOnes constants are freely shuffled and will peek through // bitcasts. Other constant build vectors do not peek through bitcasts. Only // merge with target shuffles if it has one use so shuffle combining is // likely to kick in. Shuffles of splats are expected to be removed. + if (IgnoreOp(Src.getOpcode(), OpNo)) + return false; + SDValue Op = peekThroughOneUseBitcasts(Src.getOperand(OpNo)); return ISD::isBuildVectorAllOnes(Op.getNode()) || ISD::isBuildVectorAllZeros(Op.getNode()) || ISD::isBuildVectorOfConstantSDNodes(Op.getNode()) || @@ -40655,26 +40701,23 @@ N->isOnlyUserOf(N.getOperand(0).getNode())) { SDValue N0 = peekThroughOneUseBitcasts(N.getOperand(0)); unsigned SrcOpcode = N0.getOpcode(); - if (TLI.isBinOp(SrcOpcode) && IsSafeToMoveShuffle(N0, SrcOpcode)) { - SDValue Op00 = peekThroughOneUseBitcasts(N0.getOperand(0)); - SDValue Op01 = peekThroughOneUseBitcasts(N0.getOperand(1)); - if (IsMergeableWithShuffle(Op00, Opc != X86ISD::PSHUFB) || - IsMergeableWithShuffle(Op01, Opc != X86ISD::PSHUFB)) { - SDValue LHS, RHS; - Op00 = DAG.getBitcast(ShuffleVT, Op00); - Op01 = DAG.getBitcast(ShuffleVT, Op01); - if (N.getNumOperands() == 2) { - LHS = DAG.getNode(Opc, DL, ShuffleVT, Op00, N.getOperand(1)); - RHS = DAG.getNode(Opc, DL, ShuffleVT, Op01, N.getOperand(1)); - } else { - LHS = DAG.getNode(Opc, DL, ShuffleVT, Op00); - RHS = DAG.getNode(Opc, DL, ShuffleVT, Op01); + if (IsBinOp(SrcOpcode) && IsSafeToMoveShuffle(N0, SrcOpcode)) { + EVT OpVT = N0.getValueType(); + if (IsMergeableWithShuffle(N0, 0, Opc != X86ISD::PSHUFB) || + IsMergeableWithShuffle(N0, 1, Opc != X86ISD::PSHUFB)) { + SmallVector Ops(N0->ops()); + for (int i = 0; i != 2; ++i) { + if (IgnoreOp(SrcOpcode, i)) + continue; + Ops[i] = DAG.getBitcast(ShuffleVT, Ops[i]); + Ops[i] = + N.getNumOperands() == 2 + ? DAG.getNode(Opc, DL, ShuffleVT, Ops[i], N.getOperand(1)) + : DAG.getNode(Opc, DL, ShuffleVT, Ops[i]); + Ops[i] = DAG.getBitcast(OpVT, Ops[i]); } - EVT OpVT = N0.getValueType(); return DAG.getBitcast(ShuffleVT, - DAG.getNode(SrcOpcode, DL, OpVT, - DAG.getBitcast(OpVT, LHS), - DAG.getBitcast(OpVT, RHS))); + DAG.getNode(SrcOpcode, DL, OpVT, Ops)); } } } @@ -40700,36 +40743,30 @@ SDValue N0 = peekThroughOneUseBitcasts(N.getOperand(0)); SDValue N1 = peekThroughOneUseBitcasts(N.getOperand(1)); unsigned SrcOpcode = N0.getOpcode(); - if (TLI.isBinOp(SrcOpcode) && N1.getOpcode() == SrcOpcode && + if (IsBinOp(SrcOpcode) && IsMatchingBinOp(N0, N1) && IsSafeToMoveShuffle(N0, SrcOpcode) && IsSafeToMoveShuffle(N1, SrcOpcode)) { - SDValue Op00 = peekThroughOneUseBitcasts(N0.getOperand(0)); - SDValue Op10 = peekThroughOneUseBitcasts(N1.getOperand(0)); - SDValue Op01 = peekThroughOneUseBitcasts(N0.getOperand(1)); - SDValue Op11 = peekThroughOneUseBitcasts(N1.getOperand(1)); + EVT OpVT = N0.getValueType(); // Ensure the total number of shuffles doesn't increase by folding this // shuffle through to the source ops. - if (((IsMergeableWithShuffle(Op00) && IsMergeableWithShuffle(Op10)) || - (IsMergeableWithShuffle(Op01) && IsMergeableWithShuffle(Op11))) || - ((IsMergeableWithShuffle(Op00) || IsMergeableWithShuffle(Op10)) && - (IsMergeableWithShuffle(Op01) || IsMergeableWithShuffle(Op11)))) { - SDValue LHS, RHS; - Op00 = DAG.getBitcast(ShuffleVT, Op00); - Op10 = DAG.getBitcast(ShuffleVT, Op10); - Op01 = DAG.getBitcast(ShuffleVT, Op01); - Op11 = DAG.getBitcast(ShuffleVT, Op11); - if (N.getNumOperands() == 3) { - LHS = DAG.getNode(Opc, DL, ShuffleVT, Op00, Op10, N.getOperand(2)); - RHS = DAG.getNode(Opc, DL, ShuffleVT, Op01, Op11, N.getOperand(2)); - } else { - LHS = DAG.getNode(Opc, DL, ShuffleVT, Op00, Op10); - RHS = DAG.getNode(Opc, DL, ShuffleVT, Op01, Op11); + if (((IsMergeableWithShuffle(N0, 0) && IsMergeableWithShuffle(N1, 0)) || + (IsMergeableWithShuffle(N0, 1) && IsMergeableWithShuffle(N1, 1))) || + ((IsMergeableWithShuffle(N0, 0) || IsMergeableWithShuffle(N1, 0)) && + (IsMergeableWithShuffle(N0, 1) || IsMergeableWithShuffle(N1, 1)))) { + SmallVector Ops(N0->ops()); + for (int i = 0; i != 2; ++i) { + if (IgnoreOp(SrcOpcode, i)) + continue; + SDValue LHS = DAG.getBitcast(ShuffleVT, N0.getOperand(i)); + SDValue RHS = DAG.getBitcast(ShuffleVT, N1.getOperand(i)); + Ops[i] = + N.getNumOperands() == 3 + ? DAG.getNode(Opc, DL, ShuffleVT, LHS, RHS, N.getOperand(2)) + : DAG.getNode(Opc, DL, ShuffleVT, LHS, RHS); + Ops[i] = DAG.getBitcast(OpVT, Ops[i]); } - EVT OpVT = N0.getValueType(); return DAG.getBitcast(ShuffleVT, - DAG.getNode(SrcOpcode, DL, OpVT, - DAG.getBitcast(OpVT, LHS), - DAG.getBitcast(OpVT, RHS))); + DAG.getNode(SrcOpcode, DL, OpVT, Ops)); } } } diff --git a/llvm/test/CodeGen/X86/shrink_vmul.ll b/llvm/test/CodeGen/X86/shrink_vmul.ll --- a/llvm/test/CodeGen/X86/shrink_vmul.ll +++ b/llvm/test/CodeGen/X86/shrink_vmul.ll @@ -1921,9 +1921,7 @@ ; X86-SSE-NEXT: movl {{[0-9]+}}(%esp), %ecx ; X86-SSE-NEXT: movl c, %edx ; X86-SSE-NEXT: movd {{.*#+}} xmm0 = mem[0],zero,zero,zero -; X86-SSE-NEXT: pshufd {{.*#+}} xmm0 = xmm0[0,0,0,0] ; X86-SSE-NEXT: psrad $16, %xmm0 -; X86-SSE-NEXT: pshufd {{.*#+}} xmm0 = xmm0[1,1,3,3] ; X86-SSE-NEXT: pmuludq {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0 ; X86-SSE-NEXT: psllq $32, %xmm0 ; X86-SSE-NEXT: movq %xmm0, (%edx,%eax,4) @@ -1944,9 +1942,7 @@ ; X64-SSE: # %bb.0: # %entry ; X64-SSE-NEXT: movq c(%rip), %rax ; X64-SSE-NEXT: movd {{.*#+}} xmm0 = mem[0],zero,zero,zero -; X64-SSE-NEXT: pshufd {{.*#+}} xmm0 = xmm0[0,0,0,0] ; X64-SSE-NEXT: psrad $16, %xmm0 -; X64-SSE-NEXT: pshufd {{.*#+}} xmm0 = xmm0[1,1,3,3] ; X64-SSE-NEXT: pmuludq {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 ; X64-SSE-NEXT: psllq $32, %xmm0 ; X64-SSE-NEXT: movq %xmm0, (%rax,%rsi,4)