diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt --- a/llvm/lib/Target/SPIRV/CMakeLists.txt +++ b/llvm/lib/Target/SPIRV/CMakeLists.txt @@ -15,6 +15,7 @@ add_llvm_target(SPIRVCodeGen SPIRVAsmPrinter.cpp SPIRVCallLowering.cpp + SPIRVDuplicatesTracker.cpp SPIRVEmitIntrinsics.cpp SPIRVGlobalRegistry.cpp SPIRVInstrInfo.cpp diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp --- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp @@ -68,6 +68,7 @@ ArrayRef> VRegs, FunctionLoweringInfo &FLI) const { assert(GR && "Must initialize the SPIRV type registry before lowering args."); + GR->setCurrentFunc(MIRBuilder.getMF()); // Assign types and names to all args, and store their types for later. SmallVector ArgTypeVRegs; @@ -114,6 +115,8 @@ auto MRI = MIRBuilder.getMRI(); Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32)); MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass); + if (F.isDeclaration()) + GR->add(&F, &MIRBuilder.getMF(), FuncVReg); auto *FTy = F.getFunctionType(); auto FuncTy = GR->assignTypeToVReg(FTy, FuncVReg, MIRBuilder); @@ -136,6 +139,8 @@ MIRBuilder.buildInstr(SPIRV::OpFunctionParameter) .addDef(VRegs[i][0]) .addUse(ArgTypeVRegs[i]); + if (F.isDeclaration()) + GR->add(F.getArg(i), &MIRBuilder.getMF(), VRegs[i][0]); } // Name the function. if (F.hasName()) @@ -165,6 +170,7 @@ if (Info.OrigRet.Regs.size() > 1) return false; + GR->setCurrentFunc(MIRBuilder.getMF()); Register ResVReg = Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0]; // Emit a regular OpFunctionCall. If it's an externally declared function, diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h @@ -0,0 +1,174 @@ +//===-- SPIRVDuplicatesTracker.h - SPIR-V Duplicates Tracker ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// General infrastructure for keeping track of the values that according to +// the SPIR-V binary layout should be global to the whole module. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H +#define LLVM_LIB_TARGET_SPIRV_SPIRVDUPLICATESTRACKER_H + +#include "MCTargetDesc/SPIRVBaseInfo.h" +#include "MCTargetDesc/SPIRVMCTargetDesc.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" +#include "llvm/CodeGen/MachineModuleInfo.h" + +#include + +namespace llvm { +namespace SPIRV { +// NOTE: using MapVector instead of DenseMap because it helps getting +// everything ordered in a stable manner for a price of extra (NumKeys)*PtrSize +// memory and expensive removals which do not happen anyway. +class DTSortableEntry : public MapVector { + SmallVector Deps; + + struct FlagsTy { + unsigned IsFunc : 1; + unsigned IsGV : 1; + // NOTE: bit-field default init is a C++20 feature. + FlagsTy() : IsFunc(0), IsGV(0) {} + }; + FlagsTy Flags; + +public: + // Common hoisting utility doesn't support function, because their hoisting + // require hoisting of params as well. + bool getIsFunc() const { return Flags.IsFunc; } + bool getIsGV() const { return Flags.IsGV; } + void setIsFunc(bool V) { Flags.IsFunc = V; } + void setIsGV(bool V) { Flags.IsGV = V; } + + const SmallVector &getDeps() const { return Deps; } + void addDep(DTSortableEntry *E) { Deps.push_back(E); } +}; +} // namespace SPIRV + +template class SPIRVDuplicatesTrackerBase { +public: + // NOTE: using MapVector instead of DenseMap helps getting everything ordered + // in a stable manner for a price of extra (NumKeys)*PtrSize memory and + // expensive removals which don't happen anyway. + using StorageTy = MapVector; + +private: + StorageTy Storage; + +public: + void add(KeyTy V, const MachineFunction *MF, Register R) { + if (find(V, MF).isValid()) + return; + + Storage[V][MF] = R; + if (std::is_same::type>::type>() || + std::is_same::type>::type>()) + Storage[V].setIsFunc(true); + if (std::is_same::type>::type>()) + Storage[V].setIsGV(true); + } + + Register find(KeyTy V, const MachineFunction *MF) const { + auto iter = Storage.find(V); + if (iter != Storage.end()) { + auto Map = iter->second; + auto iter2 = Map.find(MF); + if (iter2 != Map.end()) + return iter2->second; + } + return Register(); + } + + const StorageTy &getAllUses() const { return Storage; } + +private: + StorageTy &getAllUses() { return Storage; } + + // The friend class needs to have access to the internal storage + // to be able to build dependency graph, can't declare only one + // function a 'friend' due to the incomplete declaration at this point + // and mutual dependency problems. + friend class SPIRVGeneralDuplicatesTracker; +}; + +template +class SPIRVDuplicatesTracker : public SPIRVDuplicatesTrackerBase {}; + +class SPIRVGeneralDuplicatesTracker { + SPIRVDuplicatesTracker TT; + SPIRVDuplicatesTracker CT; + SPIRVDuplicatesTracker GT; + SPIRVDuplicatesTracker FT; + SPIRVDuplicatesTracker AT; + + // NOTE: using MOs instead of regs to get rid of MF dependency to be able + // to use flat data structure. + // NOTE: replacing DenseMap with MapVector doesn't affect overall correctness + // but makes LITs more stable, should prefer DenseMap still due to + // significant perf difference. + using SPIRVReg2EntryTy = + MapVector; + + template + void prebuildReg2Entry(SPIRVDuplicatesTracker &DT, + SPIRVReg2EntryTy &Reg2Entry); + +public: + void buildDepsGraph(std::vector &Graph, + MachineModuleInfo *MMI); + + void add(const Type *T, const MachineFunction *MF, Register R) { + TT.add(T, MF, R); + } + + void add(const Constant *C, const MachineFunction *MF, Register R) { + CT.add(C, MF, R); + } + + void add(const GlobalVariable *GV, const MachineFunction *MF, Register R) { + GT.add(GV, MF, R); + } + + void add(const Function *F, const MachineFunction *MF, Register R) { + FT.add(F, MF, R); + } + + void add(const Argument *Arg, const MachineFunction *MF, Register R) { + AT.add(Arg, MF, R); + } + + Register find(const Type *T, const MachineFunction *MF) { + return TT.find(const_cast(T), MF); + } + + Register find(const Constant *C, const MachineFunction *MF) { + return CT.find(const_cast(C), MF); + } + + Register find(const GlobalVariable *GV, const MachineFunction *MF) { + return GT.find(const_cast(GV), MF); + } + + Register find(const Function *F, const MachineFunction *MF) { + return FT.find(const_cast(F), MF); + } + + Register find(const Argument *Arg, const MachineFunction *MF) { + return AT.find(const_cast(Arg), MF); + } +}; +} // namespace llvm +#endif \ No newline at end of file diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp @@ -0,0 +1,95 @@ +//===-- SPIRVDuplicatesTracker.cpp - SPIR-V Duplicates Tracker --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// General infrastructure for keeping track of the values that according to +// the SPIR-V binary layout should be global to the whole module. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVDuplicatesTracker.h" + +using namespace llvm; + +template +void SPIRVGeneralDuplicatesTracker::prebuildReg2Entry( + SPIRVDuplicatesTracker &DT, SPIRVReg2EntryTy &Reg2Entry) { + for (auto &TPair : DT.getAllUses()) { + for (auto &RegPair : TPair.second) { + const MachineFunction *MF = RegPair.first; + Register R = RegPair.second; + MachineInstr *MI = MF->getRegInfo().getVRegDef(R); + if (!MI) + continue; + Reg2Entry[&MI->getOperand(0)] = &TPair.second; + } + } +} + +void SPIRVGeneralDuplicatesTracker::buildDepsGraph( + std::vector &Graph, + MachineModuleInfo *MMI = nullptr) { + SPIRVReg2EntryTy Reg2Entry; + prebuildReg2Entry(TT, Reg2Entry); + prebuildReg2Entry(CT, Reg2Entry); + prebuildReg2Entry(GT, Reg2Entry); + prebuildReg2Entry(FT, Reg2Entry); + prebuildReg2Entry(AT, Reg2Entry); + + for (auto &Op2E : Reg2Entry) { + SPIRV::DTSortableEntry *E = Op2E.second; + Graph.push_back(E); + for (auto &U : *E) { + const MachineRegisterInfo &MRI = U.first->getRegInfo(); + MachineInstr *MI = MRI.getUniqueVRegDef(U.second); + if (!MI) + continue; + assert(MI && MI->getParent() && "No MachineInstr created yet"); + for (auto i = MI->getNumDefs(); i < MI->getNumOperands(); i++) { + MachineOperand &Op = MI->getOperand(i); + if (!Op.isReg()) + continue; + MachineOperand *RegOp = &MRI.getVRegDef(Op.getReg())->getOperand(0); + assert((MI->getOpcode() == SPIRV::OpVariable && i == 3) || + Reg2Entry.count(RegOp)); + if (Reg2Entry.count(RegOp)) + E->addDep(Reg2Entry[RegOp]); + } + + if (E->getIsFunc()) { + MachineInstr *Next = MI->getNextNode(); + if (Next && (Next->getOpcode() == SPIRV::OpFunction || + Next->getOpcode() == SPIRV::OpFunctionParameter)) { + E->addDep(Reg2Entry[&Next->getOperand(0)]); + } + } + } + } + + if (MMI) { + const Module *M = MMI->getModule(); + for (auto F = M->begin(), E = M->end(); F != E; ++F) { + const MachineFunction *MF = MMI->getMachineFunction(*F); + if (!MF) + continue; + for (const MachineBasicBlock &MBB : *MF) { + for (const MachineInstr &CMI : MBB) { + MachineInstr &MI = const_cast(CMI); + MI.dump(); + if (MI.getNumExplicitDefs() > 0 && + Reg2Entry.count(&MI.getOperand(0))) { + dbgs() << "\t["; + for (SPIRV::DTSortableEntry *D : + Reg2Entry.lookup(&MI.getOperand(0))->getDeps()) + dbgs() << Register::virtReg2Index(D->lookup(MF)) << ", "; + dbgs() << "]\n"; + } + } + } + } + } +} \ No newline at end of file diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h @@ -17,6 +17,7 @@ #define LLVM_LIB_TARGET_SPIRV_SPIRVTYPEMANAGER_H #include "MCTargetDesc/SPIRVBaseInfo.h" +#include "SPIRVDuplicatesTracker.h" #include "SPIRVInstrInfo.h" #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" @@ -30,7 +31,10 @@ // where Reg = OpType... // while VRegToTypeMap tracks SPIR-V type assigned to other regs (i.e. not // type-declaring ones) - DenseMap> VRegToTypeMap; + DenseMap> + VRegToTypeMap; + + SPIRVGeneralDuplicatesTracker DT; DenseMap SPIRVToLLVMType; @@ -48,6 +52,39 @@ MachineFunction *CurMF; + void add(const Constant *C, MachineFunction *MF, Register R) { + DT.add(C, MF, R); + } + + void add(const GlobalVariable *GV, MachineFunction *MF, Register R) { + DT.add(GV, MF, R); + } + + void add(const Function *F, MachineFunction *MF, Register R) { + DT.add(F, MF, R); + } + + void add(const Argument *Arg, MachineFunction *MF, Register R) { + DT.add(Arg, MF, R); + } + + Register find(const Constant *C, MachineFunction *MF) { + return DT.find(C, MF); + } + + Register find(const GlobalVariable *GV, MachineFunction *MF) { + return DT.find(GV, MF); + } + + Register find(const Function *F, MachineFunction *MF) { + return DT.find(F, MF); + } + + void buildDepsGraph(std::vector &Graph, + MachineModuleInfo *MMI = nullptr) { + DT.buildDepsGraph(Graph, MMI); + } + // Get or create a SPIR-V type corresponding the given LLVM IR type, // and map it to the given VReg by creating an ASSIGN_TYPE instruction. SPIRVType *assignTypeToVReg( @@ -136,7 +173,7 @@ SPIRVType *getOpTypeFunction(SPIRVType *RetType, const SmallVectorImpl &ArgTypes, MachineIRBuilder &MIRBuilder); - SPIRVType *restOfCreateSPIRVType(Type *LLVMTy, MachineInstrBuilder MIB); + SPIRVType *restOfCreateSPIRVType(const Type *LLVMTy, SPIRVType *SpirvType); public: Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder, diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp --- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp @@ -101,7 +101,6 @@ SPIRVType *SpvType, bool EmitIR) { auto &MF = MIRBuilder.getMF(); - Register Res; const IntegerType *LLVMIntTy; if (SpvType) LLVMIntTy = cast(getTypeForSPIRVType(SpvType)); @@ -110,15 +109,18 @@ // Find a constant in DT or build a new one. const auto ConstInt = ConstantInt::get(const_cast(LLVMIntTy), Val); - unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; - Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth)); - assignTypeToVReg(LLVMIntTy, Res, MIRBuilder); - if (EmitIR) - MIRBuilder.buildConstant(Res, *ConstInt); - else - MIRBuilder.buildInstr(SPIRV::OpConstantI) - .addDef(Res) - .addImm(ConstInt->getSExtValue()); + Register Res = DT.find(ConstInt, &MF); + if (!Res.isValid()) { + unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; + Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth)); + assignTypeToVReg(LLVMIntTy, Res, MIRBuilder); + if (EmitIR) + MIRBuilder.buildConstant(Res, *ConstInt); + else + MIRBuilder.buildInstr(SPIRV::OpConstantI) + .addDef(Res) + .addImm(ConstInt->getSExtValue()); + } return Res; } @@ -126,7 +128,6 @@ MachineIRBuilder &MIRBuilder, SPIRVType *SpvType) { auto &MF = MIRBuilder.getMF(); - Register Res; const Type *LLVMFPTy; if (SpvType) { LLVMFPTy = getTypeForSPIRVType(SpvType); @@ -136,10 +137,13 @@ } // Find a constant in DT or build a new one. const auto ConstFP = ConstantFP::get(LLVMFPTy->getContext(), Val); - unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; - Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth)); - assignTypeToVReg(LLVMFPTy, Res, MIRBuilder); - MIRBuilder.buildFConstant(Res, *ConstFP); + Register Res = DT.find(ConstFP, &MF); + if (!Res.isValid()) { + unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; + Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth)); + assignTypeToVReg(LLVMFPTy, Res, MIRBuilder); + MIRBuilder.buildFConstant(Res, *ConstFP); + } return Res; } @@ -184,6 +188,7 @@ *Subtarget.getRegBankInfo()); } Reg = MIB->getOperand(0).getReg(); + DT.add(GVar, &MIRBuilder.getMF(), Reg); // Set to Reg the same type as ResVReg has. auto MRI = MIRBuilder.getMRI(); @@ -318,10 +323,11 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType( const Type *Type, MachineIRBuilder &MIRBuilder, SPIRV::AccessQualifier AccessQual, bool EmitIR) { + Register Reg = DT.find(Type, &MIRBuilder.getMF()); + if (Reg.isValid()) + return getSPIRVTypeForVReg(Reg); SPIRVType *SpirvType = createSPIRVType(Type, MIRBuilder, AccessQual, EmitIR); - VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType; - SPIRVToLLVMType[SpirvType] = Type; - return SpirvType; + return restOfCreateSPIRVType(Type, SpirvType); } bool SPIRVGlobalRegistry::isScalarOfType(Register VReg, @@ -387,17 +393,21 @@ MIRBuilder); } -SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(Type *LLVMTy, - MachineInstrBuilder MIB) { - SPIRVType *SpirvType = MIB; +SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(const Type *LLVMTy, + SPIRVType *SpirvType) { + assert(CurMF == SpirvType->getMF()); VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType; SPIRVToLLVMType[SpirvType] = LLVMTy; + DT.add(LLVMTy, CurMF, getSPIRVTypeID(SpirvType)); return SpirvType; } SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType( unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) { Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth); + Register Reg = DT.find(LLVMTy, CurMF); + if (Reg.isValid()) + return getSPIRVTypeForVReg(Reg); MachineBasicBlock &BB = *I.getParent(); auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeInt)) .addDef(createTypeVReg(CurMF->getRegInfo())) diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -807,23 +807,29 @@ Register SPIRVInstructionSelector::buildI32Constant(uint32_t Val, MachineInstr &I, const SPIRVType *ResType) const { + Type *LLVMTy = IntegerType::get(GR.CurMF->getFunction().getContext(), 32); const SPIRVType *SpvI32Ty = ResType ? ResType : GR.getOrCreateSPIRVIntegerType(32, I, TII); - Register NewReg; - NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32)); - MachineInstr *MI; - MachineBasicBlock &BB = *I.getParent(); - if (Val == 0) { - MI = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull)) - .addDef(NewReg) - .addUse(GR.getSPIRVTypeID(SpvI32Ty)); - } else { - MI = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI)) - .addDef(NewReg) - .addUse(GR.getSPIRVTypeID(SpvI32Ty)) - .addImm(APInt(32, Val).getZExtValue()); + // Find a constant in DT or build a new one. + auto ConstInt = ConstantInt::get(LLVMTy, Val); + Register NewReg = GR.find(ConstInt, GR.CurMF); + if (!NewReg.isValid()) { + NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32)); + GR.add(ConstInt, GR.CurMF, NewReg); + MachineInstr *MI; + MachineBasicBlock &BB = *I.getParent(); + if (Val == 0) { + MI = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull)) + .addDef(NewReg) + .addUse(GR.getSPIRVTypeID(SpvI32Ty)); + } else { + MI = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI)) + .addDef(NewReg) + .addUse(GR.getSPIRVTypeID(SpvI32Ty)) + .addImm(APInt(32, Val).getZExtValue()); + } + constrainSelectedInstRegOperands(*MI, TII, TRI, RBI); } - constrainSelectedInstRegOperands(*MI, TII, TRI, RBI); return NewReg; } diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.h @@ -15,6 +15,7 @@ #define LLVM_LIB_TARGET_SPIRV_SPIRVMODULEANALYSIS_H #include "MCTargetDesc/SPIRVBaseInfo.h" +#include "SPIRVDuplicatesTracker.h" #include "SPIRVSubtarget.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" @@ -123,6 +124,11 @@ private: void setBaseInfo(const Module &M); template void collectTypesConstsVars(); + void collectGlobalEntities( + const std::vector &DepsGraph, + SPIRV::ModuleSectionType MSType, + std::function Pred, + bool UsePreOrder); void processDefInstrs(const Module &M); void collectFuncNames(MachineInstr &MI, const Function &F); void processOtherInstrs(const Module &M); diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp --- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp @@ -28,6 +28,11 @@ #define DEBUG_TYPE "spirv-module-analysis" +static cl::opt + SPVDumpDeps("spv-dump-deps", + cl::desc("Dump MIR with SPIR-V dependencies info"), + cl::Optional, cl::init(false)); + char llvm::SPIRVModuleAnalysis::ID = 0; namespace llvm { @@ -113,6 +118,83 @@ return false; } +// Collect MI which defines the register in the given machine function. +static void collectDefInstr(Register Reg, const MachineFunction *MF, + SPIRV::ModuleAnalysisInfo *MAI, + SPIRV::ModuleSectionType MSType, + bool DoInsert = true) { + assert(MAI->hasRegisterAlias(MF, Reg) && "Cannot find register alias"); + MachineInstr *MI = MF->getRegInfo().getUniqueVRegDef(Reg); + assert(MI && "There should be an instruction that defines the register"); + MAI->setSkipEmission(MI); + if (DoInsert) + MAI->MS[MSType].push_back(MI); +} + +void SPIRVModuleAnalysis::collectGlobalEntities( + const std::vector &DepsGraph, + SPIRV::ModuleSectionType MSType, + std::function Pred, + bool UsePreOrder) { + DenseSet Visited; + for (const auto *E : DepsGraph) { + std::function RecHoistUtil; + // NOTE: here we prefer recursive approach over iterative because + // we don't expect depchains long enough to cause SO. + RecHoistUtil = [MSType, UsePreOrder, &Visited, &Pred, + &RecHoistUtil](const SPIRV::DTSortableEntry *E) { + if (Visited.count(E) || !Pred(E)) + return; + Visited.insert(E); + + // Traversing deps graph in post-order allows us to get rid of + // register aliases preprocessing. + // But pre-order is required for correct processing of function + // declaration and arguments processing. + if (!UsePreOrder) + for (auto *S : E->getDeps()) + RecHoistUtil(S); + + Register GlobalReg = Register::index2VirtReg(MAI.getNextID()); + bool IsFirst = true; + for (auto &U : *E) { + const MachineFunction *MF = U.first; + Register Reg = U.second; + MAI.setRegisterAlias(MF, Reg, GlobalReg); + if (!MF->getRegInfo().getUniqueVRegDef(Reg)) + continue; + collectDefInstr(Reg, MF, &MAI, MSType, IsFirst); + IsFirst = false; + if (E->getIsGV()) + MAI.GlobalVarList.push_back(MF->getRegInfo().getUniqueVRegDef(Reg)); + } + + if (UsePreOrder) + for (auto *S : E->getDeps()) + RecHoistUtil(S); + }; + RecHoistUtil(E); + } +} + +// The function initializes global register alias table for types, consts, +// global vars and func decls and collects these instruction for output +// at module level. Also it collects explicit OpExtension/OpCapability +// instructions. +void SPIRVModuleAnalysis::processDefInstrs(const Module &M) { + std::vector DepsGraph; + + GR->buildDepsGraph(DepsGraph, SPVDumpDeps ? MMI : nullptr); + + collectGlobalEntities( + DepsGraph, SPIRV::MB_TypeConstVars, + [](const SPIRV::DTSortableEntry *E) { return !E->getIsFunc(); }, false); + + collectGlobalEntities( + DepsGraph, SPIRV::MB_ExtFuncDecls, + [](const SPIRV::DTSortableEntry *E) { return E->getIsFunc(); }, true); +} + // Look for IDs declared with Import linkage, and map the imported name string // to the register defining that variable (which will usually be the result of // an OpFunction). This lets us call externally imported functions using @@ -146,10 +228,9 @@ // numbering has already occurred by this point. We can directly compare reg // arguments when detecting duplicates. static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI, - SPIRV::ModuleSectionType MSType, - bool IsConstOrType = false) { + SPIRV::ModuleSectionType MSType) { MAI.setSkipEmission(&MI); - if (findSameInstrInMS(MI, MSType, MAI, IsConstOrType, IsConstOrType ? 1 : 0)) + if (findSameInstrInMS(MI, MSType, MAI, false)) return; // Found a duplicate, so don't add it. // No duplicates, so add it. MAI.MS[MSType].push_back(&MI); @@ -163,18 +244,11 @@ continue; MachineFunction *MF = MMI->getMachineFunction(*F); assert(MF); - unsigned FCounter = 0; for (MachineBasicBlock &MBB : *MF) for (MachineInstr &MI : MBB) { - if (MI.getOpcode() == SPIRV::OpFunction) - FCounter++; if (MAI.getSkipEmission(&MI)) continue; const unsigned OpCode = MI.getOpcode(); - const bool IsFuncOrParm = - OpCode == SPIRV::OpFunction || OpCode == SPIRV::OpFunctionParameter; - const bool IsConstOrType = - TII->isConstantInstr(MI) || TII->isTypeDeclInstr(MI); if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) { collectOtherInstr(MI, MAI, SPIRV::MB_DebugNames); } else if (OpCode == SPIRV::OpEntryPoint) { @@ -182,12 +256,6 @@ } else if (TII->isDecorationInstr(MI)) { collectOtherInstr(MI, MAI, SPIRV::MB_Annotations); collectFuncNames(MI, *F); - } else if (IsConstOrType || (FCounter > 1 && IsFuncOrParm)) { - // Now OpSpecConstant*s are not in DT, - // but they need to be collected anyway. - enum SPIRV::ModuleSectionType Type = - IsFuncOrParm ? SPIRV::MB_ExtFuncDecls : SPIRV::MB_TypeConstVars; - collectOtherInstr(MI, MAI, Type, IsConstOrType); } else if (OpCode == SPIRV::OpFunction) { collectFuncNames(MI, *F); } @@ -239,6 +307,7 @@ // TODO: Process type/const/global var/func decl instructions, number their // destination registers from 0 to N, collect Extensions and Capabilities. + processDefInstrs(M); // Number rest of registers from N+1 onwards. numberRegistersGlobally(M);