diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -562,6 +562,7 @@ SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL); SDValue MatchLoadCombine(SDNode *N); SDValue MatchStoreCombine(StoreSDNode *N); + SDValue MatchABS(SDNode *N); SDValue ReduceLoadWidth(SDNode *N); SDValue ReduceLoadOpStoreWidth(SDNode *N); SDValue splitMergedValStore(StoreSDNode *ST); @@ -3228,18 +3229,9 @@ } // fold Y = sra (X, size(X)-1); sub (xor (X, Y), Y) -> (abs X) - if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) { - if (N0.getOpcode() == ISD::XOR && N1.getOpcode() == ISD::SRA) { - SDValue X0 = N0.getOperand(0), X1 = N0.getOperand(1); - SDValue S0 = N1.getOperand(0); - if ((X0 == S0 && X1 == N1) || (X0 == N1 && X1 == S0)) { - unsigned OpSizeInBits = VT.getScalarSizeInBits(); - if (ConstantSDNode *C = isConstOrConstSplat(N1.getOperand(1))) - if (C->getAPIntValue() == (OpSizeInBits - 1)) - return DAG.getNode(ISD::ABS, SDLoc(N), VT, S0); - } - } - } + if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) + if (SDValue AbsSrc = MatchABS(N)) + return DAG.getNode(ISD::ABS, DL, VT, AbsSrc); // If the relocation model supports it, consider symbol offsets. if (GlobalAddressSDNode *GA = dyn_cast(N0)) @@ -3532,6 +3524,11 @@ getShiftAmountTy(N0.getValueType())))); } + // fold mul(abs(x),abs(x)) -> mul(x,x) + if (N0 == N1) + if (SDValue AbsSrc = MatchABS(N0.getNode())) + return DAG.getNode(ISD::MUL, SDLoc(N), VT, AbsSrc, AbsSrc); + // Try to transform multiply-by-(power-of-2 +/- 1) into shift and add/sub. // mul x, (2^N + 1) --> add (shl x, N), x // mul x, (2^N - 1) --> sub (shl x, N), x @@ -5057,6 +5054,50 @@ return false; } +// Attempts to match an integer absolute pattern and returns the source operand +// on success. +SDValue DAGCombiner::MatchABS(SDNode *N) { + if (N->getOpcode() == ISD::ABS) + return N->getOperand(0); + + // Y = sra (X, size(X)-1); sub (xor (X, Y), Y) -> (abs X) + if (N->getOpcode() == ISD::SUB) { + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + if (N0.getOpcode() == ISD::XOR && N1.getOpcode() == ISD::SRA) { + SDValue X0 = N0.getOperand(0), X1 = N0.getOperand(1); + SDValue S0 = N1.getOperand(0); + if ((X0 == S0 && X1 == N1) || (X0 == N1 && X1 == S0)) { + unsigned OpSizeInBits = N->getValueType(0).getScalarSizeInBits(); + if (ConstantSDNode *C = isConstOrConstSplat(N1.getOperand(1))) + if (C->getAPIntValue() == (OpSizeInBits - 1)) + return S0; + } + } + } + + // Y = sra (X, size(X)-1); xor (add (X, Y), Y) -> (abs X) + if (N->getOpcode() == ISD::XOR) { + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + unsigned N0Opcode = N0.getOpcode(); + SDValue A = N0Opcode == ISD::ADD ? N0 : N1; + SDValue S = N0Opcode == ISD::SRA ? N0 : N1; + if (A.getOpcode() == ISD::ADD && S.getOpcode() == ISD::SRA) { + SDValue A0 = A.getOperand(0), A1 = A.getOperand(1); + SDValue S0 = S.getOperand(0); + if ((A0 == S && A1 == S0) || (A1 == S && A0 == S0)) { + unsigned OpSizeInBits = N->getValueType(0).getScalarSizeInBits(); + if (ConstantSDNode *C = isConstOrConstSplat(S.getOperand(1))) + if (C->getAPIntValue() == (OpSizeInBits - 1)) + return S0; + } + } + } + + return SDValue(); +} + // Unfold // x & (-1 'logical shift' y) // To @@ -7355,20 +7396,9 @@ } // fold Y = sra (X, size(X)-1); xor (add (X, Y), Y) -> (abs X) - if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) { - SDValue A = N0Opcode == ISD::ADD ? N0 : N1; - SDValue S = N0Opcode == ISD::SRA ? N0 : N1; - if (A.getOpcode() == ISD::ADD && S.getOpcode() == ISD::SRA) { - SDValue A0 = A.getOperand(0), A1 = A.getOperand(1); - SDValue S0 = S.getOperand(0); - if ((A0 == S && A1 == S0) || (A1 == S && A0 == S0)) { - unsigned OpSizeInBits = VT.getScalarSizeInBits(); - if (ConstantSDNode *C = isConstOrConstSplat(S.getOperand(1))) - if (C->getAPIntValue() == (OpSizeInBits - 1)) - return DAG.getNode(ISD::ABS, DL, VT, S0); - } - } - } + if (TLI.isOperationLegalOrCustom(ISD::ABS, VT)) + if (SDValue AbsSrc = MatchABS(N)) + return DAG.getNode(ISD::ABS, DL, VT, AbsSrc); // fold (xor x, x) -> 0 if (N0 == N1) diff --git a/llvm/test/CodeGen/X86/combine-mul.ll b/llvm/test/CodeGen/X86/combine-mul.ll --- a/llvm/test/CodeGen/X86/combine-mul.ll +++ b/llvm/test/CodeGen/X86/combine-mul.ll @@ -283,27 +283,19 @@ ret <4 x i32> %2 } -; TODO fold mul(abs(x),abs(x)) -> mul(x,x) +; fold mul(abs(x),abs(x)) -> mul(x,x) define i31 @combine_mul_abs_i31(i31 %0) { ; SSE-LABEL: combine_mul_abs_i31: ; SSE: # %bb.0: -; SSE-NEXT: addl %edi, %edi -; SSE-NEXT: sarl %edi ; SSE-NEXT: movl %edi, %eax -; SSE-NEXT: negl %eax -; SSE-NEXT: cmovll %edi, %eax -; SSE-NEXT: imull %eax, %eax +; SSE-NEXT: imull %edi, %eax ; SSE-NEXT: retq ; ; AVX-LABEL: combine_mul_abs_i31: ; AVX: # %bb.0: -; AVX-NEXT: addl %edi, %edi -; AVX-NEXT: sarl %edi ; AVX-NEXT: movl %edi, %eax -; AVX-NEXT: negl %eax -; AVX-NEXT: cmovll %edi, %eax -; AVX-NEXT: imull %eax, %eax +; AVX-NEXT: imull %edi, %eax ; AVX-NEXT: retq %c = icmp slt i31 %0, 0 %s = sub nsw i31 0, %0 @@ -316,17 +308,13 @@ ; SSE-LABEL: combine_mul_abs_i32: ; SSE: # %bb.0: ; SSE-NEXT: movl %edi, %eax -; SSE-NEXT: negl %eax -; SSE-NEXT: cmovll %edi, %eax -; SSE-NEXT: imull %eax, %eax +; SSE-NEXT: imull %edi, %eax ; SSE-NEXT: retq ; ; AVX-LABEL: combine_mul_abs_i32: ; AVX: # %bb.0: ; AVX-NEXT: movl %edi, %eax -; AVX-NEXT: negl %eax -; AVX-NEXT: cmovll %edi, %eax -; AVX-NEXT: imull %eax, %eax +; AVX-NEXT: imull %edi, %eax ; AVX-NEXT: retq %c = icmp slt i32 %0, 0 %s = sub nsw i32 0, %0 @@ -338,13 +326,11 @@ define <4 x i32> @combine_mul_abs_v4i32(<4 x i32> %0) { ; SSE-LABEL: combine_mul_abs_v4i32: ; SSE: # %bb.0: -; SSE-NEXT: pabsd %xmm0, %xmm0 ; SSE-NEXT: pmulld %xmm0, %xmm0 ; SSE-NEXT: retq ; ; AVX-LABEL: combine_mul_abs_v4i32: ; AVX: # %bb.0: -; AVX-NEXT: vpabsd %xmm0, %xmm0 ; AVX-NEXT: vpmulld %xmm0, %xmm0, %xmm0 ; AVX-NEXT: retq %c = icmp slt <4 x i32> %0, zeroinitializer