diff --git a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md --- a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md +++ b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md @@ -47,7 +47,8 @@ enum Kinds { // These kinds will be used in the examples below. Simple = Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_0_TYPE, - Complex + Complex, + Recursive }; } ``` @@ -58,13 +59,17 @@ implicitly internal storage object that holds the actual data for the type. When defining a new `Type` it isn't always necessary to define a new storage class. So before defining the derived `Type`, it's important to know which of the two -classes of `Type` we are defining. Some types are `primitives` meaning they do +classes of `Type` we are defining. Some types are _primitives_ meaning they do not have any parameters and are singletons uniqued by kind, like the [`index` type](LangRef.md#index-type). Parametric types on the other hand, have additional information that differentiates different instances of the same `Type` kind. For example the [`integer` type](LangRef.md#integer-type) has a bitwidth, making `i8` and `i16` be different instances of -[`integer` type](LangRef.md#integer-type). +[`integer` type](LangRef.md#integer-type). Types can also have a mutable +component, which can be used, for example, to construct self-referring recursive +types. The mutable component _cannot_ be used to differentiate types within the +same kind, so usually such types are also parametric where the parameters serve +to identify them. #### Simple non-parametric types @@ -240,6 +245,126 @@ }; ``` +#### Types with a mutable component + +Types with a mutable component require defining a type storage class regardless +of being parametric. The storage contains both the parameters and the mutable +component and is accessed in a thread-safe way by the type support +infrastructure. + +##### Defining a type storage + +In addition to the requirements for the type storage class for parametric types, +the storage class for types with a mutable component must additionally obey the +following. + +* The mutable component must not participate in the storage key. +* Provide a mutation method that is used to modify an existing instance of the + storage. This method modifies the mutable component based on arguments, + using `allocator` for any new dynamically-allocated storage, and indicates + whether the modification was successful. + - `LogicalResult mutate(StorageAllocator &allocator, Args ...&& args)` + +Let's define a simple storage for recursive types, where a type is identified by +its name and can contain another type including itself. + +```c++ +/// Here we define a storage class for a RecursiveType that is identified by its +/// name and contains another type. +struct RecursiveTypeStorage : public TypeStorage { + /// The type is uniquely identified by its name. Note that the contained type + /// is _not_ a part of the key. + using KeyTy = StringRef; + + /// Construct the storage from the type name. Explicitly initialize the + /// containedType to nullptr, which is used as marker for the mutable + /// component being not yet initialized. + RecursiveTypeStorage(StringRef name) : name(name), containedType(nullptr) {} + + /// Define the comparison function. + bool operator==(const KeyTy &key) const { return key == name; } + + /// Define a construction method for creating a new instance of the storage. + static RecursiveTypeStorage *construct(StorageAllocator &allocator, + const KeyTy &key) { + // Note that the key string is copied into the allocator to ensure it + // remains live as long as the storage itself. + return new (allocator.allocate()) + RecursiveTypeStorage(allocator.copyInto(key)); + } + + /// Define a mutation method for changing the type after it is created. In + /// many cases, we only want to set the mutable component once and reject + /// any further modification, which can be achieved by returning failure from + /// this function. + LogicalResult mutate(StorageAllocator &, Type body) { + // If the contained type has been initialized already, and the call tries + // to change it, reject the change. + if (containedType && containedType != body) + return failure(); + + // Change the body successfully. + containedType = body; + return success(); + } + + StringRef name; + Type containedType; +}; +``` + +##### Type class definition + +Having defined the storage class, we can define the type class itself. This is +similar to parametric types. `Type::TypeBase` provides a `mutate` method that +forwards its arguments to the `mutate` method of the storage and ensures the +modification happens under lock. + +```c++ +class RecursiveType : public Type::TypeBase { +public: + /// Inherit parent constructors. + using Base::Base; + + /// This static method is used to support type inquiry through isa, cast, + /// and dyn_cast. + static bool kindof(unsigned kind) { return kind == MyTypes::Recursive; } + + /// Creates an instance of the Recursive type. This only takes the type name + /// and returns the type with uninitialized body. + static RecursiveType get(MLIRContext *ctx, StringRef name) { + // Call into the base to get a uniqued instance of this type. The parameter + // (name) is passed after the kind. + return Base::get(ctx, MyTypes::Recursive, name); + } + + /// Now we can change the mutable component of the type. This is an instance + /// method callable on an already existing RecursiveType. + void setBody(Type body) { + // Call into the base to mutate the type. + LogicalResult result = Base::mutate(body); + // Most types expect mutation to always succeed, but types can implement + // custom logic for handling mutation failures. + assert(succeeded(result) && + "attempting to change the body of an already-initialized type"); + // Avoid unused-variable warning when building without assertions. + (void) result; + } + + /// Returns the contained type, which may be null if it has not been + /// initialized yet. + Type getBody() { + return getImpl()->containedType; + } + + /// Returns the name. + StringRef getName() { + return getImpl()->name; + } +}; +``` + ### Registering types with a Dialect Once the dialect types have been defined, they must then be registered with a diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h --- a/mlir/include/mlir/IR/AttributeSupport.h +++ b/mlir/include/mlir/IR/AttributeSupport.h @@ -139,6 +139,13 @@ kind, std::forward(args)...); } + template + static LogicalResult mutate(MLIRContext *ctx, ImplType *impl, + Args &&...args) { + assert(impl && "cannot mutate null attribute"); + return ctx->getAttributeUniquer().mutate(impl, std::forward(args)...); + } + private: /// Initialize the given attribute storage instance. static void initializeAttributeStorage(AttributeStorage *storage, diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -48,10 +48,10 @@ /// Attributes are known-constant values of operations and functions. /// -/// Instances of the Attribute class are references to immutable, uniqued, -/// and immortal values owned by MLIRContext. As such, an Attribute is a thin -/// wrapper around an underlying storage pointer. Attributes are usually passed -/// by value. +/// Instances of the Attribute class are references to immortal key-value pairs +/// with immutable, uniqued key owned by MLIRContext. As such, an Attribute is a +/// thin wrapper around an underlying storage pointer. Attributes are usually +/// passed by value. class Attribute { public: /// Integer identifier for all the concrete attribute kinds. diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h --- a/mlir/include/mlir/IR/StorageUniquerSupport.h +++ b/mlir/include/mlir/IR/StorageUniquerSupport.h @@ -105,6 +105,14 @@ return UniquerT::template get(loc.getContext(), kind, args...); } + /// Mutate the current storage instance. This will not change the unique key. + /// The arguments are forwarded to 'ConcreteT::mutate'. + template + LogicalResult mutate(Args &&...args) { + return UniquerT::mutate(this->getContext(), getImpl(), + std::forward(args)...); + } + /// Default implementation that just returns success. template static LogicalResult verifyConstructionInvariants(Args... args) { diff --git a/mlir/include/mlir/IR/TypeSupport.h b/mlir/include/mlir/IR/TypeSupport.h --- a/mlir/include/mlir/IR/TypeSupport.h +++ b/mlir/include/mlir/IR/TypeSupport.h @@ -132,6 +132,15 @@ }, kind, std::forward(args)...); } + + /// Change the mutable component of the given type instance in the provided + /// context. + template + static LogicalResult mutate(MLIRContext *ctx, ImplType *impl, + Args &&...args) { + assert(impl && "cannot mutate null type"); + return ctx->getTypeUniquer().mutate(impl, std::forward(args)...); + } }; } // namespace detail diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -27,15 +27,17 @@ struct OpaqueTypeStorage; } // namespace detail -/// Instances of the Type class are immutable and uniqued. They wrap a pointer -/// to the storage object owned by MLIRContext. Therefore, instances of Type -/// are passed around by value. +/// Instances of the Type class are uniqued, have an immutable identifier and an +/// optional mutable component. They wrap a pointer to the storage object owned +/// by MLIRContext. Therefore, instances of Type are passed around by value. /// /// Some types are "primitives" meaning they do not have any parameters, for /// example the Index type. Parametric types have additional information that /// differentiates the types of the same kind between them, for example the /// Integer type has bitwidth, making i8 and i16 belong to the same kind by be -/// different instances of the IntegerType. +/// different instances of the IntegerType. Type parameters are part of the +/// unique immutable key. The mutable component of the type can be modified +/// after the type is created, but cannot affect the identity of the type. /// /// Types are constructed and uniqued via the 'detail::TypeUniquer' class. /// @@ -62,6 +64,7 @@ /// - The type kind (for LLVM-style RTTI). /// - The dialect that defined the type. /// - Any parameters of the type. +/// - An optional mutable component. /// For non-parametric types, a convenience DefaultTypeStorage is provided. /// Parametric storage types must derive TypeStorage and respect the following: /// - Define a type alias, KeyTy, to a type that uniquely identifies the @@ -75,11 +78,14 @@ /// - Provide a method, 'bool operator==(const KeyTy &) const', to /// compare the storage instance against an instance of the key type. /// -/// - Provide a construction method: +/// - Provide a static construction method: /// 'DerivedStorage *construct(TypeStorageAllocator &, const KeyTy &key)' /// that builds a unique instance of the derived storage. The arguments to /// this function are an allocator to store any uniqued data within the /// context and the key type for this storage. +/// +/// - If they have a mutable component, this component must not be a part of +// the key. class Type { public: /// Integer identifier for all the concrete type kinds. diff --git a/mlir/include/mlir/Support/StorageUniquer.h b/mlir/include/mlir/Support/StorageUniquer.h --- a/mlir/include/mlir/Support/StorageUniquer.h +++ b/mlir/include/mlir/Support/StorageUniquer.h @@ -10,6 +10,7 @@ #define MLIR_SUPPORT_STORAGEUNIQUER_H #include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" #include "llvm/ADT/DenseSet.h" #include "llvm/Support/Allocator.h" @@ -60,6 +61,20 @@ /// that is called when erasing a storage instance. This should cleanup any /// fields of the storage as necessary and not attempt to free the memory /// of the storage itself. +/// +/// Storage classes may have an optional mutable component, which must not take +/// part in the unique immutable key. In this case, storage classes may be +/// mutated with `mutate` and must additionally respect the following: +/// - Provide a mutation method: +/// 'LogicalResult mutate(StorageAllocator &, <...>)' +/// that is called when mutating a storage instance. The first argument is +/// an allocator to store any mutable data, and the remaining arguments are +/// forwarded from the call site. The storage can be mutated at any time +/// after creation. Care must be taken to avoid excessive mutation since +/// the allocated storage can keep containing previous states. The return +/// value of the function is used to indicate whether the mutation was +/// successful, e.g., to limit the number of mutations or enable deferred +/// one-time assignment of the mutable component. class StorageUniquer { public: StorageUniquer(); @@ -166,6 +181,17 @@ return static_cast(getImpl(kind, ctorFn)); } + /// Changes the mutable component of 'storage' by forwarding the trailing + /// arguments to the 'mutate' function of the derived class. + template + LogicalResult mutate(Storage *storage, Args &&...args) { + auto mutationFn = [&](StorageAllocator &allocator) -> LogicalResult { + return static_cast(*storage).mutate( + allocator, std::forward(args)...); + }; + return mutateImpl(mutationFn); + } + /// Erases a uniqued instance of 'Storage'. This function is used for derived /// types that have complex storage or uniquing constraints. template @@ -206,6 +232,10 @@ function_ref isEqual, function_ref cleanupFn); + /// Implementation for mutating an instance of a derived storage. + LogicalResult + mutateImpl(function_ref mutationFn); + /// The internal implementation class. std::unique_ptr impl; diff --git a/mlir/lib/Support/StorageUniquer.cpp b/mlir/lib/Support/StorageUniquer.cpp --- a/mlir/lib/Support/StorageUniquer.cpp +++ b/mlir/lib/Support/StorageUniquer.cpp @@ -124,6 +124,16 @@ storageTypes.erase(existing); } + /// Mutates an instance of a derived storage in a thread-safe way. + LogicalResult + mutate(function_ref mutationFn) { + if (!threadingIsEnabled) + return mutationFn(allocator); + + llvm::sys::SmartScopedWriter lock(mutex); + return mutationFn(allocator); + } + //===--------------------------------------------------------------------===// // Instance Storage //===--------------------------------------------------------------------===// @@ -214,3 +224,9 @@ function_ref cleanupFn) { impl->erase(kind, hashValue, isEqual, cleanupFn); } + +/// Implementation for mutating an instance of a derived storage. +LogicalResult StorageUniquer::mutateImpl( + function_ref mutationFn) { + return impl->mutate(mutationFn); +} diff --git a/mlir/test/IR/recursive-type.mlir b/mlir/test/IR/recursive-type.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/recursive-type.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-opt %s -test-recursive-types | FileCheck %s + +// CHECK-LABEL: @roundtrip +func @roundtrip() { + // CHECK: !test.test_rec> + "test.dummy_op_for_roundtrip"() : () -> !test.test_rec> + // CHECK: !test.test_rec> + "test.dummy_op_for_roundtrip"() : () -> !test.test_rec> + return +} + +// CHECK-LABEL: @create +func @create() { + // CHECK: !test.test_rec> + return +} diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -16,6 +16,7 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringSwitch.h" using namespace mlir; @@ -137,19 +138,73 @@ >(); addInterfaces(); - addTypes(); + addTypes(); allowUnknownOperations(); } -Type TestDialect::parseType(DialectAsmParser &parser) const { - if (failed(parser.parseKeyword("test_type"))) +static Type parseTestType(DialectAsmParser &parser, + llvm::SetVector &stack) { + StringRef typeTag; + if (failed(parser.parseKeyword(&typeTag))) + return Type(); + + if (typeTag == "test_type") + return TestType::get(parser.getBuilder().getContext()); + + if (typeTag != "test_rec") + return Type(); + + StringRef name; + if (parser.parseLess() || parser.parseKeyword(&name)) + return Type(); + auto rec = TestRecursiveType::create(parser.getBuilder().getContext(), name); + + // If this type already has been parsed above in the stack, expect just the + // name. + if (stack.contains(rec)) { + if (failed(parser.parseGreater())) + return Type(); + return rec; + } + + // Otherwise, parse the body and update the type. + if (failed(parser.parseComma())) + return Type(); + stack.insert(rec); + Type subtype = parseTestType(parser, stack); + stack.pop_back(); + if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype))) return Type(); - return TestType::get(getContext()); + + return rec; +} + +Type TestDialect::parseType(DialectAsmParser &parser) const { + llvm::SetVector stack; + return parseTestType(parser, stack); +} + +static void printTestType(Type type, DialectAsmPrinter &printer, + llvm::SetVector &stack) { + if (type.isa()) { + printer << "test_type"; + return; + } + + auto rec = type.cast(); + printer << "test_rec<" << rec.getName(); + if (!stack.contains(rec)) { + printer << ", "; + stack.insert(rec); + printTestType(rec.getBody(), printer, stack); + stack.pop_back(); + } + printer << ">"; } void TestDialect::printType(Type type, DialectAsmPrinter &printer) const { - assert(type.isa() && "unexpected type"); - printer << "test_type"; + llvm::SetVector stack; + printTestType(type, printer, stack); } LogicalResult TestDialect::verifyOperationAttribute(Operation *op, diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h --- a/mlir/test/lib/Dialect/Test/TestTypes.h +++ b/mlir/test/lib/Dialect/Test/TestTypes.h @@ -39,6 +39,60 @@ emitRemark(loc) << *this << " - TestC"; } }; + +/// Storage for simple named recursive types, where the type is identified by +/// its name and can "contain" another type, including itself. +struct TestRecursiveTypeStorage : public TypeStorage { + using KeyTy = StringRef; + + explicit TestRecursiveTypeStorage(StringRef key) : name(key), body(Type()) {} + + bool operator==(const KeyTy &other) const { return name == other; } + + static TestRecursiveTypeStorage *construct(TypeStorageAllocator &allocator, + const KeyTy &key) { + return new (allocator.allocate()) + TestRecursiveTypeStorage(allocator.copyInto(key)); + } + + LogicalResult mutate(TypeStorageAllocator &allocator, Type newBody) { + // Cannot set a different body than before. + if (body && body != newBody) + return failure(); + + body = newBody; + return success(); + } + + StringRef name; + Type body; +}; + +/// Simple recursive type identified by its name and pointing to another named +/// type, potentially itself. This requires the body to be mutated separately +/// from type creation. +class TestRecursiveType + : public Type::TypeBase { +public: + using Base::Base; + + static bool kindof(unsigned kind) { + return kind == Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE + 1; + } + + static TestRecursiveType create(MLIRContext *ctx, StringRef name) { + return Base::get(ctx, Type::Kind::FIRST_PRIVATE_EXPERIMENTAL_9_TYPE + 1, + name); + } + + /// Body getter and setter. + LogicalResult setBody(Type body) { return Base::mutate(body); } + Type getBody() { return getImpl()->body; } + + /// Name/key getter. + StringRef getName() { return getImpl()->name; } +}; + } // end namespace mlir #endif // MLIR_TESTTYPES_H diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt --- a/mlir/test/lib/IR/CMakeLists.txt +++ b/mlir/test/lib/IR/CMakeLists.txt @@ -5,6 +5,7 @@ TestMatchers.cpp TestSideEffects.cpp TestSymbolUses.cpp + TestTypes.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/IR/TestTypes.cpp b/mlir/test/lib/IR/TestTypes.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/IR/TestTypes.cpp @@ -0,0 +1,78 @@ +//===- TestTypes.cpp - Test passes for MLIR types -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestTypes.h" +#include "TestDialect.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +struct TestRecursiveTypesPass + : public PassWrapper { + LogicalResult createIRWithTypes(); + + void runOnFunction() override { + FuncOp func = getFunction(); + + // Just make sure recurisve types are printed and parsed. + if (func.getName() == "roundtrip") + return; + + // Create a recursive type and print it as a part of a dummy op. + if (func.getName() == "create") { + if (failed(createIRWithTypes())) + signalPassFailure(); + return; + } + + // Unknown key. + func.emitOpError() << "unexpected function name"; + signalPassFailure(); + } +}; +} // end namespace + +LogicalResult TestRecursiveTypesPass::createIRWithTypes() { + MLIRContext *ctx = &getContext(); + FuncOp func = getFunction(); + auto type = TestRecursiveType::create(ctx, "some_long_and_unique_name"); + if (failed(type.setBody(type))) + return func.emitError("expected to be able to set the type body"); + + // Setting the same body is fine. + if (failed(type.setBody(type))) + return func.emitError( + "expected to be able to set the type body to the same value"); + + // Setting a different body is not. + if (succeeded(type.setBody(IndexType::get(ctx)))) + return func.emitError( + "not expected to be able to change function body more than once"); + + // Expecting to get the same type for the same name. + auto other = TestRecursiveType::create(ctx, "some_long_and_unique_name"); + if (type != other) + return func.emitError("expected type name to be the uniquing key"); + + // Create the op to check how the type is printed. + OperationState state(func.getLoc(), "test.dummy_type_test_op"); + state.addTypes(type); + func.getBody().front().push_front(Operation::create(state)); + + return success(); +} + +namespace mlir { + +void registerTestRecursiveTypesPass() { + PassRegistration reg( + "test-recursive-types", "Test support for recursive types"); +} + +} // end namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -63,6 +63,7 @@ void registerTestMemRefStrideCalculation(); void registerTestOpaqueLoc(); void registerTestPreparationPassWithAllowedMemrefResults(); +void registerTestRecursiveTypesPass(); void registerTestReducer(); void registerTestGpuParallelLoopMappingPass(); void registerTestSCFUtilsPass(); @@ -138,6 +139,7 @@ registerTestMemRefStrideCalculation(); registerTestOpaqueLoc(); registerTestPreparationPassWithAllowedMemrefResults(); + registerTestRecursiveTypesPass(); registerTestReducer(); registerTestGpuParallelLoopMappingPass(); registerTestSCFUtilsPass();