diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp --- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp +++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp @@ -3940,30 +3940,39 @@ if (!(Subtarget->hasVLX() || NVT.is512BitVector())) return false; - unsigned Opc1 = N->getOpcode(); SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); - auto isLogicOp = [](unsigned Opc) { - return Opc == ISD::AND || Opc == ISD::OR || Opc == ISD::XOR || - Opc == X86ISD::ANDNP; + auto getFoldableLogicOp = [](SDValue Op) { + // Peek through single use bitcast. + if (Op.getOpcode() == ISD::BITCAST && Op.hasOneUse()) + Op = Op.getOperand(0); + + if (!Op.hasOneUse()) + return SDValue(); + + unsigned Opc = Op.getOpcode(); + if (Opc == ISD::AND || Opc == ISD::OR || Opc == ISD::XOR || + Opc == X86ISD::ANDNP) + return Op; + + return SDValue(); }; - SDValue A, B, C; - unsigned Opc2; - if (isLogicOp(N1.getOpcode()) && N1.hasOneUse()) { - Opc2 = N1.getOpcode(); + SDValue A, FoldableOp; + if ((FoldableOp = getFoldableLogicOp(N1))) { A = N0; - B = N1.getOperand(0); - C = N1.getOperand(1); - } else if (isLogicOp(N0.getOpcode()) && N0.hasOneUse()) { - Opc2 = N0.getOpcode(); + } else if ((FoldableOp = getFoldableLogicOp(N0))) { A = N1; - B = N0.getOperand(0); - C = N0.getOperand(1); } else return false; + SDValue B = FoldableOp.getOperand(0); + SDValue C = FoldableOp.getOperand(1); + + unsigned Opc1 = N->getOpcode(); + unsigned Opc2 = FoldableOp.getOpcode(); + uint64_t Imm; switch (Opc1) { default: llvm_unreachable("Unexpected opcode!"); @@ -3996,11 +4005,117 @@ break; } + auto tryFoldLoadOrBCast = + [this](SDNode *Root, SDNode *P, SDValue &L, SDValue &Base, SDValue &Scale, + SDValue &Index, SDValue &Disp, SDValue &Segment) { + if (tryFoldLoad(Root, P, L, Base, Scale, Index, Disp, Segment)) + return true; + + // Not a load, check for broadcast which may be behind a bitcast. + if (L.getOpcode() == ISD::BITCAST && L.hasOneUse()) { + P = L.getNode(); + L = L.getOperand(0); + } + + if (L.getOpcode() != X86ISD::VBROADCAST_LOAD) + return false; + + // Only 32 and 64 bit broadcasts are supported. + auto *MemIntr = cast(L); + unsigned Size = MemIntr->getMemoryVT().getSizeInBits(); + if (Size != 32 && Size != 64) + return false; + + return tryFoldBroadcast(Root, P, L, Base, Scale, Index, Disp, Segment); + }; + + bool FoldedLoad = false; + SDValue Tmp0, Tmp1, Tmp2, Tmp3, Tmp4; + if (tryFoldLoadOrBCast(N, FoldableOp.getNode(), C, Tmp0, Tmp1, Tmp2, Tmp3, + Tmp4)) { + FoldedLoad = true; + } else if (tryFoldLoadOrBCast(N, N, A, Tmp0, Tmp1, Tmp2, Tmp3, Tmp4)) { + FoldedLoad = true; + std::swap(A, C); + // Swap bits 1/4 and 3/6. + uint8_t OldImm = Imm; + Imm = OldImm & 0xa5; + if (OldImm & 0x02) Imm |= 0x10; + if (OldImm & 0x10) Imm |= 0x02; + if (OldImm & 0x08) Imm |= 0x40; + if (OldImm & 0x40) Imm |= 0x08; + } else if (tryFoldLoadOrBCast(N, FoldableOp.getNode(), B, Tmp0, Tmp1, Tmp2, + Tmp3, Tmp4)) { + FoldedLoad = true; + std::swap(B, C); + // Swap bits 1/2 and 5/6. + uint8_t OldImm = Imm; + Imm = OldImm & 0x99; + if (OldImm & 0x02) Imm |= 0x04; + if (OldImm & 0x04) Imm |= 0x02; + if (OldImm & 0x20) Imm |= 0x40; + if (OldImm & 0x40) Imm |= 0x20; + } + SDLoc DL(N); - SDValue New = CurDAG->getNode(X86ISD::VPTERNLOG, DL, NVT, A, B, C, - CurDAG->getTargetConstant(Imm, DL, MVT::i8)); - ReplaceNode(N, New.getNode()); - SelectCode(New.getNode()); + + SDValue TImm = CurDAG->getTargetConstant(Imm, DL, MVT::i8); + + MachineSDNode *MNode; + if (FoldedLoad) { + SDVTList VTs = CurDAG->getVTList(NVT, MVT::Other); + + unsigned Opc; + if (C.getOpcode() == X86ISD::VBROADCAST_LOAD) { + auto *MemIntr = cast(C); + unsigned EltSize = MemIntr->getMemoryVT().getSizeInBits(); + assert((EltSize == 32 || EltSize == 64) && "Unexpected broadcast size!"); + + bool UseD = EltSize == 32; + if (NVT.is128BitVector()) + Opc = UseD ? X86::VPTERNLOGDZ128rmbi : X86::VPTERNLOGQZ128rmbi; + else if (NVT.is256BitVector()) + Opc = UseD ? X86::VPTERNLOGDZ256rmbi : X86::VPTERNLOGQZ256rmbi; + else if (NVT.is512BitVector()) + Opc = UseD ? X86::VPTERNLOGDZrmbi : X86::VPTERNLOGQZrmbi; + else + llvm_unreachable("Unexpected vector size!"); + } else { + bool UseD = NVT.getVectorElementType() == MVT::i32; + if (NVT.is128BitVector()) + Opc = UseD ? X86::VPTERNLOGDZ128rmi : X86::VPTERNLOGQZ128rmi; + else if (NVT.is256BitVector()) + Opc = UseD ? X86::VPTERNLOGDZ256rmi : X86::VPTERNLOGQZ256rmi; + else if (NVT.is512BitVector()) + Opc = UseD ? X86::VPTERNLOGDZrmi : X86::VPTERNLOGQZrmi; + else + llvm_unreachable("Unexpected vector size!"); + } + + SDValue Ops[] = {A, B, Tmp0, Tmp1, Tmp2, Tmp3, Tmp4, TImm, C.getOperand(0)}; + MNode = CurDAG->getMachineNode(Opc, DL, VTs, Ops); + + // Update the chain. + ReplaceUses(C.getValue(1), SDValue(MNode, 1)); + // Record the mem-refs + CurDAG->setNodeMemRefs(MNode, {cast(C)->getMemOperand()}); + } else { + bool UseD = NVT.getVectorElementType() == MVT::i32; + unsigned Opc; + if (NVT.is128BitVector()) + Opc = UseD ? X86::VPTERNLOGDZ128rri : X86::VPTERNLOGQZ128rri; + else if (NVT.is256BitVector()) + Opc = UseD ? X86::VPTERNLOGDZ256rri : X86::VPTERNLOGQZ256rri; + else if (NVT.is512BitVector()) + Opc = UseD ? X86::VPTERNLOGDZrri : X86::VPTERNLOGQZrri; + else + llvm_unreachable("Unexpected vector size!"); + + MNode = CurDAG->getMachineNode(Opc, DL, NVT, {A, B, C, TImm}); + } + + ReplaceUses(SDValue(N, 0), SDValue(MNode, 0)); + CurDAG->RemoveDeadNode(N); return true; } diff --git a/llvm/test/CodeGen/X86/avx512-logic.ll b/llvm/test/CodeGen/X86/avx512-logic.ll --- a/llvm/test/CodeGen/X86/avx512-logic.ll +++ b/llvm/test/CodeGen/X86/avx512-logic.ll @@ -887,34 +887,20 @@ } define <16 x i32> @ternlog_or_and_mask(<16 x i32> %x, <16 x i32> %y) { -; KNL-LABEL: ternlog_or_and_mask: -; KNL: ## %bb.0: -; KNL-NEXT: vpandq {{.*}}(%rip), %zmm0, %zmm0 -; KNL-NEXT: vpord %zmm1, %zmm0, %zmm0 -; KNL-NEXT: retq -; -; SKX-LABEL: ternlog_or_and_mask: -; SKX: ## %bb.0: -; SKX-NEXT: vandps {{.*}}(%rip), %zmm0, %zmm0 -; SKX-NEXT: vorps %zmm1, %zmm0, %zmm0 -; SKX-NEXT: retq +; ALL-LABEL: ternlog_or_and_mask: +; ALL: ## %bb.0: +; ALL-NEXT: vpternlogd $236, {{.*}}(%rip), %zmm1, %zmm0 +; ALL-NEXT: retq %a = and <16 x i32> %x, %b = or <16 x i32> %a, %y ret <16 x i32> %b } define <8 x i64> @ternlog_xor_and_mask(<8 x i64> %x, <8 x i64> %y) { -; KNL-LABEL: ternlog_xor_and_mask: -; KNL: ## %bb.0: -; KNL-NEXT: vpandd {{.*}}(%rip), %zmm0, %zmm0 -; KNL-NEXT: vpxorq %zmm1, %zmm0, %zmm0 -; KNL-NEXT: retq -; -; SKX-LABEL: ternlog_xor_and_mask: -; SKX: ## %bb.0: -; SKX-NEXT: vandps {{.*}}(%rip), %zmm0, %zmm0 -; SKX-NEXT: vxorps %zmm1, %zmm0, %zmm0 -; SKX-NEXT: retq +; ALL-LABEL: ternlog_xor_and_mask: +; ALL: ## %bb.0: +; ALL-NEXT: vpternlogq $108, {{.*}}(%rip), %zmm1, %zmm0 +; ALL-NEXT: retq %a = and <8 x i64> %x, %b = xor <8 x i64> %a, %y ret <8 x i64> %b diff --git a/llvm/test/CodeGen/X86/avx512vl-logic.ll b/llvm/test/CodeGen/X86/avx512vl-logic.ll --- a/llvm/test/CodeGen/X86/avx512vl-logic.ll +++ b/llvm/test/CodeGen/X86/avx512vl-logic.ll @@ -991,8 +991,7 @@ define <4 x i32> @ternlog_or_and_mask(<4 x i32> %x, <4 x i32> %y) { ; CHECK-LABEL: ternlog_or_and_mask: ; CHECK: ## %bb.0: -; CHECK-NEXT: vandps {{.*}}(%rip), %xmm0, %xmm0 -; CHECK-NEXT: vorps %xmm1, %xmm0, %xmm0 +; CHECK-NEXT: vpternlogd $236, {{.*}}(%rip), %xmm1, %xmm0 ; CHECK-NEXT: retq %a = and <4 x i32> %x, %b = or <4 x i32> %a, %y @@ -1002,8 +1001,7 @@ define <8 x i32> @ternlog_or_and_mask_ymm(<8 x i32> %x, <8 x i32> %y) { ; CHECK-LABEL: ternlog_or_and_mask_ymm: ; CHECK: ## %bb.0: -; CHECK-NEXT: vandps {{.*}}(%rip), %ymm0, %ymm0 -; CHECK-NEXT: vorps %ymm1, %ymm0, %ymm0 +; CHECK-NEXT: vpternlogd $236, {{.*}}(%rip), %ymm1, %ymm0 ; CHECK-NEXT: retq %a = and <8 x i32> %x, %b = or <8 x i32> %a, %y @@ -1013,8 +1011,7 @@ define <2 x i64> @ternlog_xor_and_mask(<2 x i64> %x, <2 x i64> %y) { ; CHECK-LABEL: ternlog_xor_and_mask: ; CHECK: ## %bb.0: -; CHECK-NEXT: vandps {{.*}}(%rip), %xmm0, %xmm0 -; CHECK-NEXT: vxorps %xmm1, %xmm0, %xmm0 +; CHECK-NEXT: vpternlogq $108, {{.*}}(%rip), %xmm1, %xmm0 ; CHECK-NEXT: retq %a = and <2 x i64> %x, %b = xor <2 x i64> %a, %y @@ -1024,8 +1021,7 @@ define <4 x i64> @ternlog_xor_and_mask_ymm(<4 x i64> %x, <4 x i64> %y) { ; CHECK-LABEL: ternlog_xor_and_mask_ymm: ; CHECK: ## %bb.0: -; CHECK-NEXT: vandps {{.*}}(%rip), %ymm0, %ymm0 -; CHECK-NEXT: vxorps %ymm1, %ymm0, %ymm0 +; CHECK-NEXT: vpternlogq $108, {{.*}}(%rip), %ymm1, %ymm0 ; CHECK-NEXT: retq %a = and <4 x i64> %x, %b = xor <4 x i64> %a, %y