diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -45007,6 +45007,23 @@ return SDValue(); } +static SDValue GetNot(SDValue V, SelectionDAG &DAG, MVT VT) { + // Basic X = NOT(Y) detection. + if (SDValue Not = IsNOT(V, DAG)) + return Not; + // Fold BROADCAST(NOT(Y)) -> BROADCAST(Y). + if (V.getOpcode() == X86ISD::VBROADCAST) { + SDValue Src = V.getOperand(0); + EVT SrcVT = Src.getValueType(); + if (!SrcVT.isVector()) + return SDValue(); + if (SDValue Not = IsNOT(Src, DAG)) + return DAG.getNode(X86ISD::VBROADCAST, SDLoc(V), VT, + DAG.getBitcast(SrcVT, Not)); + } + return SDValue(); +} + /// Try to fold: (and (xor X, -1), Y) -> (andnp X, Y). static SDValue combineANDXORWithAllOnesIntoANDNP(SDNode *N, SelectionDAG &DAG) { assert(N->getOpcode() == ISD::AND); @@ -45019,27 +45036,10 @@ SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); - auto GetNot = [&VT, &DAG](SDValue V) { - // Basic X = NOT(Y) detection. - if (SDValue Not = IsNOT(V, DAG)) - return Not; - // Fold BROADCAST(NOT(Y)) -> BROADCAST(Y). - if (V.getOpcode() == X86ISD::VBROADCAST) { - SDValue Src = V.getOperand(0); - EVT SrcVT = Src.getValueType(); - if (!SrcVT.isVector()) - return SDValue(); - if (SDValue Not = IsNOT(Src, DAG)) - return DAG.getNode(X86ISD::VBROADCAST, SDLoc(V), VT, - DAG.getBitcast(SrcVT, Not)); - } - return SDValue(); - }; - - if (SDValue Not = GetNot(N0)) { + if (SDValue Not = GetNot(N0, DAG, VT)) { X = Not; Y = N1; - } else if (SDValue Not = GetNot(N1)) { + } else if (SDValue Not = GetNot(N1, DAG, VT)) { X = Not; Y = N0; } else @@ -45050,6 +45050,52 @@ return DAG.getNode(X86ISD::ANDNP, SDLoc(N), VT, X, Y); } +/// Try to fold: (and (or (xor X, -1), Y), Z) -> PTERNLOG Z, Y, X, 0xD0 +static SDValue +combineANDORXORWithAllOnesIntoPTERNLOG(SDNode *N, SelectionDAG &DAG, + const X86Subtarget &Subtarget) { + assert(N->getOpcode() == ISD::AND); + + MVT VT = N->getSimpleValueType(0); + if (!VT.is128BitVector() && !VT.is256BitVector() && !VT.is512BitVector()) + return SDValue(); + + SDValue X, Y, Z; + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + + if (!((Subtarget.hasAVX512() && VT.is512BitVector()) || Subtarget.hasVLX())) + return SDValue(); + + auto GetORNot = [&VT, &DAG](SDValue V, SDValue &X, SDValue &Y) { + if (V.getOpcode() != ISD::OR) + return false; + + SDValue N0 = V.getOperand(0); + SDValue N1 = V.getOperand(1); + if (SDValue Not = GetNot(N0, DAG, VT)) { + X = Not; + Y = N1; + } else if (SDValue Not = GetNot(N1, DAG, VT)) { + X = Not; + Y = N0; + } else + return false; + + return true; + }; + + if (GetORNot(N0, X, Y)) + Z = N1; + else if (GetORNot(N1, X, Y)) + Z = N0; + else + return SDValue(); + + SDValue Imm = DAG.getTargetConstant(0xD0, SDLoc(N), MVT::i8); + return DAG.getNode(X86ISD::VPTERNLOG, SDLoc(N), VT, Z, Y, X, Imm); +} + // Try to widen AND, OR and XOR nodes to VT in order to remove casts around // logical operations, like in the example below. // or (and (truncate x, truncate y)), @@ -45505,6 +45551,9 @@ if (SDValue R = combineANDXORWithAllOnesIntoANDNP(N, DAG)) return R; + if (SDValue R = combineANDORXORWithAllOnesIntoPTERNLOG(N, DAG, Subtarget)) + return R; + if (SDValue ShiftRight = combineAndMaskToShift(N, DAG, Subtarget)) return ShiftRight; diff --git a/llvm/test/CodeGen/X86/avx512vl-logic.ll b/llvm/test/CodeGen/X86/avx512vl-logic.ll --- a/llvm/test/CodeGen/X86/avx512vl-logic.ll +++ b/llvm/test/CodeGen/X86/avx512vl-logic.ll @@ -977,6 +977,17 @@ ret <4 x i32> %c } +define <4 x i32> @ternlog_and_orn(<4 x i32> %x, <4 x i32> %y, <4 x i32> %z) { +; CHECK-LABEL: ternlog_and_orn: +; CHECK: ## %bb.0: +; CHECK-NEXT: vpternlogd $208, %xmm2, %xmm1, %xmm0 +; CHECK-NEXT: retq + %a = xor <4 x i32> %z, + %b = or <4 x i32> %a, %y + %c = and <4 x i32> %b, %x + ret <4 x i32> %c +} + define <4 x i32> @ternlog_xor_andn(<4 x i32> %x, <4 x i32> %y, <4 x i32> %z) { ; CHECK-LABEL: ternlog_xor_andn: ; CHECK: ## %bb.0: