diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h --- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h +++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h @@ -218,7 +218,7 @@ void printMemOperand(const MachineInstr *MI, int opNum, raw_ostream &O, const char *Modifier = nullptr); void printModuleLevelGV(const GlobalVariable *GVar, raw_ostream &O, - bool = false); + bool processDemoted, const NVPTXSubtarget &STI); void printParamName(Function::const_arg_iterator I, int paramIndex, raw_ostream &O); void emitGlobals(const Module &M); @@ -258,7 +258,8 @@ // List of variables demoted to a function scope. std::map> localDecls; - void emitPTXGlobalVariable(const GlobalVariable *GVar, raw_ostream &O); + void emitPTXGlobalVariable(const GlobalVariable *GVar, raw_ostream &O, + const NVPTXSubtarget &STI); void emitPTXAddressSpace(unsigned int AddressSpace, raw_ostream &O) const; std::string getPTXFundamentalTypeStr(Type *Ty, bool = true) const; void printScalarConstant(const Constant *CPV, raw_ostream &O); diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp --- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp @@ -818,9 +818,13 @@ "Missed a global variable"); assert(GVVisiting.size() == 0 && "Did not fully process a global variable"); + const NVPTXTargetMachine &NTM = static_cast(TM); + const NVPTXSubtarget &STI = + *static_cast(NTM.getSubtargetImpl()); + // Print out module-level global variables in proper order for (unsigned i = 0, e = Globals.size(); i != e; ++i) - printModuleLevelGV(Globals[i], OS2); + printModuleLevelGV(Globals[i], OS2, /*processDemoted=*/false, STI); OS2 << '\n'; @@ -957,8 +961,8 @@ } void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar, - raw_ostream &O, - bool processDemoted) { + raw_ostream &O, bool processDemoted, + const NVPTXSubtarget &STI) { // Skip meta data if (GVar->hasSection()) { if (GVar->getSection() == "llvm.metadata") @@ -1001,7 +1005,7 @@ // (extern) declarations, no definition or initializer // Currently the only known declaration is for an automatic __local // (.shared) promoted to global. - emitPTXGlobalVariable(GVar, O); + emitPTXGlobalVariable(GVar, O, STI); O << ";\n"; return; } @@ -1095,6 +1099,10 @@ emitPTXAddressSpace(PTy->getAddressSpace(), O); if (isManaged(*GVar)) { + if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) { + report_fatal_error( + ".attribute(.managed) requires PTX version >= 4.0 and sm_30"); + } O << " .attribute(.managed)"; } @@ -1214,9 +1222,13 @@ std::vector &gvars = localDecls[f]; + const NVPTXTargetMachine &NTM = static_cast(TM); + const NVPTXSubtarget &STI = + *static_cast(NTM.getSubtargetImpl()); + for (const GlobalVariable *GV : gvars) { O << "\t// demoted variable\n\t"; - printModuleLevelGV(GV, O, true); + printModuleLevelGV(GV, O, /*processDemoted=*/true, STI); } } @@ -1282,7 +1294,8 @@ } void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar, - raw_ostream &O) { + raw_ostream &O, + const NVPTXSubtarget &STI) { const DataLayout &DL = getDataLayout(); // GlobalVariables are always constant pointers themselves. @@ -1290,8 +1303,13 @@ O << "."; emitPTXAddressSpace(GVar->getType()->getAddressSpace(), O); - if (isManaged(*GVar)) + if (isManaged(*GVar)) { + if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) { + report_fatal_error( + ".attribute(.managed) requires PTX version >= 4.0 and sm_30"); + } O << " .attribute(.managed)"; + } if (MaybeAlign A = GVar->getAlign()) O << " .align " << A->value(); else diff --git a/llvm/test/CodeGen/NVPTX/managed.ll b/llvm/test/CodeGen/NVPTX/managed.ll --- a/llvm/test/CodeGen/NVPTX/managed.ll +++ b/llvm/test/CodeGen/NVPTX/managed.ll @@ -1,5 +1,7 @@ -; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s +; RUN: llc < %s -march=nvptx -mcpu=sm_30 -mattr=+ptx40 | FileCheck %s +; RUN: not --crash llc < %s -march=nvptx -mcpu=sm_20 2>&1 | FileCheck %s --check-prefix ERROR +; ERROR: LLVM ERROR: .attribute(.managed) requires PTX version >= 4.0 and sm_30 ; CHECK: .visible .global .align 4 .u32 device_g; @device_g = addrspace(1) global i32 zeroinitializer