diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -3079,6 +3079,7 @@ def SPV_OC_OpName : I32EnumAttrCase<"OpName", 5>; def SPV_OC_OpMemberName : I32EnumAttrCase<"OpMemberName", 6>; def SPV_OC_OpString : I32EnumAttrCase<"OpString", 7>; +def SPV_OC_OpLine : I32EnumAttrCase<"OpLine", 8>; def SPV_OC_OpExtension : I32EnumAttrCase<"OpExtension", 10>; def SPV_OC_OpExtInstImport : I32EnumAttrCase<"OpExtInstImport", 11>; def SPV_OC_OpExtInst : I32EnumAttrCase<"OpExtInst", 12>; @@ -3204,6 +3205,7 @@ def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>; def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>; def SPV_OC_OpUnreachable : I32EnumAttrCase<"OpUnreachable", 255>; +def SPV_OC_OpNoLine : I32EnumAttrCase<"OpNoLine", 317>; def SPV_OC_OpModuleProcessed : I32EnumAttrCase<"OpModuleProcessed", 330>; def SPV_OC_OpGroupNonUniformElect : I32EnumAttrCase<"OpGroupNonUniformElect", 333>; def SPV_OC_OpGroupNonUniformBallot : I32EnumAttrCase<"OpGroupNonUniformBallot", 339>; @@ -3223,7 +3225,7 @@ SPV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", [ SPV_OC_OpNop, SPV_OC_OpUndef, SPV_OC_OpSourceContinued, SPV_OC_OpSource, SPV_OC_OpSourceExtension, SPV_OC_OpName, SPV_OC_OpMemberName, SPV_OC_OpString, - SPV_OC_OpExtension, SPV_OC_OpExtInstImport, SPV_OC_OpExtInst, + SPV_OC_OpLine, SPV_OC_OpExtension, SPV_OC_OpExtInstImport, SPV_OC_OpExtInst, SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode, SPV_OC_OpCapability, SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt, SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeArray, @@ -3262,14 +3264,14 @@ SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor, SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn, - SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpModuleProcessed, - SPV_OC_OpGroupNonUniformElect, SPV_OC_OpGroupNonUniformBallot, - SPV_OC_OpGroupNonUniformIAdd, SPV_OC_OpGroupNonUniformFAdd, - SPV_OC_OpGroupNonUniformIMul, SPV_OC_OpGroupNonUniformFMul, - SPV_OC_OpGroupNonUniformSMin, SPV_OC_OpGroupNonUniformUMin, - SPV_OC_OpGroupNonUniformFMin, SPV_OC_OpGroupNonUniformSMax, - SPV_OC_OpGroupNonUniformUMax, SPV_OC_OpGroupNonUniformFMax, - SPV_OC_OpSubgroupBallotKHR + SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpNoLine, + SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect, + SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpGroupNonUniformIAdd, + SPV_OC_OpGroupNonUniformFAdd, SPV_OC_OpGroupNonUniformIMul, + SPV_OC_OpGroupNonUniformFMul, SPV_OC_OpGroupNonUniformSMin, + SPV_OC_OpGroupNonUniformUMin, SPV_OC_OpGroupNonUniformFMin, + SPV_OC_OpGroupNonUniformSMax, SPV_OC_OpGroupNonUniformUMax, + SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR ]>; // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY! diff --git a/mlir/include/mlir/Dialect/SPIRV/Serialization.h b/mlir/include/mlir/Dialect/SPIRV/Serialization.h --- a/mlir/include/mlir/Dialect/SPIRV/Serialization.h +++ b/mlir/include/mlir/Dialect/SPIRV/Serialization.h @@ -26,7 +26,8 @@ /// Serializes the given SPIR-V `module` and writes to `binary`. On failure, /// reports errors to the error handler registered with the MLIR context for /// `module`. -LogicalResult serialize(ModuleOp module, SmallVectorImpl &binary); +LogicalResult serialize(ModuleOp module, SmallVectorImpl &binary, + bool emitDebugInfo = false); /// Deserializes the given SPIR-V `binary` module and creates a MLIR ModuleOp /// in the given `context`. Returns the ModuleOp on success; otherwise, reports diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -68,6 +68,16 @@ : mergeBlock(m), continueBlock(c) {} }; +/// A struct for containing OpLine instruction information. +struct DebugLine { + uint32_t fileID; + uint32_t line; + uint32_t col; + + DebugLine(uint32_t fileIDNum, uint32_t lineNum, uint32_t colNum) + : fileID(fileIDNum), line(lineNum), col(colNum) {} +}; + /// Map from a selection/loop's header block to its merge (and continue) target. using BlockMergeInfoMap = DenseMap; @@ -233,6 +243,23 @@ LogicalResult processConstantNull(ArrayRef operands); //===--------------------------------------------------------------------===// + // Debug + //===--------------------------------------------------------------------===// + + /// Discontinues any source-level location information that might be active + /// from a previous OpLine instruction. + LogicalResult clearDebugLine(); + + /// Creates a FileLineColLoc with the OpLine location information. + Location createFileLineColLoc(OpBuilder opBuilder); + + /// Processes a SPIR-V OpLine instruction with the given `operands`. + LogicalResult processDebugLine(ArrayRef operands); + + /// Processes a SPIR-V OpString instruction with the given `operands`. + LogicalResult processDebugString(ArrayRef operands); + + //===--------------------------------------------------------------------===// // Control flow //===--------------------------------------------------------------------===// @@ -376,6 +403,10 @@ /// The SPIR-V binary module. ArrayRef binary; + /// Contains the data of the OpLine instruction which precedes the current + /// processing instruction. + llvm::Optional debugLine; + /// The current word offset into the binary module. unsigned curOffset = 0; @@ -444,6 +475,9 @@ // Result to name mapping. DenseMap nameMap; + // Result to debug info mapping. + DenseMap debugInfoMap; + // Result to decorations mapping. DenseMap decorations; @@ -1506,6 +1540,7 @@ auto *target = getOrCreateBlock(operands[0]); opBuilder.create(unknownLoc, target); + clearDebugLine(); return success(); } @@ -1536,6 +1571,7 @@ /*trueArguments=*/ArrayRef(), falseBlock, /*falseArguments=*/ArrayRef(), weights); + clearDebugLine(); return success(); } @@ -1995,6 +2031,57 @@ } //===----------------------------------------------------------------------===// +// Debug +//===----------------------------------------------------------------------===// + +Location Deserializer::createFileLineColLoc(OpBuilder opBuilder) { + if (!debugLine) + return unknownLoc; + + auto fileName = debugInfoMap.lookup(debugLine->fileID).str(); + if (fileName.empty()) + fileName = ""; + return opBuilder.getFileLineColLoc(opBuilder.getIdentifier(fileName), + debugLine->line, debugLine->col); +} + +LogicalResult Deserializer::processDebugLine(ArrayRef operands) { + // According to SPIR-V spec: + // "This location information applies to the instructions physically + // following this instruction, up to the first occurrence of any of the + // following: the next end of block, the next OpLine instruction, or the next + // OpNoLine instruction." + if (operands.size() != 3) + return emitError(unknownLoc, "OpLine must have 3 operands"); + debugLine = DebugLine(operands[0], operands[1], operands[2]); + return success(); +} + +LogicalResult Deserializer::clearDebugLine() { + debugLine = llvm::None; + return success(); +} + +LogicalResult Deserializer::processDebugString(ArrayRef operands) { + if (operands.size() < 2) + return emitError(unknownLoc, "OpString needs at least 2 operands"); + + if (!debugInfoMap.lookup(operands[0]).empty()) + return emitError(unknownLoc, + "duplicate debug string found for result ") + << operands[0]; + + unsigned wordIndex = 1; + StringRef debugString = decodeStringLiteral(operands, wordIndex); + if (wordIndex != operands.size()) + return emitError(unknownLoc, + "unexpected trailing words in OpString instruction"); + + debugInfoMap[operands[0]] = debugString; + return success(); +} + +//===----------------------------------------------------------------------===// // Instruction //===----------------------------------------------------------------------===// @@ -2085,10 +2172,15 @@ return processGlobalVariable(operands); } break; + case spirv::Opcode::OpLine: + return processDebugLine(operands); + case spirv::Opcode::OpNoLine: + return clearDebugLine(); case spirv::Opcode::OpName: return processName(operands); - case spirv::Opcode::OpModuleProcessed: case spirv::Opcode::OpString: + return processDebugString(operands); + case spirv::Opcode::OpModuleProcessed: case spirv::Opcode::OpSource: case spirv::Opcode::OpSourceContinued: case spirv::Opcode::OpSourceExtension: diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -41,7 +41,7 @@ ArrayRef operands) { uint32_t wordCount = 1 + operands.size(); binary.push_back(spirv::getPrefixedOpcode(wordCount, op)); - binary.append(operands.begin(), operands.end()); + binary.append(operands.begin(), operands.end()); return success(); } @@ -132,7 +132,7 @@ class Serializer { public: /// Creates a serializer for the given SPIR-V `module`. - explicit Serializer(spirv::ModuleOp module); + explicit Serializer(spirv::ModuleOp module, bool emitDebugInfo = false); /// Serializes the remembered SPIR-V module. LogicalResult serialize(); @@ -189,6 +189,8 @@ void processCapability(); + void processDebugInfo(); + void processExtension(); void processMemoryModel(); @@ -375,6 +377,10 @@ LogicalResult emitDecoration(uint32_t target, spirv::Decoration decoration, ArrayRef params = {}); + /// Emits an OpLine instruction with the given `loc` location information into + /// the given `binary` vector. + LogicalResult emitDebugLine(SmallVectorImpl &binary, Location loc); + private: /// The SPIR-V module to be serialized. spirv::ModuleOp module; @@ -382,6 +388,13 @@ /// An MLIR builder for getting MLIR constructs. mlir::Builder mlirBuilder; + /// A flag which indicates if the debuginfo should be emitted. + bool emitDebugInfo = false; + + /// The of the OpString instruction, which specifies a file name, for + /// use by other debug instructions. + uint32_t fileID = 0; + /// The next available result . uint32_t nextID = 1; @@ -394,7 +407,7 @@ SmallVector memoryModel; SmallVector entryPoints; SmallVector executionModes; - // TODO(antiagainst): debug instructions + SmallVector debug; SmallVector names; SmallVector decorations; SmallVector typesGlobalValues; @@ -482,8 +495,9 @@ }; } // namespace -Serializer::Serializer(spirv::ModuleOp module) - : module(module), mlirBuilder(module.getContext()) {} +Serializer::Serializer(spirv::ModuleOp module, bool emitDebugInfo) + : module(module), mlirBuilder(module.getContext()), + emitDebugInfo(emitDebugInfo) {} LogicalResult Serializer::serialize() { LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n"); @@ -495,6 +509,7 @@ processCapability(); processExtension(); processMemoryModel(); + processDebugInfo(); // Iterate over the module body to serialize it. Assumptions are that there is // only one basic block in the moduleOp @@ -525,6 +540,7 @@ binary.append(memoryModel.begin(), memoryModel.end()); binary.append(entryPoints.begin(), entryPoints.end()); binary.append(executionModes.begin(), executionModes.end()); + binary.append(debug.begin(), debug.end()); binary.append(names.begin(), names.end()); binary.append(decorations.begin(), decorations.end()); binary.append(typesGlobalValues.begin(), typesGlobalValues.end()); @@ -569,6 +585,19 @@ {static_cast(cap)}); } +void Serializer::processDebugInfo() { + if (!emitDebugInfo) + return; + auto fileLoc = module.getLoc().dyn_cast(); + auto fileName = fileLoc ? fileLoc.getFilename() : ""; + fileID = getNextID(); + SmallVector operands; + operands.push_back(fileID); + spirv::encodeStringLiteralInto(operands, fileName); + encodeInstructionInto(debug, spirv::Opcode::OpString, operands); + // TODO: Encode more debug instructions. +} + void Serializer::processExtension() { llvm::SmallVector extName; for (spirv::Extension ext : module.vce_triple()->getExtensions()) { @@ -1838,13 +1867,26 @@ return success(); } +LogicalResult Serializer::emitDebugLine(SmallVectorImpl &binary, + Location loc) { + if (!emitDebugInfo) + return success(); + + auto fileLoc = loc.dyn_cast(); + if (fileLoc) + encodeInstructionInto(binary, spirv::Opcode::OpLine, + {fileID, fileLoc.getLine(), fileLoc.getColumn()}); + return success(); +} + LogicalResult spirv::serialize(spirv::ModuleOp module, - SmallVectorImpl &binary) { + SmallVectorImpl &binary, + bool emitDebugInfo) { if (!module.vce_triple().hasValue()) return module.emitError( "module must have 'vce_triple' attribute to be serializeable"); - Serializer serializer(module); + Serializer serializer(module, emitDebugInfo); if (failed(serializer.serialize())) return failure(); diff --git a/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp b/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp @@ -91,7 +91,8 @@ if (spirvModules.size() != 1) return module.emitError("found more than one 'spv.module' op"); - if (failed(spirv::serialize(spirvModules[0], binary))) + if (failed( + spirv::serialize(spirvModules[0], binary, /*emitDebuginfo=*/false))) return failure(); output.write(reinterpret_cast(binary.data()), @@ -114,7 +115,7 @@ //===----------------------------------------------------------------------===// static LogicalResult roundTripModule(llvm::SourceMgr &sourceMgr, - raw_ostream &output, + bool emitDebugInfo, raw_ostream &output, MLIRContext *context) { // Parse an MLIR module from the source manager. auto srcModule = OwningModuleRef(parseSourceFile(sourceMgr, context)); @@ -131,7 +132,7 @@ if (std::next(spirvModules.begin()) != spirvModules.end()) return srcModule->emitError("found more than one 'spv.module' op"); - if (failed(spirv::serialize(*spirvModules.begin(), binary))) + if (failed(spirv::serialize(*spirvModules.begin(), binary, emitDebugInfo))) return failure(); // Then deserialize to get back a SPIR-V module. @@ -153,7 +154,18 @@ TranslateRegistration roundtrip( "test-spirv-roundtrip", [](llvm::SourceMgr &sourceMgr, raw_ostream &output, MLIRContext *context) { - return roundTripModule(sourceMgr, output, context); + return roundTripModule(sourceMgr, /*emitDebugInfo=*/false, output, + context); + }); +} + +void registerTestRoundtripDebugSPIRV() { + TranslateRegistration roundtrip( + "test-spirv-roundtrip-debug", + [](llvm::SourceMgr &sourceMgr, raw_ostream &output, + MLIRContext *context) { + return roundTripModule(sourceMgr, /*emitDebugInfo=*/true, output, + context); }); } } // namespace mlir diff --git a/mlir/test/Dialect/SPIRV/Serialization/debug.mlir b/mlir/test/Dialect/SPIRV/Serialization/debug.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Serialization/debug.mlir @@ -0,0 +1,60 @@ +// RUN: mlir-translate -test-spirv-roundtrip-debug -mlir-print-debuginfo %s | FileCheck %s + +spv.module Logical GLSL450 requires #spv.vce { + spv.func @arithmetic(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) "None" { + // CHECK: loc({{".*debug.mlir"}}:6:10) + %0 = spv.FAdd %arg0, %arg1 : vector<4xf32> + // CHECK: loc({{".*debug.mlir"}}:8:10) + %1 = spv.FNegate %arg0 : vector<4xf32> + spv.Return + } + + spv.func @atomic(%ptr: !spv.ptr, %value: i32, %comparator: i32) "None" { + // CHECK: loc({{".*debug.mlir"}}:14:10) + %1 = spv.AtomicAnd "Device" "None" %ptr, %value : !spv.ptr + spv.Return + } + + spv.func @bitwiser(%arg0 : i32, %arg1 : i32) "None" { + // CHECK: loc({{".*debug.mlir"}}:20:10) + %0 = spv.BitwiseAnd %arg0, %arg1 : i32 + spv.Return + } + + spv.func @convert(%arg0 : f32) "None" { + // CHECK: loc({{".*debug.mlir"}}:26:10) + %0 = spv.ConvertFToU %arg0 : f32 to i32 + spv.Return + } + + spv.func @composite(%arg0 : !spv.struct, f32>>, %arg1: !spv.array<4xf32>, %arg2 : f32, %arg3 : f32) "None" { + // CHECK: loc({{".*debug.mlir"}}:32:10) + %0 = spv.CompositeInsert %arg1, %arg0[1 : i32, 0 : i32] : !spv.array<4xf32> into !spv.struct, f32>> + // CHECK: loc({{".*debug.mlir"}}:34:10) + %1 = spv.CompositeConstruct %arg2, %arg3 : vector<2xf32> + spv.Return + } + + spv.func @group_non_uniform(%val: f32) "None" { + // CHECK: loc({{".*debug.mlir"}}:40:10) + %0 = spv.GroupNonUniformFAdd "Workgroup" "Reduce" %val : f32 + spv.Return + } + + spv.func @logical(%arg0: i32, %arg1: i32) "None" { + // CHECK: loc({{".*debug.mlir"}}:46:10) + %0 = spv.IEqual %arg0, %arg1 : i32 + spv.Return + } + + spv.func @memory_accesses(%arg0 : !spv.ptr>, StorageBuffer>, %arg1 : i32, %arg2 : i32) "None" { + // CHECK: loc({{".*debug.mlir"}}:52:10) + %2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr>, StorageBuffer> + // CHECK: loc({{".*debug.mlir"}}:54:10) + %3 = spv.Load "StorageBuffer" %2 : f32 + // CHECK: loc({{.*debug.mlir"}}:56:5) + spv.Store "StorageBuffer" %2, %3 : f32 + // CHECK: loc({{".*debug.mlir"}}:58:5) + spv.Return + } +} diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -660,6 +660,8 @@ opVar, record->getValueAsString("extendedInstSetName"), record->getValueAsInt("extendedInstOpcode"), operands); } else { + // Emit debug info. + os << formatv(" emitDebugLine(functionBody, {0}.getLoc());\n", opVar); os << formatv(" encodeInstructionInto(" "functionBody, spirv::getOpcode<{0}>(), {1});\n", op.getQualCppClassName(), operands); @@ -900,14 +902,22 @@ emitOperandDeserialization(op, record->getLoc(), " ", words, wordIndex, operands, attributes, os); - os << formatv( - " auto {1} = opBuilder.create<{0}>(unknownLoc, {2}, {3}, {4}); " - "(void){1};\n", - op.getQualCppClassName(), opVar, resultTypes, operands, attributes); + os << formatv(" Location loc = createFileLineColLoc(opBuilder);\n"); + os << formatv(" auto {1} = opBuilder.create<{0}>(loc, {2}, {3}, {4}); " + "(void){1};\n", + op.getQualCppClassName(), opVar, resultTypes, operands, + attributes); if (op.getNumResults() == 1) { os << formatv(" valueMap[{0}] = {1}.getResult();\n\n", valueID, opVar); } + // According to SPIR-V spec: + // This location information applies to the instructions physically following + // this instruction, up to the first occurrence of any of the following: the + // next end of block. + os << formatv(" if ({0}.hasTrait())\n", opVar); + os << formatv(" clearDebugLine();\n"); + // Decorations emitDecorationDeserialization(op, " ", valueID, attributes, os); os << " return success();\n"; diff --git a/mlir/tools/mlir-translate/mlir-translate.cpp b/mlir/tools/mlir-translate/mlir-translate.cpp --- a/mlir/tools/mlir-translate/mlir-translate.cpp +++ b/mlir/tools/mlir-translate/mlir-translate.cpp @@ -50,9 +50,13 @@ namespace mlir { // Defined in the test directory, no public header. void registerTestRoundtripSPIRV(); +void registerTestRoundtripDebugSPIRV(); } // namespace mlir -static void registerTestTranslations() { registerTestRoundtripSPIRV(); } +static void registerTestTranslations() { + registerTestRoundtripSPIRV(); + registerTestRoundtripDebugSPIRV(); +} int main(int argc, char **argv) { registerAllDialects();