diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -3079,6 +3079,7 @@ def SPV_OC_OpName : I32EnumAttrCase<"OpName", 5>; def SPV_OC_OpMemberName : I32EnumAttrCase<"OpMemberName", 6>; def SPV_OC_OpString : I32EnumAttrCase<"OpString", 7>; +def SPV_OC_OpLine : I32EnumAttrCase<"OpLine", 8>; def SPV_OC_OpExtension : I32EnumAttrCase<"OpExtension", 10>; def SPV_OC_OpExtInstImport : I32EnumAttrCase<"OpExtInstImport", 11>; def SPV_OC_OpExtInst : I32EnumAttrCase<"OpExtInst", 12>; @@ -3223,7 +3224,7 @@ SPV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", [ SPV_OC_OpNop, SPV_OC_OpUndef, SPV_OC_OpSourceContinued, SPV_OC_OpSource, SPV_OC_OpSourceExtension, SPV_OC_OpName, SPV_OC_OpMemberName, SPV_OC_OpString, - SPV_OC_OpExtension, SPV_OC_OpExtInstImport, SPV_OC_OpExtInst, + SPV_OC_OpLine, SPV_OC_OpExtension, SPV_OC_OpExtInstImport, SPV_OC_OpExtInst, SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode, SPV_OC_OpCapability, SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt, SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeArray, diff --git a/mlir/include/mlir/Dialect/SPIRV/Serialization.h b/mlir/include/mlir/Dialect/SPIRV/Serialization.h --- a/mlir/include/mlir/Dialect/SPIRV/Serialization.h +++ b/mlir/include/mlir/Dialect/SPIRV/Serialization.h @@ -26,7 +26,8 @@ /// Serializes the given SPIR-V `module` and writes to `binary`. On failure, /// reports errors to the error handler registered with the MLIR context for /// `module`. -LogicalResult serialize(ModuleOp module, SmallVectorImpl &binary); +LogicalResult serialize(ModuleOp module, SmallVectorImpl &binary, + bool emitDebugInfo = false); /// Deserializes the given SPIR-V `binary` module and creates a MLIR ModuleOp /// in the given `context`. Returns the ModuleOp on success; otherwise, reports diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -232,6 +232,16 @@ /// Processes a SPIR-V OpConstantNull instruction with the given `operands`. LogicalResult processConstantNull(ArrayRef operands); + //===--------------------------------------------------------------------===// + // Debug + //===--------------------------------------------------------------------===// + + /// Processes a SPIR-V OpLine instruction with the given `operands`. + LogicalResult processDebugLine(ArrayRef operands); + + /// Processes a SPIR-V OpString instruction with the given `operands`. + LogicalResult processDebugString(ArrayRef operands); + //===--------------------------------------------------------------------===// // Control flow //===--------------------------------------------------------------------===// @@ -376,6 +386,10 @@ /// The SPIR-V binary module. ArrayRef binary; + /// Contains the data of the OpLine instruction which precedes the current + /// processing instruction. + SmallVector debugLine; + /// The current word offset into the binary module. unsigned curOffset = 0; @@ -444,6 +458,9 @@ // Result to name mapping. DenseMap nameMap; + // Result to debug info mapping. + DenseMap debugInfoMap; + // Result to decorations mapping. DenseMap decorations; @@ -1994,6 +2011,36 @@ return success(); } +//===----------------------------------------------------------------------===// +// Debug +//===----------------------------------------------------------------------===// + +LogicalResult Deserializer::processDebugLine(ArrayRef operands) { + if (operands.size() != 3) + return emitError(unknownLoc, "OpLine must have 3 operands"); + debugLine = {operands.begin(), operands.end()}; + return success(); +} + +LogicalResult Deserializer::processDebugString(ArrayRef operands) { + if (operands.size() < 2) + return emitError(unknownLoc, "OpString needs at least 2 operands"); + + if (!debugInfoMap.lookup(operands[0]).empty()) + return emitError(unknownLoc, + "duplicate debug string found for result ") + << operands[0]; + + unsigned wordIndex = 1; + StringRef debugString = decodeStringLiteral(operands, wordIndex); + if (wordIndex != operands.size()) + return emitError(unknownLoc, + "unexpected trailing words in OpString instruction"); + + debugInfoMap[operands[0]] = debugString; + return success(); +} + //===----------------------------------------------------------------------===// // Instruction //===----------------------------------------------------------------------===// @@ -2085,10 +2132,13 @@ return processGlobalVariable(operands); } break; + case spirv::Opcode::OpLine: + return processDebugLine(operands); case spirv::Opcode::OpName: return processName(operands); - case spirv::Opcode::OpModuleProcessed: case spirv::Opcode::OpString: + return processDebugString(operands); + case spirv::Opcode::OpModuleProcessed: case spirv::Opcode::OpSource: case spirv::Opcode::OpSourceContinued: case spirv::Opcode::OpSourceExtension: @@ -2214,6 +2264,7 @@ interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation())); wordIndex++; } + opBuilder.create(unknownLoc, exec_model, opBuilder.getSymbolRefAttr(fnName), opBuilder.getArrayAttr(interface)); diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -41,7 +41,7 @@ ArrayRef operands) { uint32_t wordCount = 1 + operands.size(); binary.push_back(spirv::getPrefixedOpcode(wordCount, op)); - binary.append(operands.begin(), operands.end()); + binary.append(operands.begin(), operands.end()); return success(); } @@ -132,7 +132,7 @@ class Serializer { public: /// Creates a serializer for the given SPIR-V `module`. - explicit Serializer(spirv::ModuleOp module); + explicit Serializer(spirv::ModuleOp module, bool emitDebugInfo = false); /// Serializes the remembered SPIR-V module. LogicalResult serialize(); @@ -189,6 +189,8 @@ void processCapability(); + void processDebugInfo(); + void processExtension(); void processMemoryModel(); @@ -382,6 +384,13 @@ /// An MLIR builder for getting MLIR constructs. mlir::Builder mlirBuilder; + /// A flag which indicates if the debuginfo should be emitted. + bool emitDebugInfo = false; + + /// The of the OpString instruction, which specifies a file name, for + /// use by other debug instructions. + uint32_t fileID = 0; + /// The next available result . uint32_t nextID = 1; @@ -394,7 +403,7 @@ SmallVector memoryModel; SmallVector entryPoints; SmallVector executionModes; - // TODO(antiagainst): debug instructions + SmallVector debug; SmallVector names; SmallVector decorations; SmallVector typesGlobalValues; @@ -482,8 +491,9 @@ }; } // namespace -Serializer::Serializer(spirv::ModuleOp module) - : module(module), mlirBuilder(module.getContext()) {} +Serializer::Serializer(spirv::ModuleOp module, bool emitDebugInfo) + : module(module), mlirBuilder(module.getContext()), + emitDebugInfo(emitDebugInfo) {} LogicalResult Serializer::serialize() { LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n"); @@ -495,6 +505,7 @@ processCapability(); processExtension(); processMemoryModel(); + processDebugInfo(); // Iterate over the module body to serialize it. Assumptions are that there is // only one basic block in the moduleOp @@ -525,6 +536,7 @@ binary.append(memoryModel.begin(), memoryModel.end()); binary.append(entryPoints.begin(), entryPoints.end()); binary.append(executionModes.begin(), executionModes.end()); + binary.append(debug.begin(), debug.end()); binary.append(names.begin(), names.end()); binary.append(decorations.begin(), decorations.end()); binary.append(typesGlobalValues.begin(), typesGlobalValues.end()); @@ -569,6 +581,19 @@ {static_cast(cap)}); } +void Serializer::processDebugInfo() { + if (emitDebugInfo) { + auto fileLoc = module.getLoc().dyn_cast(); + auto fileName = fileLoc ? fileLoc.getFilename() : ""; + fileID = getNextID(); + SmallVector operands; + operands.push_back(fileID); + spirv::encodeStringLiteralInto(operands, fileName); + encodeInstructionInto(debug, spirv::Opcode::OpString, operands); + // TODO: Encode more debug instructions. + } +} + void Serializer::processExtension() { llvm::SmallVector extName; for (spirv::Extension ext : module.vce_triple()->getExtensions()) { @@ -1839,12 +1864,13 @@ } LogicalResult spirv::serialize(spirv::ModuleOp module, - SmallVectorImpl &binary) { + SmallVectorImpl &binary, + bool emitDebugInfo) { if (!module.vce_triple().hasValue()) return module.emitError( "module must have 'vce_triple' attribute to be serializeable"); - Serializer serializer(module); + Serializer serializer(module, emitDebugInfo); if (failed(serializer.serialize())) return failure(); diff --git a/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp b/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp @@ -91,7 +91,8 @@ if (spirvModules.size() != 1) return module.emitError("found more than one 'spv.module' op"); - if (failed(spirv::serialize(spirvModules[0], binary))) + if (failed( + spirv::serialize(spirvModules[0], binary, /*emitDebuginfo=*/false))) return failure(); output.write(reinterpret_cast(binary.data()), @@ -114,7 +115,7 @@ //===----------------------------------------------------------------------===// static LogicalResult roundTripModule(llvm::SourceMgr &sourceMgr, - raw_ostream &output, + bool emitDebugInfo, raw_ostream &output, MLIRContext *context) { // Parse an MLIR module from the source manager. auto srcModule = OwningModuleRef(parseSourceFile(sourceMgr, context)); @@ -131,7 +132,7 @@ if (std::next(spirvModules.begin()) != spirvModules.end()) return srcModule->emitError("found more than one 'spv.module' op"); - if (failed(spirv::serialize(*spirvModules.begin(), binary))) + if (failed(spirv::serialize(*spirvModules.begin(), binary, emitDebugInfo))) return failure(); // Then deserialize to get back a SPIR-V module. @@ -153,7 +154,18 @@ TranslateRegistration roundtrip( "test-spirv-roundtrip", [](llvm::SourceMgr &sourceMgr, raw_ostream &output, MLIRContext *context) { - return roundTripModule(sourceMgr, output, context); + return roundTripModule(sourceMgr, /*emitDebugInfo=*/false, output, + context); + }); +} + +void registerTestRoundtripDebugSPIRV() { + TranslateRegistration roundtrip( + "test-spirv-roundtrip-debug", + [](llvm::SourceMgr &sourceMgr, raw_ostream &output, + MLIRContext *context) { + return roundTripModule(sourceMgr, /*emitDebugInfo=*/true, output, + context); }); } } // namespace mlir diff --git a/mlir/test/Dialect/SPIRV/Serialization/debug.mlir b/mlir/test/Dialect/SPIRV/Serialization/debug.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Serialization/debug.mlir @@ -0,0 +1,60 @@ +// RUN: mlir-translate -test-spirv-roundtrip-debug -mlir-print-debuginfo %s | FileCheck %s + +spv.module Logical GLSL450 requires #spv.vce { + spv.func @arithmetic(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) "None" { + // CHECK: loc({{".*debug.mlir"}}:6:10) + %0 = spv.FAdd %arg0, %arg1 : vector<4xf32> + // CHECK: loc({{".*debug.mlir"}}:8:10) + %1 = spv.FNegate %arg0 : vector<4xf32> + spv.Return + } + + spv.func @atomic(%ptr: !spv.ptr, %value: i32, %comparator: i32) "None" { + // CHECK: loc({{".*debug.mlir"}}:14:10) + %1 = spv.AtomicAnd "Device" "None" %ptr, %value : !spv.ptr + spv.Return + } + + spv.func @bitwiser(%arg0 : i32, %arg1 : i32) "None" { + // CHECK: loc({{".*debug.mlir"}}:20:10) + %0 = spv.BitwiseAnd %arg0, %arg1 : i32 + spv.Return + } + + spv.func @convert(%arg0 : f32) "None" { + // CHECK: loc({{".*debug.mlir"}}:26:10) + %0 = spv.ConvertFToU %arg0 : f32 to i32 + spv.Return + } + + spv.func @composite(%arg0 : !spv.struct, f32>>, %arg1: !spv.array<4xf32>, %arg2 : f32, %arg3 : f32) "None" { + // CHECK: loc({{".*debug.mlir"}}:32:10) + %0 = spv.CompositeInsert %arg1, %arg0[1 : i32, 0 : i32] : !spv.array<4xf32> into !spv.struct, f32>> + // CHECK: loc({{".*debug.mlir"}}:34:10) + %1 = spv.CompositeConstruct %arg2, %arg3 : vector<2xf32> + spv.Return + } + + spv.func @group_non_uniform(%val: f32) "None" { + // CHECK: loc({{".*debug.mlir"}}:40:10) + %0 = spv.GroupNonUniformFAdd "Workgroup" "Reduce" %val : f32 + spv.Return + } + + spv.func @logical(%arg0: i32, %arg1: i32) "None" { + // CHECK: loc({{".*debug.mlir"}}:46:10) + %0 = spv.IEqual %arg0, %arg1 : i32 + spv.Return + } + + spv.func @memory_accesses(%arg0 : !spv.ptr>, StorageBuffer>, %arg1 : i32, %arg2 : i32) "None" { + // CHECK: loc({{".*debug.mlir"}}:52:10) + %2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr>, StorageBuffer> + // CHECK: loc({{".*debug.mlir"}}:54:10) + %3 = spv.Load "StorageBuffer" %2 : f32 + // CHECK: loc({{.*debug.mlir"}}:56:5) + spv.Store "StorageBuffer" %2, %3 : f32 + // CHECK: loc({{".*debug.mlir"}}:58:5) + spv.Return + } +} diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -492,6 +492,23 @@ os << "}\n"; } +/// Emits OpLine instruction for the given `opVar` and `binary`. +static void emitDebugLineOp(StringRef opVar, StringRef binary, + raw_ostream &os) { + // According to the SPIR-V spec section 2.4: + // "The section after the annotation section is the first section to allow use + // of OpLine and OpNoLine debug information." + os << formatv(" if (emitDebugInfo) {{\n"); + os << formatv(" auto fileLoc = {0}.getLoc().dyn_cast();\n", + opVar); + os << formatv(" if (fileLoc) {{\n"); + os << formatv(" encodeInstructionInto({0}, spirv::Opcode::OpLine, " + "{{fileID, fileLoc.getLine(), fileLoc.getColumn()});\n", + binary); + os << formatv(" }\n"); + os << formatv(" }\n"); +} + /// Forward declaration of function to return the SPIR-V opcode corresponding to /// an operation. This function will be generated for all SPV_Op instances that /// have hasOpcode = 1. @@ -660,6 +677,9 @@ opVar, record->getValueAsString("extendedInstSetName"), record->getValueAsInt("extendedInstOpcode"), operands); } else { + // Emit debug info. + emitDebugLineOp(opVar, "functionBody", os); + os << formatv(" encodeInstructionInto(" "functionBody, spirv::getOpcode<{0}>(), {1});\n", op.getQualCppClassName(), operands); @@ -900,10 +920,22 @@ emitOperandDeserialization(op, record->getLoc(), " ", words, wordIndex, operands, attributes, os); + os << formatv(" Location loc = unknownLoc;\n"); + os << formatv(" if (!debugLine.empty()) {{\n"); os << formatv( - " auto {1} = opBuilder.create<{0}>(unknownLoc, {2}, {3}, {4}); " - "(void){1};\n", - op.getQualCppClassName(), opVar, resultTypes, operands, attributes); + " auto fileName = debugInfoMap.lookup(debugLine[0]).str();\n"); + os << formatv(" if (fileName.empty())\n"); + os << formatv(" fileName = \"\";\n"); + os << formatv(" loc = " + "opBuilder.getFileLineColLoc(opBuilder.getIdentifier(fileName)," + " debugLine[1], debugLine[2]);\n"); + os << formatv(" debugLine.clear();\n"); + os << formatv(" }\n"); + + os << formatv(" auto {1} = opBuilder.create<{0}>(loc, {2}, {3}, {4}); " + "(void){1};\n", + op.getQualCppClassName(), opVar, resultTypes, operands, + attributes); if (op.getNumResults() == 1) { os << formatv(" valueMap[{0}] = {1}.getResult();\n\n", valueID, opVar); } diff --git a/mlir/tools/mlir-translate/mlir-translate.cpp b/mlir/tools/mlir-translate/mlir-translate.cpp --- a/mlir/tools/mlir-translate/mlir-translate.cpp +++ b/mlir/tools/mlir-translate/mlir-translate.cpp @@ -50,9 +50,13 @@ namespace mlir { // Defined in the test directory, no public header. void registerTestRoundtripSPIRV(); +void registerTestRoundtripDebugSPIRV(); } // namespace mlir -static void registerTestTranslations() { registerTestRoundtripSPIRV(); } +static void registerTestTranslations() { + registerTestRoundtripSPIRV(); + registerTestRoundtripDebugSPIRV(); +} int main(int argc, char **argv) { registerAllDialects();