diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVModule.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVModule.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVModule.h @@ -0,0 +1,29 @@ +//===- SPIRVModule.h - SPIR-V Module Utilities ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPIRV_SPIRVMODULE_H +#define MLIR_DIALECT_SPIRV_SPIRVMODULE_H + +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/IR/OwningOpRefBase.h" + +namespace mlir { +namespace spirv { + +/// This class acts as an owning reference to a SPIR-V module, and will +/// automatically destroy the held module on destruction if the held module +/// is valid. +class OwningSPIRVModuleRef : public OwningOpRefBase { +public: + using OwningOpRefBase::OwningOpRefBase; +}; + +} // end namespace spirv +} // end namespace mlir + +#endif // MLIR_DIALECT_SPIRV_SPIRVMODULE_H diff --git a/mlir/include/mlir/Dialect/SPIRV/Serialization.h b/mlir/include/mlir/Dialect/SPIRV/Serialization.h --- a/mlir/include/mlir/Dialect/SPIRV/Serialization.h +++ b/mlir/include/mlir/Dialect/SPIRV/Serialization.h @@ -22,6 +22,7 @@ namespace spirv { class ModuleOp; +class OwningSPIRVModuleRef; /// Serializes the given SPIR-V `module` and writes to `binary`. On failure, /// reports errors to the error handler registered with the MLIR context for @@ -31,9 +32,10 @@ /// Deserializes the given SPIR-V `binary` module and creates a MLIR ModuleOp /// in the given `context`. Returns the ModuleOp on success; otherwise, reports -/// errors to the error handler registered with `context` and returns -/// llvm::None. -Optional deserialize(ArrayRef binary, MLIRContext *context); +/// errors to the error handler registered with `context` and returns a null +/// module. +OwningSPIRVModuleRef deserialize(ArrayRef binary, + MLIRContext *context); } // end namespace spirv } // end namespace mlir diff --git a/mlir/include/mlir/IR/Module.h b/mlir/include/mlir/IR/Module.h --- a/mlir/include/mlir/IR/Module.h +++ b/mlir/include/mlir/IR/Module.h @@ -13,6 +13,7 @@ #ifndef MLIR_IR_MODULE_H #define MLIR_IR_MODULE_H +#include "mlir/IR/OwningOpRefBase.h" #include "mlir/IR/SymbolTable.h" #include "llvm/Support/PointerLikeTypeTraits.h" @@ -122,40 +123,10 @@ }; /// This class acts as an owning reference to a module, and will automatically -/// destroy the held module if valid. -class OwningModuleRef { +/// destroy the held module on destruction if the held module is valid. +class OwningModuleRef : public OwningOpRefBase { public: - OwningModuleRef(std::nullptr_t = nullptr) {} - OwningModuleRef(ModuleOp module) : module(module) {} - OwningModuleRef(OwningModuleRef &&other) : module(other.release()) {} - ~OwningModuleRef() { - if (module) - module.erase(); - } - - // Assign from another module reference. - OwningModuleRef &operator=(OwningModuleRef &&other) { - if (module) - module.erase(); - module = other.release(); - return *this; - } - - /// Allow accessing the internal module. - ModuleOp get() const { return module; } - ModuleOp operator*() const { return module; } - ModuleOp *operator->() { return &module; } - explicit operator bool() const { return module; } - - /// Release the referenced module. - ModuleOp release() { - ModuleOp released; - std::swap(released, module); - return released; - } - -private: - ModuleOp module; + using OwningOpRefBase::OwningOpRefBase; }; } // end namespace mlir diff --git a/mlir/include/mlir/IR/OwningOpRefBase.h b/mlir/include/mlir/IR/OwningOpRefBase.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/OwningOpRefBase.h @@ -0,0 +1,64 @@ +//===- OwningOpRefBase.h - MLIR OwningOpRefBase -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides a base class for owning op refs. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_OWNINGOPREFBASE_H +#define MLIR_IR_OWNINGOPREFBASE_H + +#include + +namespace mlir { + +/// This class acts as an owning reference to an op, and will automatically +/// destroy the held op on destruction if the held op is valid. +/// +/// Note that OpBuilder and related functionality should be highly preferred +/// instead, and this should only be used in situations where existing solutions +/// are not viable. +template +class OwningOpRefBase { +public: + OwningOpRefBase(std::nullptr_t = nullptr) {} + OwningOpRefBase(OpTy op) : op(op) {} + OwningOpRefBase(OwningOpRefBase &&other) : op(other.release()) {} + ~OwningOpRefBase() { + if (op) + op.erase(); + } + + // Assign from another op reference. + OwningOpRefBase &operator=(OwningOpRefBase &&other) { + if (op) + op.erase(); + op = other.release(); + return *this; + } + + /// Allow accessing the internal op. + OpTy get() const { return op; } + OpTy operator*() const { return op; } + OpTy *operator->() { return &op; } + explicit operator bool() const { return op; } + + /// Release the referenced op. + OpTy release() { + OpTy released; + std::swap(released, op); + return released; + } + +private: + OpTy op; +}; + +} // end namespace mlir + +#endif // MLIR_IR_OWNINGOPREFBASE_H diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/SPIRV/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h" +#include "mlir/Dialect/SPIRV/SPIRVModule.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/IR/BlockAndValueMapping.h" @@ -2516,12 +2517,12 @@ #include "mlir/Dialect/SPIRV/SPIRVSerialization.inc" } // namespace -Optional spirv::deserialize(ArrayRef binary, - MLIRContext *context) { +spirv::OwningSPIRVModuleRef spirv::deserialize(ArrayRef binary, + MLIRContext *context) { Deserializer deserializer(binary, context); if (failed(deserializer.deserialize())) - return llvm::None; + return nullptr; - return deserializer.collect(); + return deserializer.collect().getValueOr(nullptr); } diff --git a/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp b/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/SPIRV/SPIRVModule.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Serialization.h" #include "mlir/IR/Builders.h" @@ -49,13 +50,13 @@ auto binary = llvm::makeArrayRef(reinterpret_cast(start), size / sizeof(uint32_t)); - auto spirvModule = spirv::deserialize(binary, context); + spirv::OwningSPIRVModuleRef spirvModule = spirv::deserialize(binary, context); if (!spirvModule) return {}; OwningModuleRef module(ModuleOp::create(FileLineColLoc::get( input->getBufferIdentifier(), /*line=*/0, /*column=*/0, context))); - module->getBody()->push_front(spirvModule->getOperation()); + module->getBody()->push_front(spirvModule.release()); return module; } @@ -136,14 +137,14 @@ return failure(); // Then deserialize to get back a SPIR-V module. - auto spirvModule = spirv::deserialize(binary, context); + spirv::OwningSPIRVModuleRef spirvModule = spirv::deserialize(binary, context); if (!spirvModule) return failure(); // Wrap around in a new MLIR module. OwningModuleRef dstModule(ModuleOp::create(FileLineColLoc::get( /*filename=*/"", /*line=*/0, /*column=*/0, context))); - dstModule->getBody()->push_front(spirvModule->getOperation()); + dstModule->getBody()->push_front(spirvModule.release()); dstModule->print(output); return mlir::success(); diff --git a/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp b/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp --- a/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp +++ b/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/SPIRV/SPIRVBinaryUtils.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/SPIRVModule.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Serialization.h" #include "mlir/IR/Diagnostics.h" @@ -46,7 +47,7 @@ } /// Performs deserialization and returns the constructed spv.module op. - Optional deserialize() { + spirv::OwningSPIRVModuleRef deserialize() { return spirv::deserialize(binary, &context); } @@ -130,27 +131,27 @@ //===----------------------------------------------------------------------===// TEST_F(DeserializationTest, EmptyModuleFailure) { - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("SPIR-V binary module must have a 5-word header"); } TEST_F(DeserializationTest, WrongMagicNumberFailure) { addHeader(); binary.front() = 0xdeadbeef; // Change to a wrong magic number - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("incorrect magic number"); } TEST_F(DeserializationTest, OnlyHeaderSuccess) { addHeader(); - EXPECT_NE(llvm::None, deserialize()); + EXPECT_TRUE(deserialize()); } TEST_F(DeserializationTest, ZeroWordCountFailure) { addHeader(); binary.push_back(0); // OpNop with zero word count - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("word count cannot be zero"); } @@ -160,7 +161,7 @@ static_cast(spirv::Opcode::OpTypeVoid)); // Missing word for type - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("insufficient words for the last instruction"); } @@ -172,7 +173,7 @@ addHeader(); addInstruction(spirv::Opcode::OpTypeInt, {nextID++, 32}); - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters"); } @@ -198,7 +199,7 @@ addInstruction(spirv::Opcode::OpMemberName, operands2); binary.append(typeDecl.begin(), typeDecl.end()); - EXPECT_NE(llvm::None, deserialize()); + EXPECT_TRUE(deserialize()); } TEST_F(DeserializationTest, OpMemberNameMissingOperands) { @@ -215,7 +216,7 @@ addInstruction(spirv::Opcode::OpMemberName, operands1); binary.append(typeDecl.begin(), typeDecl.end()); - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("OpMemberName must have at least 3 operands"); } @@ -234,7 +235,7 @@ addInstruction(spirv::Opcode::OpMemberName, operands); binary.append(typeDecl.begin(), typeDecl.end()); - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("unexpected trailing words in OpMemberName instruction"); } @@ -249,7 +250,7 @@ addFunction(voidType, fnType); // Missing OpFunctionEnd - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("expected OpFunctionEnd instruction"); } @@ -261,7 +262,7 @@ addFunction(voidType, fnType); // Missing OpFunctionParameter - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("expected OpFunctionParameter instruction"); } @@ -274,7 +275,7 @@ addReturn(); addFunctionEnd(); - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("a basic block must start with OpLabel"); } @@ -287,6 +288,6 @@ addReturn(); addFunctionEnd(); - ASSERT_EQ(llvm::None, deserialize()); + ASSERT_FALSE(deserialize()); expectDiagnostic("OpLabel should only have result "); }