diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVModule.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVModule.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVModule.h @@ -0,0 +1,59 @@ +//===- SPIRVModule.h - SPIR-V Module Utilities ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPIRV_SPIRVMODULE_H +#define MLIR_DIALECT_SPIRV_SPIRVMODULE_H + +#include "mlir/Dialect/SPIRV/SPIRVOps.h" + +namespace mlir { +namespace spirv { + +/// This class acts as an owning reference to a SPIR-V module, and will +/// automatically destroy the held module on destruction if the held module +/// is valid. +class OwningSPIRVModuleRef { +public: + OwningSPIRVModuleRef(std::nullptr_t = nullptr) {} + OwningSPIRVModuleRef(spirv::ModuleOp module) : module(module) {} + OwningSPIRVModuleRef(OwningSPIRVModuleRef &&other) + : module(other.release()) {} + ~OwningSPIRVModuleRef() { + if (module) + module.erase(); + } + + // Assign from another module reference. + OwningSPIRVModuleRef &operator=(OwningSPIRVModuleRef &&other) { + if (module) + module.erase(); + module = other.release(); + return *this; + } + + /// Allow accessing the internal module. + spirv::ModuleOp get() const { return module; } + spirv::ModuleOp operator*() const { return module; } + spirv::ModuleOp *operator->() { return &module; } + explicit operator bool() const { return module; } + + /// Release the referenced module. + spirv::ModuleOp release() { + spirv::ModuleOp released; + std::swap(released, module); + return released; + } + +private: + spirv::ModuleOp module; +}; + +} // end namespace spirv +} // end namespace mlir + +#endif // MLIR_DIALECT_SPIRV_SPIRVMODULE_H diff --git a/mlir/include/mlir/Dialect/SPIRV/Serialization.h b/mlir/include/mlir/Dialect/SPIRV/Serialization.h --- a/mlir/include/mlir/Dialect/SPIRV/Serialization.h +++ b/mlir/include/mlir/Dialect/SPIRV/Serialization.h @@ -22,6 +22,7 @@ namespace spirv { class ModuleOp; +class OwningSPIRVModuleRef; /// Serializes the given SPIR-V `module` and writes to `binary`. On failure, /// reports errors to the error handler registered with the MLIR context for @@ -31,9 +32,10 @@ /// Deserializes the given SPIR-V `binary` module and creates a MLIR ModuleOp /// in the given `context`. Returns the ModuleOp on success; otherwise, reports -/// errors to the error handler registered with `context` and returns -/// llvm::None. -Optional deserialize(ArrayRef binary, MLIRContext *context); +/// errors to the error handler registered with `context` and returns a null +/// module. +OwningSPIRVModuleRef deserialize(ArrayRef binary, + MLIRContext *context); } // end namespace spirv } // end namespace mlir diff --git a/mlir/include/mlir/IR/Module.h b/mlir/include/mlir/IR/Module.h --- a/mlir/include/mlir/IR/Module.h +++ b/mlir/include/mlir/IR/Module.h @@ -122,7 +122,7 @@ }; /// This class acts as an owning reference to a module, and will automatically -/// destroy the held module if valid. +/// destroy the held module on destruction if the held module is valid. class OwningModuleRef { public: OwningModuleRef(std::nullptr_t = nullptr) {} diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/SPIRV/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h" +#include "mlir/Dialect/SPIRV/SPIRVModule.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/IR/BlockAndValueMapping.h" @@ -2515,12 +2516,16 @@ #include "mlir/Dialect/SPIRV/SPIRVSerialization.inc" } // namespace -Optional spirv::deserialize(ArrayRef binary, - MLIRContext *context) { +spirv::OwningSPIRVModuleRef spirv::deserialize(ArrayRef binary, + MLIRContext *context) { Deserializer deserializer(binary, context); if (failed(deserializer.deserialize())) - return llvm::None; + return nullptr; + + Optional module = deserializer.collect(); + if (!module) + return nullptr; - return deserializer.collect(); + return *module; } diff --git a/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp b/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/SPIRV/SPIRVModule.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Serialization.h" #include "mlir/IR/Builders.h" @@ -49,13 +50,13 @@ auto binary = llvm::makeArrayRef(reinterpret_cast(start), size / sizeof(uint32_t)); - auto spirvModule = spirv::deserialize(binary, context); + spirv::OwningSPIRVModuleRef spirvModule = spirv::deserialize(binary, context); if (!spirvModule) return {}; OwningModuleRef module(ModuleOp::create(FileLineColLoc::get( input->getBufferIdentifier(), /*line=*/0, /*column=*/0, context))); - module->getBody()->push_front(spirvModule->getOperation()); + module->getBody()->push_front(spirvModule.release()); return module; } @@ -136,14 +137,14 @@ return failure(); // Then deserialize to get back a SPIR-V module. - auto spirvModule = spirv::deserialize(binary, context); + spirv::OwningSPIRVModuleRef spirvModule = spirv::deserialize(binary, context); if (!spirvModule) return failure(); // Wrap around in a new MLIR module. OwningModuleRef dstModule(ModuleOp::create(FileLineColLoc::get( /*filename=*/"", /*line=*/0, /*column=*/0, context))); - dstModule->getBody()->push_front(spirvModule->getOperation()); + dstModule->getBody()->push_front(spirvModule.release()); dstModule->print(output); return mlir::success(); diff --git a/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp b/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp --- a/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp +++ b/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/SPIRVModule.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Serialization.h" #include "mlir/IR/Diagnostics.h" @@ -46,7 +47,7 @@ } /// Performs deserialization and returns the constructed spv.module op. - Optional deserialize() { + spirv::OwningSPIRVModuleRef deserialize() { return spirv::deserialize(binary, &context); } @@ -130,27 +131,27 @@ //===----------------------------------------------------------------------===// TEST_F(DeserializationTest, EmptyModuleFailure) { - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("SPIR-V binary module must have a 5-word header"); } TEST_F(DeserializationTest, WrongMagicNumberFailure) { addHeader(); binary.front() = 0xdeadbeef; // Change to a wrong magic number - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("incorrect magic number"); } TEST_F(DeserializationTest, OnlyHeaderSuccess) { addHeader(); - EXPECT_NE(llvm::None, deserialize()); + EXPECT_TRUE(deserialize()); } TEST_F(DeserializationTest, ZeroWordCountFailure) { addHeader(); binary.push_back(0); // OpNop with zero word count - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("word count cannot be zero"); } @@ -160,7 +161,7 @@ static_cast(spirv::Opcode::OpTypeVoid)); // Missing word for type - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("insufficient words for the last instruction"); } @@ -172,7 +173,7 @@ addHeader(); addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32}); - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters"); } @@ -198,7 +199,7 @@ addInstruction(spirv::Opcode::OpMemberName, operands2); binary.append(typeDecl.begin(), typeDecl.end()); - EXPECT_NE(llvm::None, deserialize()); + EXPECT_TRUE(deserialize()); } TEST_F(DeserializationTest, OpMemberNameMissingOperands) { @@ -215,7 +216,7 @@ addInstruction(spirv::Opcode::OpMemberName, operands1); binary.append(typeDecl.begin(), typeDecl.end()); - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("OpMemberName must have at least 3 operands"); } @@ -234,7 +235,7 @@ addInstruction(spirv::Opcode::OpMemberName, operands); binary.append(typeDecl.begin(), typeDecl.end()); - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("unexpected trailing words in OpMemberName instruction"); } @@ -249,7 +250,7 @@ addFunction(voidType, fnType); // Missing OpFunctionEnd - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("expected OpFunctionEnd instruction"); } @@ -261,7 +262,7 @@ addFunction(voidType, fnType); // Missing OpFunctionParameter - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("expected OpFunctionParameter instruction"); } @@ -274,7 +275,7 @@ addReturn(); addFunctionEnd(); - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("a basic block must start with OpLabel"); } @@ -287,6 +288,6 @@ addReturn(); addFunctionEnd(); - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("OpLabel should only have result "); }