Skip to content

Commit

Permalink
[X86] Split out absdiff detection from SAD combine. NFC.
Browse files Browse the repository at this point in the history
Preparation for supporting PSADBW emission for straight-line code.

llvm-svn: 276798
  • Loading branch information
mkuperst committed Jul 26, 2016
1 parent 4482b2a commit 2dc08f7
Showing 1 changed file with 64 additions and 59 deletions.
123 changes: 64 additions & 59 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
@@ -30680,8 +30680,64 @@ static SDValue OptimizeConditionalInDecrement(SDNode *N, SelectionDAG &DAG) {
DAG.getConstant(0, DL, OtherVal.getValueType()), NewCmp);
}

static SDValue detectSADPattern(SDNode *N, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
// Given a select, detect the following pattern:
// 1: %2 = zext <N x i8> %0 to <N x i32>
// 2: %3 = zext <N x i8> %1 to <N x i32>
// 3: %4 = sub nsw <N x i32> %2, %3
// 4: %5 = icmp sgt <N x i32> %4, [0 x N] or [-1 x N]
// 5: %6 = sub nsw <N x i32> zeroinitializer, %4
// 6: %7 = select <N x i1> %5, <N x i32> %4, <N x i32> %6
// This is useful as it is the input into a SAD pattern.
static bool detectZextAbsDiff(const SDValue &Select, SDValue &Op0,
SDValue &Op1) {
// Check the condition of the select instruction is greater-than.
SDValue SetCC = Select->getOperand(0);
if (SetCC.getOpcode() != ISD::SETCC)
return false;
ISD::CondCode CC = cast<CondCodeSDNode>(SetCC.getOperand(2))->get();
if (CC != ISD::SETGT)
return false;

SDValue SelectOp1 = Select->getOperand(1);
SDValue SelectOp2 = Select->getOperand(2);

// The second operand of the select should be the negation of the first
// operand, which is implemented as 0 - SelectOp1.
if (!(SelectOp2.getOpcode() == ISD::SUB &&
ISD::isBuildVectorAllZeros(SelectOp2.getOperand(0).getNode()) &&
SelectOp2.getOperand(1) == SelectOp1))
return false;

// The first operand of SetCC is the first operand of the select, which is the
// difference between the two input vectors.
if (SetCC.getOperand(0) != SelectOp1)
return false;

// The second operand of the comparison can be either -1 or 0.
if (!(ISD::isBuildVectorAllZeros(SetCC.getOperand(1).getNode()) ||
ISD::isBuildVectorAllOnes(SetCC.getOperand(1).getNode())))
return false;

// The first operand of the select is the difference between the two input
// vectors.
if (SelectOp1.getOpcode() != ISD::SUB)
return false;

Op0 = SelectOp1.getOperand(0);
Op1 = SelectOp1.getOperand(1);

// Check if the operands of the sub are zero-extended from vectors of i8.
if (Op0.getOpcode() != ISD::ZERO_EXTEND ||
Op0.getOperand(0).getValueType().getVectorElementType() != MVT::i8 ||
Op1.getOpcode() != ISD::ZERO_EXTEND ||
Op1.getOperand(0).getValueType().getVectorElementType() != MVT::i8)
return false;

return true;
}

static SDValue combineLoopSADPattern(SDNode *N, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
SDLoc DL(N);
EVT VT = N->getValueType(0);
SDValue Op0 = N->getOperand(0);
@@ -30701,21 +30757,8 @@ static SDValue detectSADPattern(SDNode *N, SelectionDAG &DAG,
if (VT.getSizeInBits() / 4 > RegSize)
return SDValue();

// Detect the following pattern:
//
// 1: %2 = zext <N x i8> %0 to <N x i32>
// 2: %3 = zext <N x i8> %1 to <N x i32>
// 3: %4 = sub nsw <N x i32> %2, %3
// 4: %5 = icmp sgt <N x i32> %4, [0 x N] or [-1 x N]
// 5: %6 = sub nsw <N x i32> zeroinitializer, %4
// 6: %7 = select <N x i1> %5, <N x i32> %4, <N x i32> %6
// 7: %8 = add nsw <N x i32> %7, %vec.phi
//
// The last instruction must be a reduction add. The instructions 3-6 forms an
// ABSDIFF pattern.

// The two operands of reduction add are from PHI and a select-op as in line 7
// above.
// We know N is a reduction add, which means one of its operands is a phi.
// To match SAD, we need the other operand to be a vector select.
SDValue SelectOp, Phi;
if (Op0.getOpcode() == ISD::VSELECT) {
SelectOp = Op0;
@@ -30726,50 +30769,12 @@ static SDValue detectSADPattern(SDNode *N, SelectionDAG &DAG,
} else
return SDValue();

// Check the condition of the select instruction is greater-than.
SDValue SetCC = SelectOp->getOperand(0);
if (SetCC.getOpcode() != ISD::SETCC)
return SDValue();
ISD::CondCode CC = cast<CondCodeSDNode>(SetCC.getOperand(2))->get();
if (CC != ISD::SETGT)
return SDValue();

Op0 = SelectOp->getOperand(1);
Op1 = SelectOp->getOperand(2);

// The second operand of SelectOp Op1 is the negation of the first operand
// Op0, which is implemented as 0 - Op0.
if (!(Op1.getOpcode() == ISD::SUB &&
ISD::isBuildVectorAllZeros(Op1.getOperand(0).getNode()) &&
Op1.getOperand(1) == Op0))
return SDValue();

// The first operand of SetCC is the first operand of SelectOp, which is the
// difference between two input vectors.
if (SetCC.getOperand(0) != Op0)
return SDValue();

// The second operand of > comparison can be either -1 or 0.
if (!(ISD::isBuildVectorAllZeros(SetCC.getOperand(1).getNode()) ||
ISD::isBuildVectorAllOnes(SetCC.getOperand(1).getNode())))
return SDValue();

// The first operand of SelectOp is the difference between two input vectors.
if (Op0.getOpcode() != ISD::SUB)
return SDValue();

Op1 = Op0.getOperand(1);
Op0 = Op0.getOperand(0);

// Check if the operands of the diff are zero-extended from vectors of i8.
if (Op0.getOpcode() != ISD::ZERO_EXTEND ||
Op0.getOperand(0).getValueType().getVectorElementType() != MVT::i8 ||
Op1.getOpcode() != ISD::ZERO_EXTEND ||
Op1.getOperand(0).getValueType().getVectorElementType() != MVT::i8)
// Check whether we have an abs-diff pattern feeding into the select.
if(!detectZextAbsDiff(SelectOp, Op0, Op1))
return SDValue();

// SAD pattern detected. Now build a SAD instruction and an addition for
// reduction. Note that the number of elments of the result of SAD is less
// reduction. Note that the number of elements of the result of SAD is less
// than the number of elements of its input. Therefore, we could only update
// part of elements in the reduction vector.

@@ -30819,7 +30824,7 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
const SDNodeFlags *Flags = &cast<BinaryWithFlagsSDNode>(N)->Flags;
if (Flags->hasVectorReduction()) {
if (SDValue Sad = detectSADPattern(N, DAG, Subtarget))
if (SDValue Sad = combineLoopSADPattern(N, DAG, Subtarget))
return Sad;
}
EVT VT = N->getValueType(0);

0 comments on commit 2dc08f7

Please sign in to comment.