Index: include/llvm/Target/TargetRegisterInfo.h =================================================================== --- include/llvm/Target/TargetRegisterInfo.h +++ include/llvm/Target/TargetRegisterInfo.h @@ -469,6 +469,10 @@ return nullptr; } + /// Return all the call-preserved register masks defined for this target. + virtual ArrayRef getRegMasks() const = 0; + virtual ArrayRef getRegMaskNames() const = 0; + /// getReservedRegs - Returns a bitset indexed by physical register number /// indicating if a register is a special register that has particular uses /// and should be considered unavailable at all times, e.g. SP, RA. This is Index: lib/CodeGen/MIRParser/MIParser.cpp =================================================================== --- lib/CodeGen/MIRParser/MIParser.cpp +++ lib/CodeGen/MIRParser/MIParser.cpp @@ -42,6 +42,8 @@ StringMap Names2InstrOpCodes; /// Maps from register names to registers. StringMap Names2Regs; + /// Maps from register mask names to register masks. + StringMap Names2RegMasks; public: MIParser(SourceMgr &SM, MachineFunction &MF, SMDiagnostic &Error, @@ -89,6 +91,13 @@ /// Try to convert a register name to a register number. Return true if the /// register name is invalid. bool getRegisterByName(StringRef RegName, unsigned &Reg); + + void initNames2RegMasks(); + + /// Check if the given identifier is a name of a register mask. + /// + /// Return null if the identifier isn't a register mask. + const uint32_t *getRegMask(StringRef Identifier); }; } // end anonymous namespace @@ -168,7 +177,8 @@ // Mark this register as implicit to prevent an assertion when it's added // to an instruction. This is a temporary workaround until the implicit // register flag can be parsed. - Operands[I].setImplicit(); + if (Operands[I].isReg()) + Operands[I].setImplicit(); } } @@ -301,6 +311,13 @@ return parseGlobalAddressOperand(Dest); case MIToken::Error: return true; + case MIToken::Identifier: + if (const auto *RegMask = getRegMask(Token.stringValue())) { + Dest = MachineOperand::CreateRegMask(RegMask); + lex(); + break; + } + // fallthrough default: // TODO: parse the other machine operands. return error("expected a machine operand"); @@ -351,6 +368,27 @@ return false; } +void MIParser::initNames2RegMasks() { + if (!Names2RegMasks.empty()) + return; + const auto *TRI = MF.getSubtarget().getRegisterInfo(); + assert(TRI && "Expected target register info"); + ArrayRef RegMasks = TRI->getRegMasks(); + ArrayRef RegMaskNames = TRI->getRegMaskNames(); + assert(RegMasks.size() == RegMaskNames.size()); + for (size_t I = 0, E = RegMasks.size(); I < E; ++I) + Names2RegMasks.insert( + std::make_pair(StringRef(RegMaskNames[I]).lower(), RegMasks[I])); +} + +const uint32_t *MIParser::getRegMask(StringRef Identifier) { + initNames2RegMasks(); + auto RegMaskInfo = Names2RegMasks.find(Identifier); + if (RegMaskInfo == Names2RegMasks.end()) + return nullptr; + return RegMaskInfo->getValue(); +} + MachineInstr *llvm::parseMachineInstr( SourceMgr &SM, MachineFunction &MF, StringRef Src, const DenseMap &MBBMapping, Index: lib/CodeGen/MIRPrinter.cpp =================================================================== --- lib/CodeGen/MIRPrinter.cpp +++ lib/CodeGen/MIRPrinter.cpp @@ -32,6 +32,7 @@ /// format. class MIRPrinter { raw_ostream &OS; + DenseMap RegisterMaskIds; public: MIRPrinter(raw_ostream &OS) : OS(OS) {} @@ -40,6 +41,9 @@ void convert(const Module &M, yaml::MachineBasicBlock &YamlMBB, const MachineBasicBlock &MBB); + +private: + void initRegisterMaskIds(const MachineFunction &MF); }; /// This class prints out the machine instructions using the MIR serialization @@ -47,9 +51,12 @@ class MIPrinter { const Module &M; raw_ostream &OS; + const DenseMap &RegisterMaskIds; public: - MIPrinter(const Module &M, raw_ostream &OS) : M(M), OS(OS) {} + MIPrinter(const Module &M, raw_ostream &OS, + const DenseMap &RegisterMaskIds) + : M(M), OS(OS), RegisterMaskIds(RegisterMaskIds) {} void print(const MachineInstr &MI); void print(const MachineOperand &Op, const TargetRegisterInfo *TRI); @@ -75,6 +82,8 @@ } // end namespace llvm void MIRPrinter::print(const MachineFunction &MF) { + initRegisterMaskIds(MF); + yaml::MachineFunction YamlMF; YamlMF.Name = MF.getName(); YamlMF.Alignment = MF.getAlignment(); @@ -116,12 +125,19 @@ std::string Str; for (const auto &MI : MBB) { raw_string_ostream StrOS(Str); - MIPrinter(M, StrOS).print(MI); + MIPrinter(M, StrOS, RegisterMaskIds).print(MI); YamlMBB.Instructions.push_back(StrOS.str()); Str.clear(); } } +void MIRPrinter::initRegisterMaskIds(const MachineFunction &MF) { + const auto *TRI = MF.getSubtarget().getRegisterInfo(); + unsigned I = 0; + for (const uint32_t *Mask : TRI->getRegMasks()) + RegisterMaskIds.insert(std::make_pair(Mask, I++)); +} + void MIPrinter::print(const MachineInstr &MI) { const auto &SubTarget = MI.getParent()->getParent()->getSubtarget(); const auto *TRI = SubTarget.getRegisterInfo(); @@ -190,6 +206,14 @@ Op.getGlobal()->printAsOperand(OS, /*PrintType=*/false, &M); // TODO: Print offset and target flags. break; + case MachineOperand::MO_RegisterMask: { + auto RegMaskInfo = RegisterMaskIds.find(Op.getRegMask()); + if (RegMaskInfo != RegisterMaskIds.end()) + OS << StringRef(TRI->getRegMaskNames()[RegMaskInfo->second]).lower(); + else + llvm_unreachable("Can't print this machine register mask yet."); + break; + } default: // TODO: Print the other machine operands. llvm_unreachable("Can't print this machine operand at the moment"); Index: test/CodeGen/MIR/X86/register-mask-operands.mir =================================================================== --- /dev/null +++ test/CodeGen/MIR/X86/register-mask-operands.mir @@ -0,0 +1,43 @@ +# RUN: llc -march=x86-64 -start-after branch-folder -stop-after branch-folder -o /dev/null %s | FileCheck %s +# This test ensures that the MIR parser parses register mask operands correctly. + +--- | + + define i32 @compute(i32 %a) #0 { + body: + %c = mul i32 %a, 11 + ret i32 %c + } + + define i32 @foo(i32 %a) #0 { + entry: + %b = call i32 @compute(i32 %a) + ret i32 %b + } + + attributes #0 = { "no-frame-pointer-elim"="false" } + +... +--- +name: compute +body: + - id: 0 + name: body + instructions: + - '%eax = IMUL32rri8 %edi, 11' + - 'RETQ %eax' +... +--- +# CHECK: name: foo +name: foo +body: + - id: 0 + name: entry + instructions: + # CHECK: - 'PUSH64r %rax + # CHECK-NEXT: - 'CALL64pcrel32 @compute, csr_64, %rsp, %edi, %rsp, %eax' + - 'PUSH64r %rax' + - 'CALL64pcrel32 @compute, csr_64, %rsp, %edi, %rsp, %eax' + - '%rdx = POP64r' + - 'RETQ %eax' +... Index: utils/TableGen/RegisterInfoEmitter.cpp =================================================================== --- utils/TableGen/RegisterInfoEmitter.cpp +++ utils/TableGen/RegisterInfoEmitter.cpp @@ -1094,6 +1094,8 @@ << "const TargetRegisterClass *RC) const override;\n" << " const int *getRegUnitPressureSets(" << "unsigned RegUnit) const override;\n" + << " ArrayRef getRegMaskNames() const override;\n" + << " ArrayRef getRegMasks() const override;\n" << "};\n\n"; const auto &RegisterClasses = RegBank.getRegClasses(); @@ -1445,6 +1447,26 @@ } OS << "\n\n"; + OS << "ArrayRef " << ClassName + << "::getRegMasks() const {\n"; + OS << " static const uint32_t *Masks[] = {\n"; + for (Record *CSRSet : CSRSets) + OS << " " << CSRSet->getName() << "_RegMask, \n"; + OS << " nullptr\n };\n"; + OS << " return ArrayRef(Masks, (size_t)" << CSRSets.size() + << ");\n"; + OS << "}\n\n"; + + OS << "ArrayRef " << ClassName + << "::getRegMaskNames() const {\n"; + OS << " static const char *Names[] = {\n"; + for (Record *CSRSet : CSRSets) + OS << " " << '"' << CSRSet->getName() << '"' << ",\n"; + OS << " nullptr\n };\n"; + OS << " return ArrayRef(Names, (size_t)" << CSRSets.size() + << ");\n"; + OS << "}\n\n"; + OS << "} // End llvm namespace\n"; OS << "#endif // GET_REGINFO_TARGET_DESC\n\n"; }