Index: mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -264,7 +264,8 @@ /// structs, but does not in uniquing of identified structs. class LLVMStructType : public Type::TypeBase { + DataLayoutTypeInterface::Trait, + TypeTrait::IsMutable> { public: /// Inherit base constructors. using Base::Base; Index: mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h =================================================================== --- mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -275,8 +275,9 @@ /// In the above, expressing recursive struct types is accomplished by giving a /// recursive struct a unique identified and using that identifier in the struct /// definition for recursive references. -class StructType : public Type::TypeBase { +class StructType + : public Type::TypeBase { public: using Base::Base; Index: mlir/include/mlir/IR/AttributeSupport.h =================================================================== --- mlir/include/mlir/IR/AttributeSupport.h +++ mlir/include/mlir/IR/AttributeSupport.h @@ -93,7 +93,8 @@ /// Give StorageUserBase access to the mutable lookup. template class... Traits> + typename UniquerT, template class MutableTrait, + template class... Traits> friend class detail::StorageUserBase; /// Look up the specified abstract attribute in the MLIRContext and return a Index: mlir/include/mlir/IR/Attributes.h =================================================================== --- mlir/include/mlir/IR/Attributes.h +++ mlir/include/mlir/IR/Attributes.h @@ -15,6 +15,11 @@ namespace mlir { class StringAttr; +namespace AttributeTrait { +template +class IsMutable; +} // namespace AttributeTrait + /// Attributes are known-constant values of operations. /// /// Instances of the Attribute class are references to immortal key-value pairs @@ -26,8 +31,10 @@ /// Utility class for implementing attributes. template class... Traits> - using AttrBase = detail::StorageUserBase; + using AttrBase = + detail::StorageUserBase; using ImplType = AttributeStorage; using ValueType = void; @@ -231,6 +238,18 @@ friend InterfaceBase; }; +//===----------------------------------------------------------------------===// +// Core AttributeTrait +//===----------------------------------------------------------------------===// + +/// This trait is used to determine if an attribute is mutable or not. It is +/// attached on an attribute if only if the corresponding ImplType defines a +/// `mutate` function with proper signature. +namespace AttributeTrait { +template +class IsMutable : public TraitBase {}; +} // namespace AttributeTrait + } // namespace mlir. namespace llvm { Index: mlir/include/mlir/IR/StorageUniquerSupport.h =================================================================== --- mlir/include/mlir/IR/StorageUniquerSupport.h +++ mlir/include/mlir/IR/StorageUniquerSupport.h @@ -80,13 +80,15 @@ /// StorageUniquer. Clients are not expected to interact with this class /// directly. template class... Traits> + typename UniquerT, template class MutableTrait, + template class... Traits> class StorageUserBase : public BaseT, public Traits... { public: using BaseT::BaseT; /// Utility declarations for the concrete attribute class. - using Base = StorageUserBase; + using Base = StorageUserBase; using ImplType = StorageT; using HasTraitFn = bool (*)(TypeID); @@ -173,6 +175,9 @@ /// Mutate the current storage instance. This will not change the unique key. /// The arguments are forwarded to 'ConcreteT::mutate'. template LogicalResult mutate(Args &&...args) { + static_assert(std::is_base_of, ConcreteT>::value, + "The `mutate` function expects mutable trait " + "(e.g. TypeTrait::IsMutable) to be attached on parent."); return UniquerT::template mutate(this->getContext(), getImpl(), std::forward(args)...); } Index: mlir/include/mlir/IR/TypeSupport.h =================================================================== --- mlir/include/mlir/IR/TypeSupport.h +++ mlir/include/mlir/IR/TypeSupport.h @@ -88,7 +88,8 @@ /// Give StorageUserBase access to the mutable lookup. template class... Traits> + typename UniquerT, template class MutableTrait, + template class... Traits> friend class detail::StorageUserBase; /// Look up the specified abstract type in the MLIRContext and return a Index: mlir/include/mlir/IR/Types.h =================================================================== --- mlir/include/mlir/IR/Types.h +++ mlir/include/mlir/IR/Types.h @@ -15,6 +15,12 @@ #include "llvm/Support/PointerLikeTypeTraits.h" namespace mlir { + +namespace TypeTrait { +template +class IsMutable; +} // namespace TypeTrait + /// Instances of the Type class are uniqued, have an immutable identifier and an /// optional mutable component. They wrap a pointer to the storage object owned /// by MLIRContext. Therefore, instances of Type are passed around by value. @@ -75,7 +81,8 @@ template class... Traits> using TypeBase = detail::StorageUserBase; + detail::TypeUniquer, + TypeTrait::IsMutable, Traits...>; using ImplType = TypeStorage; @@ -222,6 +229,18 @@ friend InterfaceBase; }; +//===----------------------------------------------------------------------===// +// Core TypeTrait +//===----------------------------------------------------------------------===// + +/// This trait is used to determine if a type is mutable or not. It is attached +/// on a type if only if the corresponding ImplType defines a `mutate` function +/// with proper signature. +namespace TypeTrait { +template +class IsMutable : public TypeTrait::TraitBase {}; +} // namespace TypeTrait + //===----------------------------------------------------------------------===// // Type Utils //===----------------------------------------------------------------------===// Index: mlir/lib/IR/SubElementInterfaces.cpp =================================================================== --- mlir/lib/IR/SubElementInterfaces.cpp +++ mlir/lib/IR/SubElementInterfaces.cpp @@ -8,12 +8,16 @@ #include "mlir/IR/SubElementInterfaces.h" +#include "llvm/ADT/DenseSet.h" + using namespace mlir; template static void walkSubElementsImpl(InterfaceT interface, function_ref walkAttrsFn, - function_ref walkTypesFn) { + function_ref walkTypesFn, + DenseSet &visitedAttrs, + DenseSet &visitedTypes) { interface.walkImmediateSubElements( [&](Attribute attr) { // Guard against potentially null inputs. This removes the need for the @@ -21,9 +25,17 @@ if (!attr) return; + // Avoid infinite recursion when visiting sub attributes later, if this + // is a mutable attribute. + if (LLVM_UNLIKELY(attr.hasTrait())) { + if (!visitedAttrs.insert(attr).second) + return; + } + // Walk any sub elements first. if (auto interface = attr.dyn_cast()) - walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn); + walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs, + visitedTypes); // Walk this attribute. walkAttrsFn(attr); @@ -34,9 +46,17 @@ if (!type) return; + // Avoid infinite recursion when visiting sub types later, if this + // is a mutable type. + if (LLVM_UNLIKELY(type.hasTrait())) { + if (!visitedTypes.insert(type).second) + return; + } + // Walk any sub elements first. if (auto interface = type.dyn_cast()) - walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn); + walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs, + visitedTypes); // Walk this type. walkTypesFn(type); @@ -47,14 +67,20 @@ function_ref walkAttrsFn, function_ref walkTypesFn) { assert(walkAttrsFn && walkTypesFn && "expected valid walk functions"); - walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn); + DenseSet visitedAttrs; + DenseSet visitedTypes; + walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs, + visitedTypes); } void SubElementTypeInterface::walkSubElements( function_ref walkAttrsFn, function_ref walkTypesFn) { assert(walkAttrsFn && walkTypesFn && "expected valid walk functions"); - walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn); + DenseSet visitedAttrs; + DenseSet visitedTypes; + walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs, + visitedTypes); } //===----------------------------------------------------------------------===// Index: mlir/test/lib/Dialect/Test/TestTypes.h =================================================================== --- mlir/test/lib/Dialect/Test/TestTypes.h +++ mlir/test/lib/Dialect/Test/TestTypes.h @@ -130,7 +130,8 @@ /// from type creation. class TestRecursiveType : public ::mlir::Type::TypeBase { + TestRecursiveTypeStorage, + ::mlir::TypeTrait::IsMutable> { public: using Base::Base; Index: mlir/unittests/Dialect/CMakeLists.txt =================================================================== --- mlir/unittests/Dialect/CMakeLists.txt +++ mlir/unittests/Dialect/CMakeLists.txt @@ -7,6 +7,7 @@ MLIRDialect) add_subdirectory(Affine) +add_subdirectory(LLVMIR) add_subdirectory(Quant) add_subdirectory(SparseTensor) add_subdirectory(SPIRV) Index: mlir/unittests/Dialect/LLVMIR/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/unittests/Dialect/LLVMIR/CMakeLists.txt @@ -0,0 +1,7 @@ +add_mlir_unittest(MLIRLLVMIRTests + LLVMTypeTest.cpp +) +target_link_libraries(MLIRLLVMIRTests + PRIVATE + MLIRLLVMDialect + ) Index: mlir/unittests/Dialect/LLVMIR/LLVMTestBase.h =================================================================== --- /dev/null +++ mlir/unittests/Dialect/LLVMIR/LLVMTestBase.h @@ -0,0 +1,27 @@ +//===- LLVMTestBase.h - Test fixure for LLVM dialect tests ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Test fixure for LLVM dialect tests. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_UNITTEST_DIALECT_LLVMIR_LLVMTESTBASE_H +#define MLIR_UNITTEST_DIALECT_LLVMIR_LLVMTESTBASE_H + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/MLIRContext.h" +#include "gtest/gtest.h" + +class LLVMIRTest : public ::testing::Test { +protected: + LLVMIRTest() { context.getOrLoadDialect(); } + + mlir::MLIRContext context; +}; + +#endif Index: mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp =================================================================== --- /dev/null +++ mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp @@ -0,0 +1,20 @@ +//===- LLVMTypeTest.cpp - Tests for LLVM types ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "LLVMTestBase.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/SubElementInterfaces.h" + +using namespace mlir; +using namespace mlir::LLVM; + +TEST_F(LLVMIRTest, IsStructTypeMutable) { + auto structTy = LLVMStructType::getIdentified(&context, "foo"); + ASSERT_TRUE(bool(structTy)); + ASSERT_TRUE(structTy.hasTrait()); +}