diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp --- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp +++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp @@ -3940,30 +3940,39 @@ if (!(Subtarget->hasVLX() || NVT.is512BitVector())) return false; - unsigned Opc1 = N->getOpcode(); SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); - auto isLogicOp = [](unsigned Opc) { - return Opc == ISD::AND || Opc == ISD::OR || Opc == ISD::XOR || - Opc == X86ISD::ANDNP; + auto getFoldableLogicOp = [](SDValue Op) { + // Peek through single use bitcast. + if (Op.getOpcode() == ISD::BITCAST && Op.hasOneUse()) + Op = Op.getOperand(0); + + if (!Op.hasOneUse()) + return SDValue(); + + unsigned Opc = Op.getOpcode(); + if (Opc == ISD::AND || Opc == ISD::OR || Opc == ISD::XOR || + Opc == X86ISD::ANDNP) + return Op; + + return SDValue(); }; - SDValue A, B, C; - unsigned Opc2; - if (isLogicOp(N1.getOpcode()) && N1.hasOneUse()) { - Opc2 = N1.getOpcode(); + SDValue A, FoldableOp; + if ((FoldableOp = getFoldableLogicOp(N1))) { A = N0; - B = N1.getOperand(0); - C = N1.getOperand(1); - } else if (isLogicOp(N0.getOpcode()) && N0.hasOneUse()) { - Opc2 = N0.getOpcode(); + } else if ((FoldableOp = getFoldableLogicOp(N0))) { A = N1; - B = N0.getOperand(0); - C = N0.getOperand(1); } else return false; + SDValue B = FoldableOp.getOperand(0); + SDValue C = FoldableOp.getOperand(1); + + unsigned Opc1 = N->getOpcode(); + unsigned Opc2 = FoldableOp.getOpcode(); + uint64_t Imm; switch (Opc1) { default: llvm_unreachable("Unexpected opcode!"); @@ -3996,11 +4005,117 @@ break; } + auto tryFoldLoadOrBCast = + [this](SDNode *Root, SDNode *P, SDValue &L, SDValue &Base, SDValue &Scale, + SDValue &Index, SDValue &Disp, SDValue &Segment) { + if (tryFoldLoad(Root, P, L, Base, Scale, Index, Disp, Segment)) + return true; + + // Not a load, check for broadcast which may be behind a bitcast. + if (L.getOpcode() == ISD::BITCAST && L.hasOneUse()) { + P = L.getNode(); + L = L.getOperand(0); + } + + if (L.getOpcode() != X86ISD::VBROADCAST_LOAD) + return false; + + // Only 32 and 64 bit broadcasts are supported. + auto *MemIntr = cast<MemIntrinsicSDNode>(L); + unsigned Size = MemIntr->getMemoryVT().getSizeInBits(); + if (Size != 32 && Size != 64) + return false; + + return tryFoldBroadcast(Root, P, L, Base, Scale, Index, Disp, Segment); + }; + + bool FoldedLoad = false; + SDValue Tmp0, Tmp1, Tmp2, Tmp3, Tmp4; + if (tryFoldLoadOrBCast(N, FoldableOp.getNode(), C, Tmp0, Tmp1, Tmp2, Tmp3, + Tmp4)) { + FoldedLoad = true; + } else if (tryFoldLoadOrBCast(N, N, A, Tmp0, Tmp1, Tmp2, Tmp3, Tmp4)) { + FoldedLoad = true; + std::swap(A, C); + // Swap bits 1/4 and 3/6. + uint8_t OldImm = Imm; + Imm = OldImm & 0xa5; + if (OldImm & 0x02) Imm |= 0x10; + if (OldImm & 0x10) Imm |= 0x02; + if (OldImm & 0x08) Imm |= 0x40; + if (OldImm & 0x40) Imm |= 0x08; + } else if (tryFoldLoadOrBCast(N, FoldableOp.getNode(), B, Tmp0, Tmp1, Tmp2, + Tmp3, Tmp4)) { + FoldedLoad = true; + std::swap(B, C); + // Swap bits 1/2 and 5/6. + uint8_t OldImm = Imm; + Imm = OldImm & 0x99; + if (OldImm & 0x02) Imm |= 0x04; + if (OldImm & 0x04) Imm |= 0x02; + if (OldImm & 0x20) Imm |= 0x40; + if (OldImm & 0x40) Imm |= 0x20; + } + SDLoc DL(N); - SDValue New = CurDAG->getNode(X86ISD::VPTERNLOG, DL, NVT, A, B, C, - CurDAG->getTargetConstant(Imm, DL, MVT::i8)); - ReplaceNode(N, New.getNode()); - SelectCode(New.getNode()); + + SDValue TImm = CurDAG->getTargetConstant(Imm, DL, MVT::i8); + + MachineSDNode *MNode; + if (FoldedLoad) { + SDVTList VTs = CurDAG->getVTList(NVT, MVT::Other); + + unsigned Opc; + if (C.getOpcode() == X86ISD::VBROADCAST_LOAD) { + auto *MemIntr = cast<MemIntrinsicSDNode>(C); + unsigned EltSize = MemIntr->getMemoryVT().getSizeInBits(); + assert((EltSize == 32 || EltSize == 64) && "Unexpected broadcast size!"); + + bool UseD = EltSize == 32; + if (NVT.is128BitVector()) + Opc = UseD ? X86::VPTERNLOGDZ128rmbi : X86::VPTERNLOGQZ128rmbi; + else if (NVT.is256BitVector()) + Opc = UseD ? X86::VPTERNLOGDZ256rmbi : X86::VPTERNLOGQZ256rmbi; + else if (NVT.is512BitVector()) + Opc = UseD ? X86::VPTERNLOGDZrmbi : X86::VPTERNLOGQZrmbi; + else + llvm_unreachable("Unexpected vector size!"); + } else { + bool UseD = NVT.getVectorElementType() == MVT::i32; + if (NVT.is128BitVector()) + Opc = UseD ? X86::VPTERNLOGDZ128rmi : X86::VPTERNLOGQZ128rmi; + else if (NVT.is256BitVector()) + Opc = UseD ? X86::VPTERNLOGDZ256rmi : X86::VPTERNLOGQZ256rmi; + else if (NVT.is512BitVector()) + Opc = UseD ? X86::VPTERNLOGDZrmi : X86::VPTERNLOGQZrmi; + else + llvm_unreachable("Unexpected vector size!"); + } + + SDValue Ops[] = {A, B, Tmp0, Tmp1, Tmp2, Tmp3, Tmp4, TImm, C.getOperand(0)}; + MNode = CurDAG->getMachineNode(Opc, DL, VTs, Ops); + + // Update the chain. + ReplaceUses(C.getValue(1), SDValue(MNode, 1)); + // Record the mem-refs + CurDAG->setNodeMemRefs(MNode, {cast<MemSDNode>(C)->getMemOperand()}); + } else { + bool UseD = NVT.getVectorElementType() == MVT::i32; + unsigned Opc; + if (NVT.is128BitVector()) + Opc = UseD ? X86::VPTERNLOGDZ128rri : X86::VPTERNLOGQZ128rri; + else if (NVT.is256BitVector()) + Opc = UseD ? X86::VPTERNLOGDZ256rri : X86::VPTERNLOGQZ256rri; + else if (NVT.is512BitVector()) + Opc = UseD ? X86::VPTERNLOGDZrri : X86::VPTERNLOGQZrri; + else + llvm_unreachable("Unexpected vector size!"); + + MNode = CurDAG->getMachineNode(Opc, DL, NVT, {A, B, C, TImm}); + } + + ReplaceUses(SDValue(N, 0), SDValue(MNode, 0)); + CurDAG->RemoveDeadNode(N); return true; } diff --git a/llvm/test/CodeGen/X86/avx512-logic.ll b/llvm/test/CodeGen/X86/avx512-logic.ll --- a/llvm/test/CodeGen/X86/avx512-logic.ll +++ b/llvm/test/CodeGen/X86/avx512-logic.ll @@ -887,34 +887,20 @@ } define <16 x i32> @ternlog_or_and_mask(<16 x i32> %x, <16 x i32> %y) { -; KNL-LABEL: ternlog_or_and_mask: -; KNL: ## %bb.0: -; KNL-NEXT: vpandq {{.*}}(%rip), %zmm0, %zmm0 -; KNL-NEXT: vpord %zmm1, %zmm0, %zmm0 -; KNL-NEXT: retq -; -; SKX-LABEL: ternlog_or_and_mask: -; SKX: ## %bb.0: -; SKX-NEXT: vandps {{.*}}(%rip), %zmm0, %zmm0 -; SKX-NEXT: vorps %zmm1, %zmm0, %zmm0 -; SKX-NEXT: retq +; ALL-LABEL: ternlog_or_and_mask: +; ALL: ## %bb.0: +; ALL-NEXT: vpternlogd $236, {{.*}}(%rip), %zmm1, %zmm0 +; ALL-NEXT: retq %a = and <16 x i32> %x, <i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255, i32 255> %b = or <16 x i32> %a, %y ret <16 x i32> %b } define <8 x i64> @ternlog_xor_and_mask(<8 x i64> %x, <8 x i64> %y) { -; KNL-LABEL: ternlog_xor_and_mask: -; KNL: ## %bb.0: -; KNL-NEXT: vpandd {{.*}}(%rip), %zmm0, %zmm0 -; KNL-NEXT: vpxorq %zmm1, %zmm0, %zmm0 -; KNL-NEXT: retq -; -; SKX-LABEL: ternlog_xor_and_mask: -; SKX: ## %bb.0: -; SKX-NEXT: vandps {{.*}}(%rip), %zmm0, %zmm0 -; SKX-NEXT: vxorps %zmm1, %zmm0, %zmm0 -; SKX-NEXT: retq +; ALL-LABEL: ternlog_xor_and_mask: +; ALL: ## %bb.0: +; ALL-NEXT: vpternlogq $108, {{.*}}(%rip), %zmm1, %zmm0 +; ALL-NEXT: retq %a = and <8 x i64> %x, <i64 4294967295, i64 4294967295, i64 4294967295, i64 4294967295, i64 4294967295, i64 4294967295, i64 4294967295, i64 4294967295> %b = xor <8 x i64> %a, %y ret <8 x i64> %b diff --git a/llvm/test/CodeGen/X86/avx512vl-logic.ll b/llvm/test/CodeGen/X86/avx512vl-logic.ll --- a/llvm/test/CodeGen/X86/avx512vl-logic.ll +++ b/llvm/test/CodeGen/X86/avx512vl-logic.ll @@ -991,8 +991,7 @@ define <4 x i32> @ternlog_or_and_mask(<4 x i32> %x, <4 x i32> %y) { ; CHECK-LABEL: ternlog_or_and_mask: ; CHECK: ## %bb.0: -; CHECK-NEXT: vandps {{.*}}(%rip), %xmm0, %xmm0 -; CHECK-NEXT: vorps %xmm1, %xmm0, %xmm0 +; CHECK-NEXT: vpternlogd $236, {{.*}}(%rip), %xmm1, %xmm0 ; CHECK-NEXT: retq %a = and <4 x i32> %x, <i32 255, i32 255, i32 255, i32 255> %b = or <4 x i32> %a, %y @@ -1002,8 +1001,7 @@ define <8 x i32> @ternlog_or_and_mask_ymm(<8 x i32> %x, <8 x i32> %y) { ; CHECK-LABEL: ternlog_or_and_mask_ymm: ; CHECK: ## %bb.0: -; CHECK-NEXT: vandps {{.*}}(%rip), %ymm0, %ymm0 -; CHECK-NEXT: vorps %ymm1, %ymm0, %ymm0 +; CHECK-NEXT: vpternlogd $236, {{.*}}(%rip), %ymm1, %ymm0 ; CHECK-NEXT: retq %a = and <8 x i32> %x, <i32 -16777216, i32 -16777216, i32 -16777216, i32 -16777216, i32 -16777216, i32 -16777216, i32 -16777216, i32 -16777216> %b = or <8 x i32> %a, %y @@ -1013,8 +1011,7 @@ define <2 x i64> @ternlog_xor_and_mask(<2 x i64> %x, <2 x i64> %y) { ; CHECK-LABEL: ternlog_xor_and_mask: ; CHECK: ## %bb.0: -; CHECK-NEXT: vandps {{.*}}(%rip), %xmm0, %xmm0 -; CHECK-NEXT: vxorps %xmm1, %xmm0, %xmm0 +; CHECK-NEXT: vpternlogq $108, {{.*}}(%rip), %xmm1, %xmm0 ; CHECK-NEXT: retq %a = and <2 x i64> %x, <i64 1099511627775, i64 1099511627775> %b = xor <2 x i64> %a, %y @@ -1024,8 +1021,7 @@ define <4 x i64> @ternlog_xor_and_mask_ymm(<4 x i64> %x, <4 x i64> %y) { ; CHECK-LABEL: ternlog_xor_and_mask_ymm: ; CHECK: ## %bb.0: -; CHECK-NEXT: vandps {{.*}}(%rip), %ymm0, %ymm0 -; CHECK-NEXT: vxorps %ymm1, %ymm0, %ymm0 +; CHECK-NEXT: vpternlogq $108, {{.*}}(%rip), %ymm1, %ymm0 ; CHECK-NEXT: retq %a = and <4 x i64> %x, <i64 72057594037927935, i64 72057594037927935, i64 72057594037927935, i64 72057594037927935> %b = xor <4 x i64> %a, %y