diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp --- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp +++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp @@ -504,8 +504,9 @@ bool tryShiftAmountMod(SDNode *N); bool tryShrinkShlLogicImm(SDNode *N); bool tryVPTERNLOG(SDNode *N); - bool matchVPTERNLOG(SDNode *Root, SDNode *ParentA, SDNode *ParentBC, - SDValue A, SDValue B, SDValue C, uint8_t Imm); + bool matchVPTERNLOG(SDNode *Root, SDNode *ParentA, SDNode *ParentB, + SDNode *ParentC, SDValue A, SDValue B, SDValue C, + uint8_t Imm); bool tryVPTESTM(SDNode *Root, SDValue Setcc, SDValue Mask); bool tryMatchBitSelect(SDNode *N); @@ -4043,11 +4044,12 @@ } bool X86DAGToDAGISel::matchVPTERNLOG(SDNode *Root, SDNode *ParentA, - SDNode *ParentBC, SDValue A, SDValue B, - SDValue C, uint8_t Imm) { + SDNode *ParentB, SDNode *ParentC, + SDValue A, SDValue B, SDValue C, + uint8_t Imm) { assert(A.isOperandOf(ParentA)); - assert(B.isOperandOf(ParentBC)); - assert(C.isOperandOf(ParentBC)); + assert(B.isOperandOf(ParentB)); + assert(C.isOperandOf(ParentC)); auto tryFoldLoadOrBCast = [this](SDNode *Root, SDNode *P, SDValue &L, SDValue &Base, SDValue &Scale, @@ -4075,7 +4077,7 @@ bool FoldedLoad = false; SDValue Tmp0, Tmp1, Tmp2, Tmp3, Tmp4; - if (tryFoldLoadOrBCast(Root, ParentBC, C, Tmp0, Tmp1, Tmp2, Tmp3, Tmp4)) { + if (tryFoldLoadOrBCast(Root, ParentC, C, Tmp0, Tmp1, Tmp2, Tmp3, Tmp4)) { FoldedLoad = true; } else if (tryFoldLoadOrBCast(Root, ParentA, A, Tmp0, Tmp1, Tmp2, Tmp3, Tmp4)) { @@ -4088,7 +4090,7 @@ if (OldImm & 0x10) Imm |= 0x02; if (OldImm & 0x08) Imm |= 0x40; if (OldImm & 0x40) Imm |= 0x08; - } else if (tryFoldLoadOrBCast(Root, ParentBC, B, Tmp0, Tmp1, Tmp2, Tmp3, + } else if (tryFoldLoadOrBCast(Root, ParentB, B, Tmp0, Tmp1, Tmp2, Tmp3, Tmp4)) { FoldedLoad = true; std::swap(B, C); @@ -4166,7 +4168,6 @@ } // Try to match two logic ops to a VPTERNLOG. -// FIXME: Handle inverted inputs? // FIXME: Handle more complex patterns that use an operand more than once? bool X86DAGToDAGISel::tryVPTERNLOG(SDNode *N) { MVT NVT = N->getSimpleValueType(0); @@ -4209,12 +4210,41 @@ SDValue B = FoldableOp.getOperand(0); SDValue C = FoldableOp.getOperand(1); + SDNode *ParentA = N; + SDNode *ParentB = FoldableOp.getNode(); + SDNode *ParentC = FoldableOp.getNode(); // We can build the appropriate control immediate by performing the logic // operation we're matching using these constants for A, B, and C. - const uint8_t TernlogMagicA = 0xf0; - const uint8_t TernlogMagicB = 0xcc; - const uint8_t TernlogMagicC = 0xaa; + uint8_t TernlogMagicA = 0xf0; + uint8_t TernlogMagicB = 0xcc; + uint8_t TernlogMagicC = 0xaa; + + auto IsNot = [](SDValue Op) { + return Op.getOpcode() == ISD::XOR && Op.hasOneUse() && + ISD::isBuildVectorAllOnes(Op.getOperand(1).getNode()); + }; + + // Some of the inputs may be inverted, peek through them and invert the + // magic values accordingly. + // TODO: There may be a bitcast before the xor that we should peek through. + if (IsNot(A)) { + TernlogMagicA = ~TernlogMagicA; + ParentA = A.getNode(); + A = A.getOperand(0); + } + + if (IsNot(B)) { + TernlogMagicB = ~TernlogMagicB; + ParentB = B.getNode(); + B = B.getOperand(0); + } + + if (IsNot(C)) { + TernlogMagicC = ~TernlogMagicC; + ParentC = C.getNode(); + C = B.getOperand(0); + } uint8_t Imm; switch (FoldableOp.getOpcode()) { @@ -4238,7 +4268,7 @@ case ISD::XOR: Imm ^= TernlogMagicA; break; } - return matchVPTERNLOG(N, N, FoldableOp.getNode(), A, B, C, Imm); + return matchVPTERNLOG(N, ParentA, ParentB, ParentC, A, B, C, Imm); } /// If the high bits of an 'and' operand are known zero, try setting the @@ -4575,7 +4605,7 @@ ReplaceNode(N, Ternlog.getNode()); return matchVPTERNLOG(Ternlog.getNode(), Ternlog.getNode(), Ternlog.getNode(), - A, B, C, 0xCA); + Ternlog.getNode(), A, B, C, 0xCA); } void X86DAGToDAGISel::Select(SDNode *Node) { @@ -4810,7 +4840,7 @@ case X86ISD::VPTERNLOG: { uint8_t Imm = cast(Node->getOperand(3))->getZExtValue(); - if (matchVPTERNLOG(Node, Node, Node, Node->getOperand(0), + if (matchVPTERNLOG(Node, Node, Node, Node, Node->getOperand(0), Node->getOperand(1), Node->getOperand(2), Imm)) return; break; diff --git a/llvm/test/CodeGen/X86/avx512vl-logic.ll b/llvm/test/CodeGen/X86/avx512vl-logic.ll --- a/llvm/test/CodeGen/X86/avx512vl-logic.ll +++ b/llvm/test/CodeGen/X86/avx512vl-logic.ll @@ -977,6 +977,17 @@ ret <4 x i32> %c } +define <4 x i32> @ternlog_and_orn(<4 x i32> %x, <4 x i32> %y, <4 x i32> %z) { +; CHECK-LABEL: ternlog_and_orn: +; CHECK: ## %bb.0: +; CHECK-NEXT: vpternlogd $176, %xmm1, %xmm2, %xmm0 +; CHECK-NEXT: retq + %a = xor <4 x i32> %z, + %b = or <4 x i32> %a, %y + %c = and <4 x i32> %b, %x + ret <4 x i32> %c +} + define <4 x i32> @ternlog_xor_andn(<4 x i32> %x, <4 x i32> %y, <4 x i32> %z) { ; CHECK-LABEL: ternlog_xor_andn: ; CHECK: ## %bb.0: