Index: mlir/include/mlir/IR/AttributeSupport.h =================================================================== --- mlir/include/mlir/IR/AttributeSupport.h +++ mlir/include/mlir/IR/AttributeSupport.h @@ -41,8 +41,9 @@ /// attributes they contain. template static AbstractAttribute get(Dialect &dialect) { - return AbstractAttribute(dialect, T::getInterfaceMap(), T::getHasTraitFn(), - T::getTypeID()); + return AbstractAttribute( + dialect, T::getInterfaceMap(), T::getHasTraitFn(), T::getTypeID(), + detail::is_storage_mutable::value); } /// This method is used by Dialect objects to register attributes with @@ -51,9 +52,10 @@ /// 'get(dialect)'. static AbstractAttribute get(Dialect &dialect, detail::InterfaceMap &&interfaceMap, - HasTraitFn &&hasTrait, TypeID typeID) { + HasTraitFn &&hasTrait, TypeID typeID, + bool isMutable = false) { return AbstractAttribute(dialect, std::move(interfaceMap), - std::move(hasTrait), typeID); + std::move(hasTrait), typeID, isMutable); } /// Return the dialect this attribute was registered to. @@ -85,11 +87,14 @@ /// Return the unique identifier representing the concrete attribute class. TypeID getTypeID() const { return typeID; } + /// Returns true if the attribute is mutable. + bool isMutableAttr() const { return isMutable; }; + private: AbstractAttribute(Dialect &dialect, detail::InterfaceMap &&interfaceMap, - HasTraitFn &&hasTrait, TypeID typeID) + HasTraitFn &&hasTrait, TypeID typeID, bool isMutable) : dialect(dialect), interfaceMap(std::move(interfaceMap)), - hasTraitFn(std::move(hasTrait)), typeID(typeID) {} + hasTraitFn(std::move(hasTrait)), typeID(typeID), isMutable(isMutable) {} /// Give StorageUserBase access to the mutable lookup. template +class is_storage_mutable { + // SFINAE utilities to check the function signature. + template + struct match_function_signature : public std::false_type {}; + + template + struct match_function_signature + : public std::true_type {}; + +public: + template ::value, + int> = 0> + static char test(int); + + template + static double test(...); + + static constexpr bool value = (sizeof(test(0)) == 1); +}; + namespace storage_user_base_impl { /// Returns true if this given Trait ID matches the IDs of any of the provided /// trait types `Traits`. Index: mlir/include/mlir/IR/TypeSupport.h =================================================================== --- mlir/include/mlir/IR/TypeSupport.h +++ mlir/include/mlir/IR/TypeSupport.h @@ -39,8 +39,9 @@ /// types they contain. template static AbstractType get(Dialect &dialect) { - return AbstractType(dialect, T::getInterfaceMap(), T::getHasTraitFn(), - T::getTypeID()); + return AbstractType( + dialect, T::getInterfaceMap(), T::getHasTraitFn(), T::getTypeID(), + detail::is_storage_mutable::value); } /// This method is used by Dialect objects to register types with @@ -48,9 +49,10 @@ /// The use of this method is in general discouraged in favor of /// 'get(dialect)'; static AbstractType get(Dialect &dialect, detail::InterfaceMap &&interfaceMap, - HasTraitFn &&hasTrait, TypeID typeID) { + HasTraitFn &&hasTrait, TypeID typeID, + bool isMutable = false) { return AbstractType(dialect, std::move(interfaceMap), std::move(hasTrait), - typeID); + typeID, isMutable); } /// Return the dialect this type was registered to. @@ -80,11 +82,14 @@ /// Return the unique identifier representing the concrete type class. TypeID getTypeID() const { return typeID; } + /// Returns true if this type is mutable. + bool isMutableType() const { return isMutable; } + private: AbstractType(Dialect &dialect, detail::InterfaceMap &&interfaceMap, - HasTraitFn &&hasTrait, TypeID typeID) + HasTraitFn &&hasTrait, TypeID typeID, bool isMutable) : dialect(dialect), interfaceMap(std::move(interfaceMap)), - hasTraitFn(std::move(hasTrait)), typeID(typeID) {} + hasTraitFn(std::move(hasTrait)), typeID(typeID), isMutable(isMutable) {} /// Give StorageUserBase access to the mutable lookup. template static void walkSubElementsImpl(InterfaceT interface, function_ref walkAttrsFn, - function_ref walkTypesFn) { + function_ref walkTypesFn, + llvm::DenseSet &visitedAttrs, + llvm::DenseSet &visitedTypes) { interface.walkImmediateSubElements( [&](Attribute attr) { // Guard against potentially null inputs. This removes the need for the @@ -21,9 +25,17 @@ if (!attr) return; + // Avoid infinite recursion when visiting sub attributes later, if this + // is a mutable attribute. + if (LLVM_UNLIKELY(attr.getAbstractAttribute().isMutableAttr())) { + if (!visitedAttrs.insert(attr).second) + return; + } + // Walk any sub elements first. if (auto interface = attr.dyn_cast()) - walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn); + walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs, + visitedTypes); // Walk this attribute. walkAttrsFn(attr); @@ -34,9 +46,17 @@ if (!type) return; + // Avoid infinite recursion when visiting sub types later, if this + // is a mutable type. + if (LLVM_UNLIKELY(type.getAbstractType().isMutableType())) { + if (!visitedTypes.insert(type).second) + return; + } + // Walk any sub elements first. if (auto interface = type.dyn_cast()) - walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn); + walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs, + visitedTypes); // Walk this type. walkTypesFn(type); @@ -47,14 +67,20 @@ function_ref walkAttrsFn, function_ref walkTypesFn) { assert(walkAttrsFn && walkTypesFn && "expected valid walk functions"); - walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn); + llvm::DenseSet visitedAttrs; + llvm::DenseSet visitedTypes; + walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs, + visitedTypes); } void SubElementTypeInterface::walkSubElements( function_ref walkAttrsFn, function_ref walkTypesFn) { assert(walkAttrsFn && walkTypesFn && "expected valid walk functions"); - walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn); + llvm::DenseSet visitedAttrs; + llvm::DenseSet visitedTypes; + walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs, + visitedTypes); } //===----------------------------------------------------------------------===// Index: mlir/unittests/Dialect/CMakeLists.txt =================================================================== --- mlir/unittests/Dialect/CMakeLists.txt +++ mlir/unittests/Dialect/CMakeLists.txt @@ -7,6 +7,7 @@ MLIRDialect) add_subdirectory(Affine) +add_subdirectory(LLVMIR) add_subdirectory(Quant) add_subdirectory(SparseTensor) add_subdirectory(SPIRV) Index: mlir/unittests/Dialect/LLVMIR/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/unittests/Dialect/LLVMIR/CMakeLists.txt @@ -0,0 +1,7 @@ +add_mlir_unittest(MLIRLLVMIRTests + LLVMTypeTest.cpp +) +target_link_libraries(MLIRLLVMIRTests + PRIVATE + MLIRLLVMIR + ) Index: mlir/unittests/Dialect/LLVMIR/LLVMTestBase.h =================================================================== --- /dev/null +++ mlir/unittests/Dialect/LLVMIR/LLVMTestBase.h @@ -0,0 +1,27 @@ +//===- LLVMTestBase.h - Test fixure for LLVM dialect tests ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Test fixure for LLVM dialect tests. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_UNITTEST_DIALECT_LLVMIR_LLVMTESTBASE_H +#define MLIR_UNITTEST_DIALECT_LLVMIR_LLVMTESTBASE_H + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/MLIRContext.h" +#include "gtest/gtest.h" + +class LLVMIRTest : public ::testing::Test { +protected: + LLVMIRTest() { context.getOrLoadDialect(); } + + mlir::MLIRContext context; +}; + +#endif Index: mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp =================================================================== --- /dev/null +++ mlir/unittests/Dialect/LLVMIR/LLVMTypeTest.cpp @@ -0,0 +1,20 @@ +//===- LLVMTypeTest.cpp - Tests for LLVM types ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "LLVMTestBase.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/SubElementInterfaces.h" + +using namespace mlir; +using namespace mlir::LLVM; + +TEST_F(LLVMIRTest, IsStructTypeMutable) { + auto structTy = LLVMStructType::getIdentified(&context, "foo"); + ASSERT_TRUE(bool(structTy)); + ASSERT_TRUE(structTy.getAbstractType().isMutableType()); +}