Index: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp =================================================================== --- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp +++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp @@ -34024,6 +34024,53 @@ return SDValue(); } +// Helper for splitting operands of a binary operation to legal target size and +// apply a function on each part. +// Useful for operations that are available on SSE2 in 128-bit, on AVX2 in +// 256-bit and on AVX512BW in 512-bit. +// The argument VT is the type used for deciding if/how to split the operands +// Op0 and Op1. Op0 and Op1 do *not* have to be of type VT. +// The argument Builder is a function that will be applied on each split psrt: +// SDValue Builder(SelectionDAG&G, SDLoc, SDValue, SDValue) +template +SDValue SplitBinaryOpsAndApply(SelectionDAG &DAG, const X86Subtarget &Subtarget, + SDLoc DL, EVT VT, SDValue Op0, SDValue Op1, + F Builder) { + assert(Subtarget.hasSSE2() && "Target assumed to support at least SSE2"); + unsigned NumSubs = 1; + if (Subtarget.hasBWI()) { + if (VT.getSizeInBits() > 512) { + NumSubs = VT.getSizeInBits() / 512; + assert((VT.getSizeInBits() % 512) == 0 && "Illegal vector size"); + } + } else if (Subtarget.hasAVX2()) { + if (VT.getSizeInBits() > 256) { + NumSubs = VT.getSizeInBits() / 256; + assert((VT.getSizeInBits() % 256) == 0 && "Illegal vector size"); + } + } else { + if (VT.getSizeInBits() > 128) { + NumSubs = VT.getSizeInBits() / 128; + assert((VT.getSizeInBits() % 128) == 0 && "Illegal vector size"); + } + } + + if (NumSubs == 1) + return Builder(DAG, DL, Op0, Op1); + + SmallVector Subs; + EVT InVT = Op0.getValueType(); + EVT SubVT = EVT::getVectorVT(*DAG.getContext(), InVT.getScalarType(), + InVT.getVectorNumElements() / NumSubs); + for (unsigned i = 0; i != NumSubs; ++i) { + unsigned Idx = i * SubVT.getVectorNumElements(); + SDValue LHS = extractSubVector(Op0, Idx, DAG, DL, SubVT.getSizeInBits()); + SDValue RHS = extractSubVector(Op1, Idx, DAG, DL, SubVT.getSizeInBits()); + Subs.push_back(Builder(DAG, DL, LHS, RHS)); + } + return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Subs); +} + /// This function detects the AVG pattern between vectors of unsigned i8/i16, /// which is c = (a + b + 1) / 2, and replace this operation with the efficient /// X86ISD::AVG instruction. @@ -34079,35 +34126,6 @@ return true; }; - // Split vectors to legal target size and apply AVG. - auto LowerToAVG = [&](SDValue Op0, SDValue Op1) { - unsigned NumSubs = 1; - if (Subtarget.hasBWI()) { - if (VT.getSizeInBits() > 512) - NumSubs = VT.getSizeInBits() / 512; - } else if (Subtarget.hasAVX2()) { - if (VT.getSizeInBits() > 256) - NumSubs = VT.getSizeInBits() / 256; - } else { - if (VT.getSizeInBits() > 128) - NumSubs = VT.getSizeInBits() / 128; - } - - if (NumSubs == 1) - return DAG.getNode(X86ISD::AVG, DL, VT, Op0, Op1); - - SmallVector Subs; - EVT SubVT = EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(), - VT.getVectorNumElements() / NumSubs); - for (unsigned i = 0; i != NumSubs; ++i) { - unsigned Idx = i * SubVT.getVectorNumElements(); - SDValue LHS = extractSubVector(Op0, Idx, DAG, DL, SubVT.getSizeInBits()); - SDValue RHS = extractSubVector(Op1, Idx, DAG, DL, SubVT.getSizeInBits()); - Subs.push_back(DAG.getNode(X86ISD::AVG, DL, SubVT, LHS, RHS)); - } - return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Subs); - }; - // Check if each element of the vector is left-shifted by one. auto LHS = In.getOperand(0); auto RHS = In.getOperand(1); @@ -34121,6 +34139,10 @@ Operands[0] = LHS.getOperand(0); Operands[1] = LHS.getOperand(1); + auto AVGBuilder = [](SelectionDAG &DAG, SDLoc DL, SDValue Op0, SDValue Op1) { + return DAG.getNode(X86ISD::AVG, DL, Op0.getValueType(), Op0, Op1); + }; + // Take care of the case when one of the operands is a constant vector whose // element is in the range [1, 256]. if (IsConstVectorInRange(Operands[1], 1, ScalarVT == MVT::i8 ? 256 : 65536) && @@ -34131,7 +34153,9 @@ SDValue VecOnes = DAG.getConstant(1, DL, InVT); Operands[1] = DAG.getNode(ISD::SUB, DL, InVT, Operands[1], VecOnes); Operands[1] = DAG.getNode(ISD::TRUNCATE, DL, VT, Operands[1]); - return LowerToAVG(Operands[0].getOperand(0), Operands[1]); + return SplitBinaryOpsAndApply(DAG, Subtarget, DL, VT, + Operands[0].getOperand(0), Operands[1], + AVGBuilder); } if (Operands[0].getOpcode() == ISD::ADD) @@ -34154,8 +34178,10 @@ Operands[j].getOperand(0).getValueType() != VT) return SDValue(); - // The pattern is detected, emit X86ISD::AVG instruction. - return LowerToAVG(Operands[0].getOperand(0), Operands[1].getOperand(0)); + // The pattern is detected, emit X86ISD::AVG instruction(s). + return SplitBinaryOpsAndApply(DAG, Subtarget, DL, VT, + Operands[0].getOperand(0), + Operands[1].getOperand(0), AVGBuilder); } return SDValue();