diff --git a/llvm/lib/Target/X86/X86EvexToVex.cpp b/llvm/lib/Target/X86/X86EvexToVex.cpp --- a/llvm/lib/Target/X86/X86EvexToVex.cpp +++ b/llvm/lib/Target/X86/X86EvexToVex.cpp @@ -151,24 +151,6 @@ (void)NewOpc; unsigned Opc = MI.getOpcode(); switch (Opc) { - case X86::VPDPBUSDSZ256m: - case X86::VPDPBUSDSZ256r: - case X86::VPDPBUSDSZ128m: - case X86::VPDPBUSDSZ128r: - case X86::VPDPBUSDZ256m: - case X86::VPDPBUSDZ256r: - case X86::VPDPBUSDZ128m: - case X86::VPDPBUSDZ128r: - case X86::VPDPWSSDSZ256m: - case X86::VPDPWSSDSZ256r: - case X86::VPDPWSSDSZ128m: - case X86::VPDPWSSDSZ128r: - case X86::VPDPWSSDZ256m: - case X86::VPDPWSSDZ256r: - case X86::VPDPWSSDZ128m: - case X86::VPDPWSSDZ128r: - // These can only VEX convert if AVXVNNI is enabled. - return ST->hasAVXVNNI(); case X86::VALIGNDZ128rri: case X86::VALIGNDZ128rmi: case X86::VALIGNQZ128rri: @@ -268,8 +250,9 @@ // Use the VEX.L bit to select the 128 or 256-bit table. ArrayRef Table = - (Desc.TSFlags & X86II::VEX_L) ? makeArrayRef(X86EvexToVex256CompressTable) - : makeArrayRef(X86EvexToVex128CompressTable); + (Desc.TSFlags & X86II::VEX_L) + ? makeArrayRef(X86EvexToVex256CompressTable) + : makeArrayRef(X86EvexToVex128CompressTable); const auto *I = llvm::lower_bound(Table, MI.getOpcode()); if (I == Table.end() || I->EvexOpcode != MI.getOpcode()) @@ -280,6 +263,9 @@ if (usesExtendedRegister(MI)) return false; + if (!CheckVEXInstPredicate(MI, ST)) + return false; + if (!performCustomAdjustments(MI, NewOpc, ST)) return false; diff --git a/llvm/lib/Target/X86/X86InstrFormats.td b/llvm/lib/Target/X86/X86InstrFormats.td --- a/llvm/lib/Target/X86/X86InstrFormats.td +++ b/llvm/lib/Target/X86/X86InstrFormats.td @@ -236,6 +236,7 @@ class EVEX_V128 { bit hasEVEX_L2 = 0; bit hasVEX_L = 0; } class NOTRACK { bit hasNoTrackPrefix = 1; } class SIMD_EXC { list Uses = [MXCSR]; bit mayRaiseFPException = 1; } +class CheckPredicate { bit checkPredicate = 1; } // Specify AVX512 8-bit compressed displacement encoding based on the vector // element size in bits (8, 16, 32, 64) and the CDisp8 form. @@ -352,6 +353,7 @@ bit isMemoryFoldable = 1; // Is it allowed to memory fold/unfold this instruction? bit notEVEX2VEXConvertible = 0; // Prevent EVEX->VEX conversion. bit ExplicitVEXPrefix = 0; // Force the instruction to use VEX encoding. + bit checkPredicate = 0; // Does this VEX inst should check predicate? // TSFlags layout should be kept in sync with X86BaseInfo.h. let TSFlags{6-0} = FormBits; diff --git a/llvm/lib/Target/X86/X86InstrSSE.td b/llvm/lib/Target/X86/X86InstrSSE.td --- a/llvm/lib/Target/X86/X86InstrSSE.td +++ b/llvm/lib/Target/X86/X86InstrSSE.td @@ -7200,10 +7200,10 @@ VEX_4V, VEX_L, Sched<[SchedWriteVecIMul.XMM]>; } -defm VPDPBUSD : avx_vnni_rm<0x50, "vpdpbusd", X86Vpdpbusd, 0>, ExplicitVEXPrefix; -defm VPDPBUSDS : avx_vnni_rm<0x51, "vpdpbusds", X86Vpdpbusds, 0>, ExplicitVEXPrefix; -defm VPDPWSSD : avx_vnni_rm<0x52, "vpdpwssd", X86Vpdpwssd, 1>, ExplicitVEXPrefix; -defm VPDPWSSDS : avx_vnni_rm<0x53, "vpdpwssds", X86Vpdpwssds, 1>, ExplicitVEXPrefix; +defm VPDPBUSD : avx_vnni_rm<0x50, "vpdpbusd", X86Vpdpbusd, 0>, ExplicitVEXPrefix, CheckPredicate; +defm VPDPBUSDS : avx_vnni_rm<0x51, "vpdpbusds", X86Vpdpbusds, 0>, ExplicitVEXPrefix, CheckPredicate; +defm VPDPWSSD : avx_vnni_rm<0x52, "vpdpwssd", X86Vpdpwssd, 1>, ExplicitVEXPrefix, CheckPredicate; +defm VPDPWSSDS : avx_vnni_rm<0x53, "vpdpwssds", X86Vpdpwssds, 1>, ExplicitVEXPrefix, CheckPredicate; def X86vpmaddwd_su : PatFrag<(ops node:$lhs, node:$rhs), (X86vpmaddwd node:$lhs, node:$rhs), [{ diff --git a/llvm/utils/TableGen/X86EVEX2VEXTablesEmitter.cpp b/llvm/utils/TableGen/X86EVEX2VEXTablesEmitter.cpp --- a/llvm/utils/TableGen/X86EVEX2VEXTablesEmitter.cpp +++ b/llvm/utils/TableGen/X86EVEX2VEXTablesEmitter.cpp @@ -30,10 +30,13 @@ std::map> VEXInsts; typedef std::pair Entry; + typedef std::pair Predicate; // Represent both compress tables std::vector EVEX2VEX128; std::vector EVEX2VEX256; + // Represent predicates of VEX instructions. + std::vector EVEX2VEXPredicates; public: X86EVEX2VEXTablesEmitter(RecordKeeper &R) : Records(R), Target(R) {} @@ -45,6 +48,9 @@ // Prints the given table as a C++ array of type // X86EvexToVexCompressTableEntry void printTable(const std::vector &Table, raw_ostream &OS); + + void printCheckPredicate(const std::vector &Predicates, + raw_ostream &OS); }; void X86EVEX2VEXTablesEmitter::printTable(const std::vector &Table, @@ -67,6 +73,19 @@ OS << "};\n\n"; } +void X86EVEX2VEXTablesEmitter::printCheckPredicate( + const std::vector &Predicates, raw_ostream &OS) { + OS << "static bool CheckVEXInstPredicate" + << "(MachineInstr &MI, const X86Subtarget *Subtarget) {\n" + << " unsigned Opc = MI.getOpcode();\n" + << " switch (Opc) {\n" + << " default: return true;\n"; + for (auto Pair : Predicates) + OS << " case X86::" << Pair.first << ": return " << Pair.second << ";\n"; + OS << " }\n" + << "}\n\n"; +} + // Return true if the 2 BitsInits are equal // Calculates the integer value residing BitsInit object static inline uint64_t getValueFromBitsInit(const BitsInit *B) { @@ -169,6 +188,26 @@ }; void X86EVEX2VEXTablesEmitter::run(raw_ostream &OS) { + auto getPredicates = [&](const CodeGenInstruction *Inst) { + std::string Predicates; + std::vector PredicatesRecords = + Inst->TheDef->getValueAsListOfDefs("Predicates"); + for (unsigned i = 0, e = PredicatesRecords.size(); i != e; ++i) { + StringRef PredicatesRecordsName = PredicatesRecords[i]->getName(); + // Currently we only do AVX related checks. + if (PredicatesRecordsName.startswith("HasAVX")) { + if (!Predicates.empty()) + Predicates += " && "; + Predicates += + '(' + PredicatesRecords[i]->getValueAsString("CondString").str() + + ')'; + } + } + if (Predicates.empty()) + Predicates = "true"; + return Predicates; + }; + emitSourceFileHeader("X86 EVEX2VEX tables", OS); ArrayRef NumberedInstructions = @@ -222,11 +261,17 @@ EVEX2VEX256.push_back(std::make_pair(EVEXInst, VEXInst)); // {0,1} else EVEX2VEX128.push_back(std::make_pair(EVEXInst, VEXInst)); // {0,0} + + // Adding predicate check to EVEX2VEXPredicates table when needed. + if (VEXInst->TheDef->getValueAsBit("checkPredicate")) + EVEX2VEXPredicates.push_back( + std::make_pair(EVEXInst->TheDef->getName(), getPredicates(VEXInst))); } // Print both tables printTable(EVEX2VEX128, OS); printTable(EVEX2VEX256, OS); + printCheckPredicate(EVEX2VEXPredicates, OS); } }