Index: mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -264,7 +264,8 @@ /// structs, but does not in uniquing of identified structs. class LLVMStructType : public Type::TypeBase { + DataLayoutTypeInterface::Trait, + TypeTrait::IsMutable> { public: /// Inherit base constructors. using Base::Base; Index: mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h =================================================================== --- mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -275,8 +275,9 @@ /// In the above, expressing recursive struct types is accomplished by giving a /// recursive struct a unique identified and using that identifier in the struct /// definition for recursive references. -class StructType : public Type::TypeBase { +class StructType + : public Type::TypeBase { public: using Base::Base; Index: mlir/include/mlir/IR/Attributes.h =================================================================== --- mlir/include/mlir/IR/Attributes.h +++ mlir/include/mlir/IR/Attributes.h @@ -231,6 +231,18 @@ friend InterfaceBase; }; +//===----------------------------------------------------------------------===// +// Core AttributeTrait +//===----------------------------------------------------------------------===// + +/// This trait is used to determine if an attribute is mutable or not. It is +/// attached on an attribute if the corresponding ImplType defines a `mutate` +/// function with proper signature. +namespace AttributeTrait { +template +using IsMutable = detail::StorageUserTrait::IsMutable; +} // namespace AttributeTrait + } // namespace mlir. namespace llvm { Index: mlir/include/mlir/IR/StorageUniquerSupport.h =================================================================== --- mlir/include/mlir/IR/StorageUniquerSupport.h +++ mlir/include/mlir/IR/StorageUniquerSupport.h @@ -53,6 +53,16 @@ } }; +namespace StorageUserTrait { +/// This trait is used to determine if a storage user, like Type, is mutable +/// or not. A storage user is mutable if ImplType of the derived class defines +/// a `mutate` function with a proper signature. Note that this trait is not +/// supposed to be used publicly. Users should use alias names like +/// `TypeTrait::IsMutable` instead. +template +struct IsMutable : public StorageUserTraitBase {}; +} // namespace StorageUserTrait + //===----------------------------------------------------------------------===// // StorageUserBase //===----------------------------------------------------------------------===// @@ -173,6 +183,10 @@ /// Mutate the current storage instance. This will not change the unique key. /// The arguments are forwarded to 'ConcreteT::mutate'. template LogicalResult mutate(Args &&...args) { + static_assert(std::is_base_of, + ConcreteT>::value, + "The `mutate` function expects mutable trait " + "(e.g. TypeTrait::IsMutable) to be attached on parent."); return UniquerT::template mutate(this->getContext(), getImpl(), std::forward(args)...); } Index: mlir/include/mlir/IR/Types.h =================================================================== --- mlir/include/mlir/IR/Types.h +++ mlir/include/mlir/IR/Types.h @@ -222,6 +222,18 @@ friend InterfaceBase; }; +//===----------------------------------------------------------------------===// +// Core TypeTrait +//===----------------------------------------------------------------------===// + +/// This trait is used to determine if a type is mutable or not. It is attached +/// on a type if the corresponding ImplType defines a `mutate` function with +/// a proper signature. +namespace TypeTrait { +template +using IsMutable = detail::StorageUserTrait::IsMutable; +} // namespace TypeTrait + //===----------------------------------------------------------------------===// // Type Utils //===----------------------------------------------------------------------===// Index: mlir/lib/IR/SubElementInterfaces.cpp =================================================================== --- mlir/lib/IR/SubElementInterfaces.cpp +++ mlir/lib/IR/SubElementInterfaces.cpp @@ -8,12 +8,16 @@ #include "mlir/IR/SubElementInterfaces.h" +#include "llvm/ADT/DenseSet.h" + using namespace mlir; template static void walkSubElementsImpl(InterfaceT interface, function_ref walkAttrsFn, - function_ref walkTypesFn) { + function_ref walkTypesFn, + DenseSet &visitedAttrs, + DenseSet &visitedTypes) { interface.walkImmediateSubElements( [&](Attribute attr) { // Guard against potentially null inputs. This removes the need for the @@ -21,9 +25,17 @@ if (!attr) return; + // Avoid infinite recursion when visiting sub attributes later, if this + // is a mutable attribute. + if (LLVM_UNLIKELY(attr.hasTrait())) { + if (!visitedAttrs.insert(attr).second) + return; + } + // Walk any sub elements first. if (auto interface = attr.dyn_cast()) - walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn); + walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs, + visitedTypes); // Walk this attribute. walkAttrsFn(attr); @@ -34,9 +46,17 @@ if (!type) return; + // Avoid infinite recursion when visiting sub types later, if this + // is a mutable type. + if (LLVM_UNLIKELY(type.hasTrait())) { + if (!visitedTypes.insert(type).second) + return; + } + // Walk any sub elements first. if (auto interface = type.dyn_cast()) - walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn); + walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs, + visitedTypes); // Walk this type. walkTypesFn(type); @@ -47,14 +67,20 @@ function_ref walkAttrsFn, function_ref walkTypesFn) { assert(walkAttrsFn && walkTypesFn && "expected valid walk functions"); - walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn); + DenseSet visitedAttrs; + DenseSet visitedTypes; + walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs, + visitedTypes); } void SubElementTypeInterface::walkSubElements( function_ref walkAttrsFn, function_ref walkTypesFn) { assert(walkAttrsFn && walkTypesFn && "expected valid walk functions"); - walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn); + DenseSet visitedAttrs; + DenseSet visitedTypes; + walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs, + visitedTypes); } //===----------------------------------------------------------------------===// Index: mlir/test/lib/Dialect/Test/TestTypes.h =================================================================== --- mlir/test/lib/Dialect/Test/TestTypes.h +++ mlir/test/lib/Dialect/Test/TestTypes.h @@ -130,7 +130,8 @@ /// from type creation. class TestRecursiveType : public ::mlir::Type::TypeBase { + TestRecursiveTypeStorage, + ::mlir::TypeTrait::IsMutable> { public: using Base::Base; Index: mlir/unittests/Dialect/CMakeLists.txt =================================================================== --- mlir/unittests/Dialect/CMakeLists.txt +++ mlir/unittests/Dialect/CMakeLists.txt @@ -7,6 +7,7 @@ MLIRDialect) add_subdirectory(Affine) +add_subdirectory(LLVMIR) add_subdirectory(Quant) add_subdirectory(SparseTensor) add_subdirectory(SPIRV) Index: mlir/unittests/Dialect/LLVMIR/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/unittests/Dialect/LLVMIR/CMakeLists.txt @@ -0,0 +1,7 @@ +add_mlir_unittest(MLIRLLVMIRTests + LLVMTypeTest.cpp +) +target_link_libraries(MLIRLLVMIRTests + PRIVATE + MLIRLLVMDialect + ) Index: mlir/unittests/Dialect/LLVMIR/LLVMTestBase.h =================================================================== --- /dev/null +++ mlir/unittests/Dialect/LLVMIR/LLVMTestBase.h @@ -0,0 +1,27 @@ +//===- LLVMTestBase.h - Test fixure for LLVM dialect tests ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Test fixure for LLVM dialect tests. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_UNITTEST_DIALECT_LLVMIR_LLVMTESTBASE_H +#define MLIR_UNITTEST_DIALECT_LLVMIR_LLVMTESTBASE_H + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/MLIRContext.h" +#include "gtest/gtest.h" + +class LLVMIRTest : public ::testing::Test { +protected: + LLVMIRTest() { context.getOrLoadDialect(); } + + mlir::MLIRContext context; +}; + +#endif Index: mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp =================================================================== --- /dev/null +++ mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp @@ -0,0 +1,20 @@ +//===- LLVMTypeTest.cpp - Tests for LLVM types ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "LLVMTestBase.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/SubElementInterfaces.h" + +using namespace mlir; +using namespace mlir::LLVM; + +TEST_F(LLVMIRTest, IsStructTypeMutable) { + auto structTy = LLVMStructType::getIdentified(&context, "foo"); + ASSERT_TRUE(bool(structTy)); + ASSERT_TRUE(structTy.hasTrait()); +}