diff --git a/mlir/include/mlir/Target/SPIRV/SPIRVBinaryUtils.h b/mlir/include/mlir/Target/SPIRV/SPIRVBinaryUtils.h --- a/mlir/include/mlir/Target/SPIRV/SPIRVBinaryUtils.h +++ b/mlir/include/mlir/Target/SPIRV/SPIRVBinaryUtils.h @@ -13,8 +13,8 @@ #ifndef MLIR_TARGET_SPIRV_BINARY_UTILS_H_ #define MLIR_TARGET_SPIRV_BINARY_UTILS_H_ -#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" -#include "mlir/Support/LogicalResult.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" +#include "mlir/Support/LLVM.h" #include @@ -41,6 +41,16 @@ /// Encodes an SPIR-V `literal` string into the given `binary` vector. LogicalResult encodeStringLiteralInto(SmallVectorImpl &binary, StringRef literal); + +/// Decodes a string literal in `words` starting at `wordIndex`. Update the +/// latter to point to the position in words after the string literal. +inline StringRef decodeStringLiteral(ArrayRef words, + unsigned &wordIndex) { + StringRef str(reinterpret_cast(words.data() + wordIndex)); + wordIndex += str.size() / 4 + 1; + return str; +} + } // namespace spirv } // namespace mlir diff --git a/mlir/include/mlir/Target/SPIRV/Serialization.h b/mlir/include/mlir/Target/SPIRV/Serialization.h --- a/mlir/include/mlir/Target/SPIRV/Serialization.h +++ b/mlir/include/mlir/Target/SPIRV/Serialization.h @@ -22,11 +22,18 @@ namespace spirv { class ModuleOp; +struct SerializationOptions { + /// Whether to emit `OpName` instructions for SPIR-V symbol ops. + bool emitSymbolName = true; + /// Whether to emit `OpLine` location information for SPIR-V ops. + bool emitDebugInfo = false; +}; + /// Serializes the given SPIR-V `module` and writes to `binary`. On failure, /// reports errors to the error handler registered with the MLIR context for /// `module`. LogicalResult serialize(ModuleOp module, SmallVectorImpl &binary, - bool emitDebugInfo = false); + const SerializationOptions &options = {}); } // namespace spirv } // namespace mlir diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp --- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Location.h" +#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -21,19 +21,6 @@ #include "llvm/ADT/StringRef.h" #include -//===----------------------------------------------------------------------===// -// Utility Functions -//===----------------------------------------------------------------------===// - -/// Decodes a string literal in `words` starting at `wordIndex`. Update the -/// latter to point to the position in words after the string literal. -static inline llvm::StringRef -decodeStringLiteral(llvm::ArrayRef words, unsigned &wordIndex) { - llvm::StringRef str(reinterpret_cast(words.data() + wordIndex)); - wordIndex += str.size() / 4 + 1; - return str; -} - namespace mlir { namespace spirv { diff --git a/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp b/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp --- a/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serialization.cpp @@ -23,12 +23,12 @@ namespace mlir { LogicalResult spirv::serialize(spirv::ModuleOp module, SmallVectorImpl &binary, - bool emitDebugInfo) { + const SerializationOptions &options) { if (!module.vce_triple().hasValue()) return module.emitError( "module must have 'vce_triple' attribute to be serializeable"); - Serializer serializer(module, emitDebugInfo); + Serializer serializer(module, options); if (failed(serializer.serialize())) return failure(); diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h @@ -15,6 +15,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/IR/Builders.h" +#include "mlir/Target/SPIRV/Serialization.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/raw_ostream.h" @@ -42,7 +43,8 @@ class Serializer { public: /// Creates a serializer for the given SPIR-V `module`. - explicit Serializer(spirv::ModuleOp module, bool emitDebugInfo = false); + explicit Serializer(spirv::ModuleOp module, + const SerializationOptions &options); /// Serializes the remembered SPIR-V module. LogicalResult serialize(); @@ -316,8 +318,8 @@ /// An MLIR builder for getting MLIR constructs. mlir::Builder mlirBuilder; - /// A flag which indicates if the debuginfo should be emitted. - bool emitDebugInfo = false; + /// Serialization options. + SerializationOptions options; /// A flag which indicates if the last processed instruction was a merge /// instruction. diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -81,9 +81,9 @@ return success(); } -Serializer::Serializer(spirv::ModuleOp module, bool emitDebugInfo) - : module(module), mlirBuilder(module.getContext()), - emitDebugInfo(emitDebugInfo) {} +Serializer::Serializer(spirv::ModuleOp module, + const SerializationOptions &options) + : module(module), mlirBuilder(module.getContext()), options(options) {} LogicalResult Serializer::serialize() { LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n"); @@ -172,7 +172,7 @@ } void Serializer::processDebugInfo() { - if (!emitDebugInfo) + if (!options.emitDebugInfo) return; auto fileLoc = module.getLoc().dyn_cast(); auto fileName = fileLoc ? fileLoc.getFilename().strref() : ""; @@ -254,12 +254,13 @@ LogicalResult Serializer::processName(uint32_t resultID, StringRef name) { assert(!name.empty() && "unexpected empty string for OpName"); + if (!options.emitSymbolName) + return success(); SmallVector nameOperands; nameOperands.push_back(resultID); - if (failed(spirv::encodeStringLiteralInto(nameOperands, name))) { + if (failed(spirv::encodeStringLiteralInto(nameOperands, name))) return failure(); - } return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands); } @@ -1170,7 +1171,7 @@ LogicalResult Serializer::emitDebugLine(SmallVectorImpl &binary, Location loc) { - if (!emitDebugInfo) + if (!options.emitDebugInfo) return success(); if (lastProcessedWasMergeInst) { diff --git a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp --- a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp +++ b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp @@ -40,7 +40,7 @@ context->loadDialect(); // Make sure the input stream can be treated as a stream of SPIR-V words - auto start = input->getBufferStart(); + auto *start = input->getBufferStart(); auto size = input->getBufferSize(); if (size % sizeof(uint32_t) != 0) { emitError(UnknownLoc::get(context)) @@ -94,8 +94,7 @@ if (spirvModules.size() != 1) return module.emitError("found more than one 'spv.module' op"); - if (failed( - spirv::serialize(spirvModules[0], binary, /*emitDebuginfo=*/false))) + if (failed(spirv::serialize(spirvModules[0], binary))) return failure(); output.write(reinterpret_cast(binary.data()), @@ -133,7 +132,9 @@ if (std::next(spirvModules.begin()) != spirvModules.end()) return srcModule.emitError("found more than one 'spv.module' op"); - if (failed(spirv::serialize(*spirvModules.begin(), binary, emitDebugInfo))) + spirv::SerializationOptions options; + options.emitDebugInfo = emitDebugInfo; + if (failed(spirv::serialize(*spirvModules.begin(), binary, options))) return failure(); MLIRContext deserializationContext(context->getDialectRegistry()); diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp --- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp +++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp @@ -37,10 +37,11 @@ protected: SerializationTest() { context.getOrLoadDialect(); - createModuleOp(); + initModuleOp(); } - void createModuleOp() { + /// Initializes an empty SPIR-V module op. + void initModuleOp() { OpBuilder builder(&context); OperationState state(UnknownLoc::get(&context), spirv::ModuleOp::getOperationName()); @@ -58,27 +59,29 @@ module = cast(Operation::create(state)); } - Type getFloatStructType() { - OpBuilder opBuilder(module->getRegion()); - llvm::SmallVector elementTypes{opBuilder.getF32Type()}; + /// Gets the `struct { float }` type. + spirv::StructType getFloatStructType() { + OpBuilder builder(module->getRegion()); + llvm::SmallVector elementTypes{builder.getF32Type()}; llvm::SmallVector offsetInfo{0}; - auto structType = spirv::StructType::get(elementTypes, offsetInfo); - return structType; + return spirv::StructType::get(elementTypes, offsetInfo); } - void addGlobalVar(Type type, llvm::StringRef name) { - OpBuilder opBuilder(module->getRegion()); + /// Inserts a global variable of the given `type` and `name`. + spirv::GlobalVariableOp addGlobalVar(Type type, llvm::StringRef name) { + OpBuilder builder(module->getRegion()); auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform); - opBuilder.create( + return builder.create( UnknownLoc::get(&context), TypeAttr::get(ptrType), - opBuilder.getStringAttr(name), nullptr); + builder.getStringAttr(name), nullptr); } + /// Returns true if we can find a matching instruction in the SPIR-V blob. bool findInstruction(llvm::function_ref operands)> matchFn) { auto binarySize = binary.size(); - auto begin = binary.begin(); + auto *begin = binary.begin(); auto currOffset = spirv::kHeaderWordCount; while (currOffset < binarySize) { @@ -109,10 +112,12 @@ // Block decoration //===----------------------------------------------------------------------===// -TEST_F(SerializationTest, BlockDecorationTest) { +TEST_F(SerializationTest, ContainsBlockDecoration) { auto structType = getFloatStructType(); addGlobalVar(structType, "var0"); + ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary))); + auto hasBlockDecoration = [](spirv::Opcode opcode, ArrayRef operands) -> bool { if (opcode != spirv::Opcode::OpDecorate || operands.size() != 2) @@ -121,3 +126,35 @@ }; EXPECT_TRUE(findInstruction(hasBlockDecoration)); } + +TEST_F(SerializationTest, ContainsSymbolName) { + auto structType = getFloatStructType(); + addGlobalVar(structType, "var0"); + + spirv::SerializationOptions options; + options.emitSymbolName = true; + ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary, options))); + + auto hasVarName = [](spirv::Opcode opcode, ArrayRef operands) { + unsigned index = 1; // Skip the result + return opcode == spirv::Opcode::OpName && + spirv::decodeStringLiteral(operands, index) == "var0"; + }; + EXPECT_TRUE(findInstruction(hasVarName)); +} + +TEST_F(SerializationTest, DoesNotContainSymbolName) { + auto structType = getFloatStructType(); + addGlobalVar(structType, "var0"); + + spirv::SerializationOptions options; + options.emitSymbolName = false; + ASSERT_TRUE(succeeded(spirv::serialize(module.get(), binary, options))); + + auto hasVarName = [](spirv::Opcode opcode, ArrayRef operands) { + unsigned index = 1; // Skip the result + return opcode == spirv::Opcode::OpName && + spirv::decodeStringLiteral(operands, index) == "var0"; + }; + EXPECT_FALSE(findInstruction(hasVarName)); +}