Index: mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -14,6 +14,7 @@ #ifndef MLIR_DIALECT_LLVMIR_LLVMTYPES_H_ #define MLIR_DIALECT_LLVMIR_LLVMTYPES_H_ +#include "mlir/IR/SubElementInterfaces.h" #include "mlir/IR/Types.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" @@ -73,7 +74,8 @@ /// type. class LLVMArrayType : public Type::TypeBase { + DataLayoutTypeInterface::Trait, + SubElementTypeInterface::Trait> { public: /// Inherit base constructors. using Base::Base; @@ -111,6 +113,9 @@ unsigned getPreferredAlignment(const DataLayout &dataLayout, DataLayoutEntryListRef params) const; + + void walkImmediateSubElements(function_ref walkAttrsFn, + function_ref walkTypesFn) const; }; //===----------------------------------------------------------------------===// @@ -120,9 +125,9 @@ /// LLVM dialect function type. It consists of a single return type (unlike MLIR /// which can have multiple), a list of parameter types and can optionally be /// variadic. -class LLVMFunctionType - : public Type::TypeBase { +class LLVMFunctionType : public Type::TypeBase { public: /// Inherit base constructors. using Base::Base; @@ -150,11 +155,11 @@ LLVMFunctionType clone(TypeRange inputs, TypeRange results) const; /// Returns the result type of the function. - Type getReturnType(); + Type getReturnType() const; /// Returns the result type of the function as an ArrayRef, enabling better /// integration with generic MLIR utilities. - ArrayRef getReturnTypes(); + ArrayRef getReturnTypes() const; /// Returns the number of arguments to the function. unsigned getNumParams(); @@ -163,12 +168,15 @@ Type getParamType(unsigned i); /// Returns a list of argument types of the function. - ArrayRef getParams(); + ArrayRef getParams() const; ArrayRef params() { return getParams(); } /// Verifies that the type about to be constructed is well-formed. static LogicalResult verify(function_ref emitError, Type result, ArrayRef arguments, bool); + + void walkImmediateSubElements(function_ref walkAttrsFn, + function_ref walkTypesFn) const; }; //===----------------------------------------------------------------------===// @@ -179,9 +187,10 @@ /// object in memory. Pointers may be opaque or parameterized by the element /// type. Both opaque and non-opaque pointers are additionally parameterized by /// the address space. -class LLVMPointerType : public Type::TypeBase { +class LLVMPointerType + : public Type::TypeBase< + LLVMPointerType, Type, detail::LLVMPointerTypeStorage, + DataLayoutTypeInterface::Trait, SubElementTypeInterface::Trait> { public: /// Inherit base constructors. using Base::Base; @@ -232,6 +241,9 @@ DataLayoutEntryListRef newLayout) const; LogicalResult verifyEntries(DataLayoutEntryListRef entries, Location loc) const; + + void walkImmediateSubElements(function_ref walkAttrsFn, + function_ref walkTypesFn) const; }; //===----------------------------------------------------------------------===// @@ -264,7 +276,8 @@ /// structs, but does not in uniquing of identified structs. class LLVMStructType : public Type::TypeBase { + DataLayoutTypeInterface::Trait, + SubElementTypeInterface::Trait> { public: /// Inherit base constructors. using Base::Base; @@ -358,6 +371,9 @@ LogicalResult verifyEntries(DataLayoutEntryListRef entries, Location loc) const; + + void walkImmediateSubElements(function_ref walkAttrsFn, + function_ref walkTypesFn) const; }; //===----------------------------------------------------------------------===// @@ -368,7 +384,8 @@ /// length that can be processed as one. class LLVMFixedVectorType : public Type::TypeBase { + detail::LLVMTypeAndSizeStorage, + SubElementTypeInterface::Trait> { public: /// Inherit base constructor. using Base::Base; @@ -387,7 +404,7 @@ static bool isValidElementType(Type type); /// Returns the element type of the vector. - Type getElementType(); + Type getElementType() const; /// Returns the number of elements in the fixed vector. unsigned getNumElements(); @@ -395,6 +412,9 @@ /// Verifies that the type about to be constructed is well-formed. static LogicalResult verify(function_ref emitError, Type elementType, unsigned numElements); + + void walkImmediateSubElements(function_ref walkAttrsFn, + function_ref walkTypesFn) const; }; //===----------------------------------------------------------------------===// @@ -406,7 +426,8 @@ /// elements can be processed as one in SIMD context. class LLVMScalableVectorType : public Type::TypeBase { + detail::LLVMTypeAndSizeStorage, + SubElementTypeInterface::Trait> { public: /// Inherit base constructor. using Base::Base; @@ -423,7 +444,7 @@ static bool isValidElementType(Type type); /// Returns the element type of the vector. - Type getElementType(); + Type getElementType() const; /// Returns the scaling factor of the number of elements in the vector. The /// vector contains at least the resulting number of elements, or any non-zero @@ -433,6 +454,9 @@ /// Verifies that the type about to be constructed is well-formed. static LogicalResult verify(function_ref emitError, Type elementType, unsigned minNumElements); + + void walkImmediateSubElements(function_ref walkAttrsFn, + function_ref walkTypesFn) const; }; //===----------------------------------------------------------------------===// Index: mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp =================================================================== --- mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -86,6 +86,12 @@ return dataLayout.getTypePreferredAlignment(getElementType()); } +void LLVMArrayType::walkImmediateSubElements( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + walkTypesFn(getElementType()); +} + //===----------------------------------------------------------------------===// // Function type. //===----------------------------------------------------------------------===// @@ -119,8 +125,10 @@ return get(results[0], llvm::to_vector(inputs), isVarArg()); } -Type LLVMFunctionType::getReturnType() { return getImpl()->getReturnType(); } -ArrayRef LLVMFunctionType::getReturnTypes() { +Type LLVMFunctionType::getReturnType() const { + return getImpl()->getReturnType(); +} +ArrayRef LLVMFunctionType::getReturnTypes() const { return getImpl()->getReturnType(); } @@ -134,7 +142,7 @@ bool LLVMFunctionType::isVarArg() const { return getImpl()->isVariadic(); } -ArrayRef LLVMFunctionType::getParams() { +ArrayRef LLVMFunctionType::getParams() const { return getImpl()->getArgumentTypes(); } @@ -151,6 +159,13 @@ return success(); } +void LLVMFunctionType::walkImmediateSubElements( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + for (Type type : llvm::concat(getReturnTypes(), getParams())) + walkTypesFn(type); +} + //===----------------------------------------------------------------------===// // Pointer type. //===----------------------------------------------------------------------===// @@ -353,6 +368,12 @@ return success(); } +void LLVMPointerType::walkImmediateSubElements( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + walkTypesFn(getElementType()); +} + //===----------------------------------------------------------------------===// // Struct type. //===----------------------------------------------------------------------===// @@ -589,6 +610,13 @@ return mlir::success(); } +void LLVMStructType::walkImmediateSubElements( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + for (Type type : getBody()) + walkTypesFn(type); +} + //===----------------------------------------------------------------------===// // Vector types. //===----------------------------------------------------------------------===// @@ -621,7 +649,7 @@ numElements); } -Type LLVMFixedVectorType::getElementType() { +Type LLVMFixedVectorType::getElementType() const { return static_cast(impl)->elementType; } @@ -640,6 +668,12 @@ emitError, elementType, numElements); } +void LLVMFixedVectorType::walkImmediateSubElements( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + walkTypesFn(getElementType()); +} + //===----------------------------------------------------------------------===// // LLVMScalableVectorType. //===----------------------------------------------------------------------===// @@ -658,7 +692,7 @@ minNumElements); } -Type LLVMScalableVectorType::getElementType() { +Type LLVMScalableVectorType::getElementType() const { return static_cast(impl)->elementType; } @@ -680,6 +714,12 @@ emitError, elementType, numElements); } +void LLVMScalableVectorType::walkImmediateSubElements( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + walkTypesFn(getElementType()); +} + //===----------------------------------------------------------------------===// // Utility functions. //===----------------------------------------------------------------------===// Index: mlir/unittests/Dialect/CMakeLists.txt =================================================================== --- mlir/unittests/Dialect/CMakeLists.txt +++ mlir/unittests/Dialect/CMakeLists.txt @@ -7,6 +7,7 @@ MLIRDialect) add_subdirectory(Affine) +add_subdirectory(LLVMIR) add_subdirectory(Quant) add_subdirectory(SparseTensor) add_subdirectory(SPIRV) Index: mlir/unittests/Dialect/LLVMIR/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/unittests/Dialect/LLVMIR/CMakeLists.txt @@ -0,0 +1,7 @@ +add_mlir_unittest(MLIRLLVMIRTests + LLVMSubElementTypeTest.cpp +) +target_link_libraries(MLIRLLVMIRTests + PRIVATE + MLIRLLVMIR + ) Index: mlir/unittests/Dialect/LLVMIR/LLVMSubElementTypeTest.cpp =================================================================== --- /dev/null +++ mlir/unittests/Dialect/LLVMIR/LLVMSubElementTypeTest.cpp @@ -0,0 +1,45 @@ +//===- LLVMSubElementTypeTest.cpp - LLVM SubElementTypeInterface tests ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "LLVMTestBase.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/SubElementInterfaces.h" + +using namespace mlir; +using namespace mlir::LLVM; + +TEST_F(LLVMIRTest, MutualReferencedSubElementTypes) { + auto fooStructTy = LLVMStructType::getIdentified(&context, "foo"); + ASSERT_TRUE(bool(fooStructTy)); + auto barStructTy = LLVMStructType::getIdentified(&context, "bar"); + ASSERT_TRUE(bool(barStructTy)); + + // Created two structs that are referencing each other. + Type fooBody[] = {LLVMPointerType::get(barStructTy)}; + ASSERT_TRUE(succeeded(fooStructTy.setBody(fooBody, /*packed=*/false))); + Type barBody[] = {LLVMPointerType::get(fooStructTy)}; + ASSERT_TRUE(succeeded(barStructTy.setBody(barBody, /*packed=*/false))); + + auto subElementInterface = fooStructTy.dyn_cast(); + ASSERT_TRUE(bool(subElementInterface)); + // Test if walkSubElements goes into infinite loops. + SmallVector subElementTypes; + subElementInterface.walkSubElements( + [](Attribute attr) {}, + [&](Type type) { subElementTypes.push_back(type); }); + ASSERT_EQ(subElementTypes.size(), 4U); + + auto structType = subElementTypes[0].dyn_cast(); + ASSERT_TRUE(bool(structType)); + ASSERT_TRUE(structType.getName().equals("foo")); + ASSERT_TRUE(subElementTypes[1].isa()); + structType = subElementTypes[2].dyn_cast(); + ASSERT_TRUE(bool(structType)); + ASSERT_TRUE(structType.getName().equals("bar")); + ASSERT_TRUE(subElementTypes[3].isa()); +} Index: mlir/unittests/Dialect/LLVMIR/LLVMTestBase.h =================================================================== --- /dev/null +++ mlir/unittests/Dialect/LLVMIR/LLVMTestBase.h @@ -0,0 +1,27 @@ +//===- LLVMTestBase.h - Test fixure for LLVM dialect tests ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Test fixure for LLVM dialect tests. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_UNITTEST_DIALECT_LLVMIR_LLVMTESTBASE_H +#define MLIR_UNITTEST_DIALECT_LLVMIR_LLVMTESTBASE_H + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/MLIRContext.h" +#include "gtest/gtest.h" + +class LLVMIRTest : public ::testing::Test { +protected: + LLVMIRTest() { context.getOrLoadDialect(); } + + mlir::MLIRContext context; +}; + +#endif