diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp --- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp +++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp @@ -504,8 +504,9 @@ bool tryShiftAmountMod(SDNode *N); bool tryShrinkShlLogicImm(SDNode *N); bool tryVPTERNLOG(SDNode *N); - bool matchVPTERNLOG(SDNode *Root, SDNode *ParentA, SDNode *ParentBC, - SDValue A, SDValue B, SDValue C, uint8_t Imm); + bool matchVPTERNLOG(SDNode *Root, SDNode *ParentA, SDNode *ParentB, + SDNode *ParentC, SDValue A, SDValue B, SDValue C, + uint8_t Imm); bool tryVPTESTM(SDNode *Root, SDValue Setcc, SDValue Mask); bool tryMatchBitSelect(SDNode *N); @@ -4043,11 +4044,11 @@ } bool X86DAGToDAGISel::matchVPTERNLOG(SDNode *Root, SDNode *ParentA, - SDNode *ParentBC, SDValue A, SDValue B, - SDValue C, uint8_t Imm) { - assert(A.isOperandOf(ParentA)); - assert(B.isOperandOf(ParentBC)); - assert(C.isOperandOf(ParentBC)); + SDNode *ParentB, SDNode *ParentC, + SDValue A, SDValue B, SDValue C, + uint8_t Imm) { + assert(A.isOperandOf(ParentA) && B.isOperandOf(ParentB) && + C.isOperandOf(ParentC) && "Incorrect parent node"); auto tryFoldLoadOrBCast = [this](SDNode *Root, SDNode *P, SDValue &L, SDValue &Base, SDValue &Scale, @@ -4075,7 +4076,7 @@ bool FoldedLoad = false; SDValue Tmp0, Tmp1, Tmp2, Tmp3, Tmp4; - if (tryFoldLoadOrBCast(Root, ParentBC, C, Tmp0, Tmp1, Tmp2, Tmp3, Tmp4)) { + if (tryFoldLoadOrBCast(Root, ParentC, C, Tmp0, Tmp1, Tmp2, Tmp3, Tmp4)) { FoldedLoad = true; } else if (tryFoldLoadOrBCast(Root, ParentA, A, Tmp0, Tmp1, Tmp2, Tmp3, Tmp4)) { @@ -4088,7 +4089,7 @@ if (OldImm & 0x10) Imm |= 0x02; if (OldImm & 0x08) Imm |= 0x40; if (OldImm & 0x40) Imm |= 0x08; - } else if (tryFoldLoadOrBCast(Root, ParentBC, B, Tmp0, Tmp1, Tmp2, Tmp3, + } else if (tryFoldLoadOrBCast(Root, ParentB, B, Tmp0, Tmp1, Tmp2, Tmp3, Tmp4)) { FoldedLoad = true; std::swap(B, C); @@ -4166,7 +4167,6 @@ } // Try to match two logic ops to a VPTERNLOG. -// FIXME: Handle inverted inputs? // FIXME: Handle more complex patterns that use an operand more than once? bool X86DAGToDAGISel::tryVPTERNLOG(SDNode *N) { MVT NVT = N->getSimpleValueType(0); @@ -4209,12 +4209,31 @@ SDValue B = FoldableOp.getOperand(0); SDValue C = FoldableOp.getOperand(1); + SDNode *ParentA = N; + SDNode *ParentB = FoldableOp.getNode(); + SDNode *ParentC = FoldableOp.getNode(); // We can build the appropriate control immediate by performing the logic // operation we're matching using these constants for A, B, and C. - const uint8_t TernlogMagicA = 0xf0; - const uint8_t TernlogMagicB = 0xcc; - const uint8_t TernlogMagicC = 0xaa; + uint8_t TernlogMagicA = 0xf0; + uint8_t TernlogMagicB = 0xcc; + uint8_t TernlogMagicC = 0xaa; + + // Some of the inputs may be inverted, peek through them and invert the + // magic values accordingly. + // TODO: There may be a bitcast before the xor that we should peek through. + auto PeekThroughNot = [](SDValue &Op, SDNode *&Parent, uint8_t &Magic) { + if (Op.getOpcode() == ISD::XOR && Op.hasOneUse() && + ISD::isBuildVectorAllOnes(Op.getOperand(1).getNode())) { + Magic = ~Magic; + Parent = Op.getNode(); + Op = Op.getOperand(0); + } + }; + + PeekThroughNot(A, ParentA, TernlogMagicA); + PeekThroughNot(B, ParentB, TernlogMagicB); + PeekThroughNot(C, ParentC, TernlogMagicC); uint8_t Imm; switch (FoldableOp.getOpcode()) { @@ -4238,7 +4257,7 @@ case ISD::XOR: Imm ^= TernlogMagicA; break; } - return matchVPTERNLOG(N, N, FoldableOp.getNode(), A, B, C, Imm); + return matchVPTERNLOG(N, ParentA, ParentB, ParentC, A, B, C, Imm); } /// If the high bits of an 'and' operand are known zero, try setting the @@ -4575,7 +4594,7 @@ ReplaceNode(N, Ternlog.getNode()); return matchVPTERNLOG(Ternlog.getNode(), Ternlog.getNode(), Ternlog.getNode(), - A, B, C, 0xCA); + Ternlog.getNode(), A, B, C, 0xCA); } void X86DAGToDAGISel::Select(SDNode *Node) { @@ -4810,7 +4829,7 @@ case X86ISD::VPTERNLOG: { uint8_t Imm = cast(Node->getOperand(3))->getZExtValue(); - if (matchVPTERNLOG(Node, Node, Node, Node->getOperand(0), + if (matchVPTERNLOG(Node, Node, Node, Node, Node->getOperand(0), Node->getOperand(1), Node->getOperand(2), Imm)) return; break; diff --git a/llvm/test/CodeGen/X86/avx512vl-logic.ll b/llvm/test/CodeGen/X86/avx512vl-logic.ll --- a/llvm/test/CodeGen/X86/avx512vl-logic.ll +++ b/llvm/test/CodeGen/X86/avx512vl-logic.ll @@ -980,8 +980,7 @@ define <4 x i32> @ternlog_and_orn(<4 x i32> %x, <4 x i32> %y, <4 x i32> %z) { ; CHECK-LABEL: ternlog_and_orn: ; CHECK: ## %bb.0: -; CHECK-NEXT: vpternlogq $15, %xmm2, %xmm2, %xmm2 -; CHECK-NEXT: vpternlogd $224, %xmm1, %xmm2, %xmm0 +; CHECK-NEXT: vpternlogd $176, %xmm1, %xmm2, %xmm0 ; CHECK-NEXT: retq %a = xor <4 x i32> %z, %b = or <4 x i32> %a, %y @@ -992,8 +991,7 @@ define <4 x i32> @ternlog_and_orn_2(<4 x i32> %x, <4 x i32> %y, <4 x i32> %z) { ; CHECK-LABEL: ternlog_and_orn_2: ; CHECK: ## %bb.0: -; CHECK-NEXT: vpternlogq $15, %xmm2, %xmm2, %xmm2 -; CHECK-NEXT: vpternlogd $224, %xmm2, %xmm1, %xmm0 +; CHECK-NEXT: vpternlogd $208, %xmm2, %xmm1, %xmm0 ; CHECK-NEXT: retq %a = xor <4 x i32> %z, %b = or <4 x i32> %y, %a @@ -1001,6 +999,8 @@ ret <4 x i32> %c } +; FIXME: This should be a single vpternlog, but we accidentally match the xor -1 +; as the second binary op instead of the and. define <4 x i32> @ternlog_orn_and(<4 x i32> %x, <4 x i32> %y, <4 x i32> %z) { ; CHECK-LABEL: ternlog_orn_and: ; CHECK: ## %bb.0: @@ -1017,8 +1017,7 @@ define <4 x i32> @ternlog_orn_and_2(<4 x i32> %x, <4 x i32> %y, <4 x i32> %z) { ; CHECK-LABEL: ternlog_orn_and_2: ; CHECK: ## %bb.0: -; CHECK-NEXT: vpternlogq $15, %xmm0, %xmm0, %xmm0 -; CHECK-NEXT: vpternlogd $248, %xmm2, %xmm1, %xmm0 +; CHECK-NEXT: vpternlogd $143, %xmm2, %xmm1, %xmm0 ; CHECK-NEXT: retq %a = xor <4 x i32> %x, %b = and <4 x i32> %y, %z