diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVModule.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVModule.h deleted file mode 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVModule.h +++ /dev/null @@ -1,30 +0,0 @@ -//===- SPIRVModule.h - SPIR-V Module Utilities ------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_SPIRV_IR_SPIRVMODULE_H -#define MLIR_DIALECT_SPIRV_IR_SPIRVMODULE_H - -#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" -#include "mlir/IR/OwningOpRef.h" - -namespace mlir { -namespace spirv { - -/// This class acts as an owning reference to a SPIR-V module, and will -/// automatically destroy the held module on destruction if the held module -/// is valid. -// TODO: Remove this class in favor of using OwningOpRef directly. -class OwningSPIRVModuleRef : public OwningOpRef { -public: - using OwningOpRef::OwningOpRef; -}; - -} // end namespace spirv -} // end namespace mlir - -#endif // MLIR_DIALECT_SPIRV_IR_SPIRVMODULE_H diff --git a/mlir/include/mlir/Dialect/SPIRV/Linking/ModuleCombiner.h b/mlir/include/mlir/Dialect/SPIRV/Linking/ModuleCombiner.h --- a/mlir/include/mlir/Dialect/SPIRV/Linking/ModuleCombiner.h +++ b/mlir/include/mlir/Dialect/SPIRV/Linking/ModuleCombiner.h @@ -13,9 +13,8 @@ #ifndef MLIR_DIALECT_SPIRV_LINKING_MODULECOMBINER_H_ #define MLIR_DIALECT_SPIRV_LINKING_MODULECOMBINER_H_ -#include "mlir/Dialect/SPIRV/IR/SPIRVModule.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/Support/LLVM.h" namespace mlir { class OpBuilder; @@ -67,11 +66,9 @@ /// function call. /// /// \return the combined module. -OwningSPIRVModuleRef -combine(llvm::MutableArrayRef modules, - OpBuilder &combinedModuleBuilder, - llvm::function_ref - symbRenameListener); +OwningOpRef +combine(MutableArrayRef modules, OpBuilder &combinedModuleBuilder, + function_ref symbRenameListener); } // namespace spirv } // namespace mlir diff --git a/mlir/include/mlir/Target/SPIRV/Deserialization.h b/mlir/include/mlir/Target/SPIRV/Deserialization.h --- a/mlir/include/mlir/Target/SPIRV/Deserialization.h +++ b/mlir/include/mlir/Target/SPIRV/Deserialization.h @@ -13,21 +13,21 @@ #ifndef MLIR_TARGET_SPIRV_DESERIALIZATION_H #define MLIR_TARGET_SPIRV_DESERIALIZATION_H +#include "mlir/IR/OwningOpRef.h" #include "mlir/Support/LLVM.h" namespace mlir { -struct LogicalResult; class MLIRContext; namespace spirv { -class OwningSPIRVModuleRef; +class ModuleOp; /// Deserializes the given SPIR-V `binary` module and creates a MLIR ModuleOp /// in the given `context`. Returns the ModuleOp on success; otherwise, reports /// errors to the error handler registered with `context` and returns a null /// module. -OwningSPIRVModuleRef deserialize(ArrayRef binary, - MLIRContext *context); +OwningOpRef deserialize(ArrayRef binary, + MLIRContext *context); } // end namespace spirv } // end namespace mlir diff --git a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp --- a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp +++ b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp @@ -119,7 +119,7 @@ // TODO Properly test symbol rename listener mechanism. -OwningSPIRVModuleRef +OwningOpRef combine(llvm::MutableArrayRef modules, OpBuilder &combinedModuleBuilder, llvm::function_ref diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserialization.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserialization.cpp --- a/mlir/lib/Target/SPIRV/Deserialization/Deserialization.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserialization.cpp @@ -10,9 +10,10 @@ #include "Deserializer.h" -namespace mlir { -spirv::OwningSPIRVModuleRef spirv::deserialize(ArrayRef binary, - MLIRContext *context) { +using namespace mlir; + +OwningOpRef spirv::deserialize(ArrayRef binary, + MLIRContext *context) { Deserializer deserializer(binary, context); if (failed(deserializer.deserialize())) @@ -20,4 +21,3 @@ return deserializer.collect(); } -} // namespace mlir diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -14,7 +14,6 @@ #define MLIR_TARGET_SPIRV_DESERIALIZER_H #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" -#include "mlir/Dialect/SPIRV/IR/SPIRVModule.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/IR/Builders.h" #include "llvm/ADT/ArrayRef.h" @@ -142,7 +141,7 @@ LogicalResult deserialize(); /// Collects the final SPIR-V ModuleOp. - spirv::OwningSPIRVModuleRef collect(); + OwningOpRef collect(); private: //===--------------------------------------------------------------------===// @@ -150,7 +149,7 @@ //===--------------------------------------------------------------------===// /// Initializes the `module` ModuleOp in this deserializer instance. - spirv::OwningSPIRVModuleRef createModuleOp(); + OwningOpRef createModuleOp(); /// Processes SPIR-V module header in `binary`. LogicalResult processHeader(); @@ -507,7 +506,7 @@ Location unknownLoc; /// The SPIR-V ModuleOp. - spirv::OwningSPIRVModuleRef module; + OwningOpRef module; /// The current function under construction. Optional curFunction; diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -14,7 +14,6 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" -#include "mlir/Dialect/SPIRV/IR/SPIRVModule.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/IR/BlockAndValueMapping.h" @@ -88,7 +87,7 @@ return success(); } -spirv::OwningSPIRVModuleRef spirv::Deserializer::collect() { +OwningOpRef spirv::Deserializer::collect() { return std::move(module); } @@ -96,7 +95,7 @@ // Module structure //===----------------------------------------------------------------------===// -spirv::OwningSPIRVModuleRef spirv::Deserializer::createModuleOp() { +OwningOpRef spirv::Deserializer::createModuleOp() { OpBuilder builder(context); OperationState state(unknownLoc, spirv::ModuleOp::getOperationName()); spirv::ModuleOp::build(builder, state); diff --git a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp --- a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp +++ b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp @@ -12,7 +12,6 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" -#include "mlir/Dialect/SPIRV/IR/SPIRVModule.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -52,7 +51,8 @@ auto binary = llvm::makeArrayRef(reinterpret_cast(start), size / sizeof(uint32_t)); - spirv::OwningSPIRVModuleRef spirvModule = spirv::deserialize(binary, context); + OwningOpRef spirvModule = + spirv::deserialize(binary, context); if (!spirvModule) return {}; @@ -140,7 +140,7 @@ // TODO: we should only load the required dialects instead of all dialects. deserializationContext.loadAllAvailableDialects(); // Then deserialize to get back a SPIR-V module. - spirv::OwningSPIRVModuleRef spirvModule = + OwningOpRef spirvModule = spirv::deserialize(binary, &deserializationContext); if (!spirvModule) return failure(); diff --git a/mlir/test/lib/Dialect/SPIRV/TestGLSLCanonicalization.cpp b/mlir/test/lib/Dialect/SPIRV/TestGLSLCanonicalization.cpp --- a/mlir/test/lib/Dialect/SPIRV/TestGLSLCanonicalization.cpp +++ b/mlir/test/lib/Dialect/SPIRV/TestGLSLCanonicalization.cpp @@ -7,7 +7,6 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h" -#include "mlir/Dialect/SPIRV/IR/SPIRVModule.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" diff --git a/mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp b/mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp --- a/mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp +++ b/mlir/test/lib/Dialect/SPIRV/TestModuleCombiner.cpp @@ -25,7 +25,7 @@ void runOnOperation() override; private: - mlir::spirv::OwningSPIRVModuleRef combinedModule; + OwningOpRef combinedModule; }; } // namespace diff --git a/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp b/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp --- a/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp +++ b/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp @@ -14,7 +14,6 @@ #include "mlir/Target/SPIRV/Deserialization.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" -#include "mlir/Dialect/SPIRV/IR/SPIRVModule.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/MLIRContext.h" @@ -45,7 +44,7 @@ } /// Performs deserialization and returns the constructed spv.module op. - spirv::OwningSPIRVModuleRef deserialize() { + OwningOpRef deserialize() { return spirv::deserialize(binary, &context); } diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp --- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp +++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp @@ -14,7 +14,6 @@ #include "mlir/Target/SPIRV/Serialization.h" #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" -#include "mlir/Dialect/SPIRV/IR/SPIRVModule.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/IR/Builders.h" @@ -102,7 +101,7 @@ protected: MLIRContext context; - spirv::OwningSPIRVModuleRef module; + OwningOpRef module; SmallVector binary; };