diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -76,33 +76,6 @@ } }; -//===----------------------------------------------------------------------===// -// ToyDialect -//===----------------------------------------------------------------------===// - -/// Dialect creation, the instance will be owned by the context. This is the -/// point of registration of custom types and operations for the dialect. -ToyDialect::ToyDialect(mlir::MLIRContext *ctx) - : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get()) { - addOperations< -#define GET_OP_LIST -#include "toy/Ops.cpp.inc" - >(); - addInterfaces(); - addTypes(); -} - -mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder, - mlir::Attribute value, - mlir::Type type, - mlir::Location loc) { - if (type.isa()) - return builder.create(loc, type, - value.cast()); - return builder.create(loc, type, - value.cast()); -} - //===----------------------------------------------------------------------===// // Toy Operations //===----------------------------------------------------------------------===// @@ -566,3 +539,30 @@ #define GET_OP_CLASSES #include "toy/Ops.cpp.inc" + +//===----------------------------------------------------------------------===// +// ToyDialect +//===----------------------------------------------------------------------===// + +/// Dialect creation, the instance will be owned by the context. This is the +/// point of registration of custom types and operations for the dialect. +ToyDialect::ToyDialect(mlir::MLIRContext *ctx) + : mlir::Dialect(getDialectNamespace(), ctx, TypeID::get()) { + addOperations< +#define GET_OP_LIST +#include "toy/Ops.cpp.inc" + >(); + addInterfaces(); + addTypes(); +} + +mlir::Operation *ToyDialect::materializeConstant(mlir::OpBuilder &builder, + mlir::Attribute value, + mlir::Type type, + mlir::Location loc) { + if (type.isa()) + return builder.create(loc, type, + value.cast()); + return builder.create(loc, type, + value.cast()); +} diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td b/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td --- a/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td @@ -64,6 +64,9 @@ let name = "pdl"; let cppNamespace = "::mlir::pdl"; + let extraClassDeclaration = [{ + void registerTypes(); + }]; } #endif // MLIR_DIALECT_PDL_IR_PDLDIALECT diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -52,6 +52,9 @@ let hasRegionResultAttrVerify = 1; let extraClassDeclaration = [{ + void registerAttributes(); + void registerTypes(); + //===------------------------------------------------------------------===// // Attribute //===------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinDialect.td b/mlir/include/mlir/IR/BuiltinDialect.td --- a/mlir/include/mlir/IR/BuiltinDialect.td +++ b/mlir/include/mlir/IR/BuiltinDialect.td @@ -22,6 +22,17 @@ let name = ""; let cppNamespace = "::mlir"; + let extraClassDeclaration = [{ + private: + // Register the builtin Attributes. + void registerAttributes(); + // Register the builtin Location Attributes. + void registerLocationAttributes(); + // Register the builtin Types. + void registerTypes(); + + public: + }]; } #endif // BUILTIN_BASE diff --git a/mlir/include/mlir/Support/StorageUniquer.h b/mlir/include/mlir/Support/StorageUniquer.h --- a/mlir/include/mlir/Support/StorageUniquer.h +++ b/mlir/include/mlir/Support/StorageUniquer.h @@ -135,7 +135,13 @@ /// instances of this class type. `id` is the type identifier that will be /// used to identify this type when creating instances of it via 'get'. template void registerParametricStorageType(TypeID id) { - registerParametricStorageTypeImpl(id); + // If the storage is trivially destructible, we don't need a destructor + // function. + if (std::is_trivially_destructible::value) + return registerParametricStorageTypeImpl(id, nullptr); + registerParametricStorageTypeImpl(id, [](BaseStorage *storage) { + reinterpret_cast(storage)->~Storage(); + }); } /// Utility override when the storage type represents the type id. template void registerParametricStorageType() { @@ -244,8 +250,10 @@ function_ref ctorFn); /// Implementation for registering an instance of a derived type with - /// parametric storage. - void registerParametricStorageTypeImpl(TypeID id); + /// parametric storage. This method takes an optional destructor function that + /// destructs storage instances when necessary. + void registerParametricStorageTypeImpl( + TypeID id, function_ref destructorFn); /// Implementation for getting an instance of a derived type with default /// storage. diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp --- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp +++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp @@ -20,6 +20,12 @@ using namespace mlir; +#define GET_OP_CLASSES +#include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc" + void arm_sve::ArmSVEDialect::initialize() { addOperations< #define GET_OP_LIST @@ -31,12 +37,6 @@ >(); } -#define GET_OP_CLASSES -#include "mlir/Dialect/ArmSVE/ArmSVE.cpp.inc" - -#define GET_TYPEDEF_CLASSES -#include "mlir/Dialect/ArmSVE/ArmSVETypes.cpp.inc" - //===----------------------------------------------------------------------===// // ScalableVectorType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "TypeDetail.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp --- a/mlir/lib/Dialect/PDL/IR/PDL.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp @@ -25,10 +25,7 @@ #define GET_OP_LIST #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc" >(); - addTypes< -#define GET_TYPEDEF_LIST -#include "mlir/Dialect/PDL/IR/PDLOpsTypes.cpp.inc" - >(); + registerTypes(); } /// Returns true if the given operation is used by a "binding" pdl operation diff --git a/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp --- a/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp @@ -26,6 +26,13 @@ // PDLDialect //===----------------------------------------------------------------------===// +void PDLDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "mlir/Dialect/PDL/IR/PDLOpsTypes.cpp.inc" + >(); +} + static Type parsePDLType(DialectAsmParser &parser) { StringRef keyword; if (parser.parseKeyword(&keyword)) diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/IR/Builders.h" @@ -350,3 +351,11 @@ return success(); } + +//===----------------------------------------------------------------------===// +// SPIR-V Dialect +//===----------------------------------------------------------------------===// + +void spirv::SPIRVDialect::registerAttributes() { + addAttributes(); +} diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -115,10 +115,8 @@ //===----------------------------------------------------------------------===// void SPIRVDialect::initialize() { - addTypes(); - - addAttributes(); + registerAttributes(); + registerTypes(); // Add SPIR-V ops. addOperations< diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -1154,3 +1154,12 @@ // Add any capabilities associated with the underlying vectors (i.e., columns) getColumnType().cast().getCapabilities(capabilities, storage); } + +//===----------------------------------------------------------------------===// +// SPIR-V Dialect +//===----------------------------------------------------------------------===// + +void SPIRVDialect::registerTypes() { + addTypes(); +} diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -100,36 +100,6 @@ return false; } -//===----------------------------------------------------------------------===// -// VectorDialect -//===----------------------------------------------------------------------===// - -void VectorDialect::initialize() { - addAttributes(); - - addOperations< -#define GET_OP_LIST -#include "mlir/Dialect/Vector/VectorOps.cpp.inc" - >(); -} - -/// Materialize a single constant operation from a given attribute value with -/// the desired resultant type. -Operation *VectorDialect::materializeConstant(OpBuilder &builder, - Attribute value, Type type, - Location loc) { - return builder.create(loc, type, value); -} - -IntegerType vector::getVectorSubscriptType(Builder &builder) { - return builder.getIntegerType(64); -} - -ArrayAttr vector::getVectorSubscriptAttr(Builder &builder, - ArrayRef values) { - return builder.getI64ArrayAttr(values); -} - //===----------------------------------------------------------------------===// // CombiningKindAttr //===----------------------------------------------------------------------===// @@ -230,6 +200,36 @@ llvm_unreachable("Unknown attribute type"); } +//===----------------------------------------------------------------------===// +// VectorDialect +//===----------------------------------------------------------------------===// + +void VectorDialect::initialize() { + addAttributes(); + + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/Vector/VectorOps.cpp.inc" + >(); +} + +/// Materialize a single constant operation from a given attribute value with +/// the desired resultant type. +Operation *VectorDialect::materializeConstant(OpBuilder &builder, + Attribute value, Type type, + Location loc) { + return builder.create(loc, type, value); +} + +IntegerType vector::getVectorSubscriptType(Builder &builder) { + return builder.getIntegerType(64); +} + +ArrayAttr vector::getVectorSubscriptAttr(Builder &builder, + ArrayRef values) { + return builder.getI64ArrayAttr(values); +} + //===----------------------------------------------------------------------===// // ReductionOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -9,6 +9,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "AttributeDetail.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/IntegerSet.h" @@ -28,6 +29,18 @@ #define GET_ATTRDEF_CLASSES #include "mlir/IR/BuiltinAttributes.cpp.inc" +//===----------------------------------------------------------------------===// +// BuiltinDialect +//===----------------------------------------------------------------------===// + +void BuiltinDialect::registerAttributes() { + addAttributes(); +} + //===----------------------------------------------------------------------===// // DictionaryAttr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -60,17 +60,9 @@ } // end anonymous namespace. void BuiltinDialect::initialize() { - addTypes(); - addAttributes(); - addAttributes(); + registerTypes(); + registerAttributes(); + registerLocationAttributes(); addOperations< #define GET_OP_LIST #include "mlir/IR/BuiltinOps.cpp.inc" diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -10,6 +10,7 @@ #include "TypeDetail.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "llvm/ADT/APFloat.h" @@ -28,6 +29,17 @@ #define GET_TYPEDEF_CLASSES #include "mlir/IR/BuiltinTypes.cpp.inc" +//===----------------------------------------------------------------------===// +// BuiltinDialect +//===----------------------------------------------------------------------===// + +void BuiltinDialect::registerTypes() { + addTypes< +#define GET_TYPEDEF_LIST +#include "mlir/IR/BuiltinTypes.cpp.inc" + >(); +} + //===----------------------------------------------------------------------===// /// ComplexType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp --- a/mlir/lib/IR/Location.cpp +++ b/mlir/lib/IR/Location.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/Location.h" +#include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Identifier.h" #include "llvm/ADT/SetVector.h" @@ -20,6 +21,17 @@ #define GET_ATTRDEF_CLASSES #include "mlir/IR/BuiltinLocationAttributes.cpp.inc" +//===----------------------------------------------------------------------===// +// BuiltinDialect +//===----------------------------------------------------------------------===// + +void BuiltinDialect::registerLocationAttributes() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "mlir/IR/BuiltinLocationAttributes.cpp.inc" + >(); +} + //===----------------------------------------------------------------------===// // LocationAttr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Support/StorageUniquer.cpp b/mlir/lib/Support/StorageUniquer.cpp --- a/mlir/lib/Support/StorageUniquer.cpp +++ b/mlir/lib/Support/StorageUniquer.cpp @@ -100,12 +100,23 @@ return storage; } + /// Destroy all of the storage instances within the given shard. + void destroyShardInstances(Shard &shard) { + if (!destructorFn) + return; + for (HashedStorage &instance : shard.instances) + destructorFn(instance.storage); + } + public: #if LLVM_ENABLE_THREADS != 0 /// Initialize the storage uniquer with a given number of storage shards to - /// use. The provided shard number is required to be a valid power of 2. - ParametricStorageUniquer(size_t numShards = 8) - : shards(new std::atomic[numShards]), numShards(numShards) { + /// use. The provided shard number is required to be a valid power of 2. The + /// destructor function is used to destroy any allocated storage instances. + ParametricStorageUniquer(function_ref destructorFn, + size_t numShards = 8) + : shards(new std::atomic[numShards]), numShards(numShards), + destructorFn(destructorFn) { assert(llvm::isPowerOf2_64(numShards) && "the number of shards is required to be a power of 2"); for (size_t i = 0; i < numShards; i++) @@ -113,9 +124,12 @@ } ~ParametricStorageUniquer() { // Free all of the allocated shards. - for (size_t i = 0; i != numShards; ++i) - if (Shard *shard = shards[i].load()) + for (size_t i = 0; i != numShards; ++i) { + if (Shard *shard = shards[i].load()) { + destroyShardInstances(*shard); delete shard; + } + } } /// Get or create an instance of a parametric type. BaseStorage * @@ -204,10 +218,17 @@ /// The number of available shards. size_t numShards; + /// Function to used to destruct any allocated storage instances. + function_ref destructorFn; + #else /// If multi-threading is disabled, ignore the shard parameter as we will - /// always use one shard. - ParametricStorageUniquer(size_t numShards = 0) {} + /// always use one shard. The destructor function is used to destroy any + /// allocated storage instances. + ParametricStorageUniquer(function_ref destructorFn, + size_t numShards = 0) + : destructorFn(destructorFn) {} + ~ParametricStorageUniquer() { destroyShardInstances(shard); } /// Get or create an instance of a parametric type. BaseStorage * @@ -228,6 +249,9 @@ private: /// The main uniquer shard that is used for allocating storage instances. Shard shard; + + /// Function to used to destruct any allocated storage instances. + function_ref destructorFn; #endif }; } // end anonymous namespace @@ -323,9 +347,10 @@ /// Implementation for registering an instance of a derived type with /// parametric storage. -void StorageUniquer::registerParametricStorageTypeImpl(TypeID id) { +void StorageUniquer::registerParametricStorageTypeImpl( + TypeID id, function_ref destructorFn) { impl->parametricUniquers.try_emplace( - id, std::make_unique()); + id, std::make_unique(destructorFn)); } /// Implementation for getting an instance of a derived type with default diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -100,6 +100,13 @@ // TestDialect //===----------------------------------------------------------------------===// +void TestDialect::registerAttributes() { + addAttributes< +#define GET_ATTRDEF_LIST +#include "TestAttrDefs.cpp.inc" + >(); +} + Attribute TestDialect::parseAttribute(DialectAsmParser &parser, Type type) const { StringRef attrTag; diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -165,20 +165,14 @@ //===----------------------------------------------------------------------===// void TestDialect::initialize() { + registerAttributes(); + registerTypes(); addOperations< #define GET_OP_LIST #include "TestOps.cpp.inc" >(); - addAttributes< -#define GET_ATTRDEF_LIST -#include "TestAttrDefs.cpp.inc" - >(); addInterfaces(); - addTypes(); allowUnknownOperations(); } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -29,6 +29,9 @@ let hasRegionResultAttrVerify = 1; let extraClassDeclaration = [{ + void registerAttributes(); + void registerTypes(); + Attribute parseAttribute(DialectAsmParser &parser, Type type) const override; void printAttribute(Attribute attr, diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -132,6 +132,13 @@ // TestDialect //===----------------------------------------------------------------------===// +void TestDialect::registerTypes() { + addTypes(); +} + static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser, llvm::SetVector &stack) { StringRef typeTag;