Index: mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -264,7 +264,8 @@ /// structs, but does not in uniquing of identified structs. class LLVMStructType : public Type::TypeBase { + DataLayoutTypeInterface::Trait, + TypeTrait::IsMutableType> { public: /// Inherit base constructors. using Base::Base; Index: mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h =================================================================== --- mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -276,7 +276,8 @@ /// recursive struct a unique identified and using that identifier in the struct /// definition for recursive references. class StructType : public Type::TypeBase { + detail::StructTypeStorage, + TypeTrait::IsMutableType> { public: using Base::Base; Index: mlir/include/mlir/IR/AttributeSupport.h =================================================================== --- mlir/include/mlir/IR/AttributeSupport.h +++ mlir/include/mlir/IR/AttributeSupport.h @@ -23,6 +23,11 @@ class MLIRContext; class Type; +namespace AttributeTrait { +template +class IsMutableAttr; +} // namespace AttributeTrait + //===----------------------------------------------------------------------===// // AbstractAttribute //===----------------------------------------------------------------------===// @@ -257,6 +262,10 @@ static typename std::enable_if_t< !std::is_same::value> registerAttribute(MLIRContext *ctx, TypeID typeID) { + static_assert(std::is_base_of, T>::value == + detect_mutate_function::value, + "IsMutableAttr trait requires `mutate` function to be defined" + " in the ImplType (and vice versa)."); ctx->getAttributeUniquer() .registerParametricStorageType(typeID); } @@ -267,6 +276,10 @@ static typename std::enable_if_t< std::is_same::value> registerAttribute(MLIRContext *ctx, TypeID typeID) { + static_assert(std::is_base_of, T>::value == + detect_mutate_function::value, + "IsMutableAttr trait requires `mutate` function to be defined" + " in the ImplType (and vice versa)."); ctx->getAttributeUniquer() .registerSingletonStorageType( typeID, [ctx, typeID](AttributeStorage *storage) { Index: mlir/include/mlir/IR/Attributes.h =================================================================== --- mlir/include/mlir/IR/Attributes.h +++ mlir/include/mlir/IR/Attributes.h @@ -231,6 +231,18 @@ friend InterfaceBase; }; +//===----------------------------------------------------------------------===// +// Core AttributeTrait +//===----------------------------------------------------------------------===// + +/// This trait is used to determine if an attribute is mutable or not. It is +/// attached on an attribute if only if the corresponding ImplType defines a +/// `mutate` function with proper signature. +namespace AttributeTrait { +template +class IsMutableAttr : public TraitBase {}; +} // namespace AttributeTrait + } // namespace mlir. namespace llvm { Index: mlir/include/mlir/IR/StorageUniquerSupport.h =================================================================== --- mlir/include/mlir/IR/StorageUniquerSupport.h +++ mlir/include/mlir/IR/StorageUniquerSupport.h @@ -57,6 +57,33 @@ // StorageUserBase //===----------------------------------------------------------------------===// +/// Utility to check if a storage is mutable. +/// A storage is considered mutable if it has a non-static member function +/// ``` +/// LogicalResult mutate(AllocatorType &, ...); +/// ``` +template +class detect_mutate_function { + template + struct match_variadic_signature {}; + + template + struct match_variadic_signature { + using type = int; + }; + + template + using has_mutate_t = + typename match_variadic_signature::type; + +public: + static constexpr bool value = llvm::is_detected::value; +}; + namespace storage_user_base_impl { /// Returns true if this given Trait ID matches the IDs of any of the provided /// trait types `Traits`. Index: mlir/include/mlir/IR/TypeSupport.h =================================================================== --- mlir/include/mlir/IR/TypeSupport.h +++ mlir/include/mlir/IR/TypeSupport.h @@ -21,6 +21,11 @@ class Dialect; class MLIRContext; +namespace TypeTrait { +template +class IsMutableType; +} // namespace TypeTrait + //===----------------------------------------------------------------------===// // AbstractType //===----------------------------------------------------------------------===// @@ -232,6 +237,10 @@ static typename std::enable_if_t< !std::is_same::value> registerType(MLIRContext *ctx, TypeID typeID) { + static_assert(std::is_base_of, T>::value == + detect_mutate_function::value, + "IsMutableType trait requires `mutate` function to be defined" + " in the ImplType (and vice versa)."); ctx->getTypeUniquer().registerParametricStorageType( typeID); } @@ -242,6 +251,10 @@ static typename std::enable_if_t< std::is_same::value> registerType(MLIRContext *ctx, TypeID typeID) { + static_assert(std::is_base_of, T>::value == + detect_mutate_function::value, + "IsMutableType trait requires `mutate` function to be defined" + " in the ImplType (and vice versa)."); ctx->getTypeUniquer().registerSingletonStorageType( typeID, [&ctx, typeID](TypeStorage *storage) { storage->initialize(AbstractType::lookup(typeID, ctx)); Index: mlir/include/mlir/IR/Types.h =================================================================== --- mlir/include/mlir/IR/Types.h +++ mlir/include/mlir/IR/Types.h @@ -222,6 +222,19 @@ friend InterfaceBase; }; +//===----------------------------------------------------------------------===// +// Core TypeTrait +//===----------------------------------------------------------------------===// + +/// This trait is used to determine if a type is mutable or not. It is attached +/// on a type if only if the corresponding ImplType defines a `mutate` function +/// with proper signature. +namespace TypeTrait { +template +class IsMutableType : public TypeTrait::TraitBase { +}; +} // namespace TypeTrait + //===----------------------------------------------------------------------===// // Type Utils //===----------------------------------------------------------------------===// Index: mlir/lib/IR/SubElementInterfaces.cpp =================================================================== --- mlir/lib/IR/SubElementInterfaces.cpp +++ mlir/lib/IR/SubElementInterfaces.cpp @@ -8,12 +8,16 @@ #include "mlir/IR/SubElementInterfaces.h" +#include "llvm/ADT/DenseSet.h" + using namespace mlir; template static void walkSubElementsImpl(InterfaceT interface, function_ref walkAttrsFn, - function_ref walkTypesFn) { + function_ref walkTypesFn, + DenseSet &visitedAttrs, + DenseSet &visitedTypes) { interface.walkImmediateSubElements( [&](Attribute attr) { // Guard against potentially null inputs. This removes the need for the @@ -21,9 +25,17 @@ if (!attr) return; + // Avoid infinite recursion when visiting sub attributes later, if this + // is a mutable attribute. + if (LLVM_UNLIKELY(attr.hasTrait())) { + if (!visitedAttrs.insert(attr).second) + return; + } + // Walk any sub elements first. if (auto interface = attr.dyn_cast()) - walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn); + walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs, + visitedTypes); // Walk this attribute. walkAttrsFn(attr); @@ -34,9 +46,17 @@ if (!type) return; + // Avoid infinite recursion when visiting sub types later, if this + // is a mutable type. + if (LLVM_UNLIKELY(type.hasTrait())) { + if (!visitedTypes.insert(type).second) + return; + } + // Walk any sub elements first. if (auto interface = type.dyn_cast()) - walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn); + walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs, + visitedTypes); // Walk this type. walkTypesFn(type); @@ -47,14 +67,20 @@ function_ref walkAttrsFn, function_ref walkTypesFn) { assert(walkAttrsFn && walkTypesFn && "expected valid walk functions"); - walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn); + DenseSet visitedAttrs; + DenseSet visitedTypes; + walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs, + visitedTypes); } void SubElementTypeInterface::walkSubElements( function_ref walkAttrsFn, function_ref walkTypesFn) { assert(walkAttrsFn && walkTypesFn && "expected valid walk functions"); - walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn); + DenseSet visitedAttrs; + DenseSet visitedTypes; + walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs, + visitedTypes); } //===----------------------------------------------------------------------===// Index: mlir/test/lib/Dialect/Test/TestTypes.h =================================================================== --- mlir/test/lib/Dialect/Test/TestTypes.h +++ mlir/test/lib/Dialect/Test/TestTypes.h @@ -130,7 +130,8 @@ /// from type creation. class TestRecursiveType : public ::mlir::Type::TypeBase { + TestRecursiveTypeStorage, + ::mlir::TypeTrait::IsMutableType> { public: using Base::Base; Index: mlir/unittests/Dialect/CMakeLists.txt =================================================================== --- mlir/unittests/Dialect/CMakeLists.txt +++ mlir/unittests/Dialect/CMakeLists.txt @@ -7,6 +7,7 @@ MLIRDialect) add_subdirectory(Affine) +add_subdirectory(LLVMIR) add_subdirectory(Quant) add_subdirectory(SparseTensor) add_subdirectory(SPIRV) Index: mlir/unittests/Dialect/LLVMIR/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/unittests/Dialect/LLVMIR/CMakeLists.txt @@ -0,0 +1,7 @@ +add_mlir_unittest(MLIRLLVMIRTests + LLVMTypeTest.cpp +) +target_link_libraries(MLIRLLVMIRTests + PRIVATE + MLIRLLVMDialect + ) Index: mlir/unittests/Dialect/LLVMIR/LLVMTestBase.h =================================================================== --- /dev/null +++ mlir/unittests/Dialect/LLVMIR/LLVMTestBase.h @@ -0,0 +1,27 @@ +//===- LLVMTestBase.h - Test fixure for LLVM dialect tests ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Test fixure for LLVM dialect tests. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_UNITTEST_DIALECT_LLVMIR_LLVMTESTBASE_H +#define MLIR_UNITTEST_DIALECT_LLVMIR_LLVMTESTBASE_H + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/MLIRContext.h" +#include "gtest/gtest.h" + +class LLVMIRTest : public ::testing::Test { +protected: + LLVMIRTest() { context.getOrLoadDialect(); } + + mlir::MLIRContext context; +}; + +#endif Index: mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp =================================================================== --- /dev/null +++ mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp @@ -0,0 +1,20 @@ +//===- LLVMTypeTest.cpp - Tests for LLVM types ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "LLVMTestBase.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/SubElementInterfaces.h" + +using namespace mlir; +using namespace mlir::LLVM; + +TEST_F(LLVMIRTest, IsStructTypeMutable) { + auto structTy = LLVMStructType::getIdentified(&context, "foo"); + ASSERT_TRUE(bool(structTy)); + ASSERT_TRUE(structTy.hasTrait()); +}