diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -9,7 +9,7 @@ #ifndef MLIR_IR_BUILTINATTRIBUTES_H #define MLIR_IR_BUILTINATTRIBUTES_H -#include "mlir/IR/Attributes.h" +#include "SubElementInterfaces.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/Sequence.h" #include diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -15,6 +15,7 @@ #define BUILTIN_ATTRIBUTES include "mlir/IR/BuiltinDialect.td" +include "mlir/IR/SubElementInterfaces.td" // TODO: Currently the attributes defined in this file are prefixed with // `Builtin_`. This is to differentiate the attributes here with the ones in @@ -22,8 +23,9 @@ // to this file instead. // Base class for Builtin dialect attributes. -class Builtin_Attr - : AttrDef { +class Builtin_Attr traits = [], + string baseCppClass = "::mlir::Attribute"> + : AttrDef { let mnemonic = ?; } @@ -62,7 +64,9 @@ // ArrayAttr //===----------------------------------------------------------------------===// -def Builtin_ArrayAttr : Builtin_Attr<"Array"> { +def Builtin_ArrayAttr : Builtin_Attr<"Array", [ + DeclareAttrInterfaceMethods + ]> { let summary = "A collection of other Attribute values"; let description = [{ Syntax: @@ -133,7 +137,7 @@ //===----------------------------------------------------------------------===// def Builtin_DenseIntOrFPElementsAttr - : Builtin_Attr<"DenseIntOrFPElements", "DenseElementsAttr"> { + : Builtin_Attr<"DenseIntOrFPElements", /*traits=*/[], "DenseElementsAttr"> { let summary = "An Attribute containing a dense multi-dimensional array of " "integer or floating-point values"; let description = [{ @@ -228,7 +232,7 @@ //===----------------------------------------------------------------------===// def Builtin_DenseStringElementsAttr - : Builtin_Attr<"DenseStringElements", "DenseElementsAttr"> { + : Builtin_Attr<"DenseStringElements", /*traits=*/[], "DenseElementsAttr"> { let summary = "An Attribute containing a dense multi-dimensional array of " "strings"; let description = [{ @@ -277,7 +281,9 @@ // DictionaryAttr //===----------------------------------------------------------------------===// -def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary"> { +def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary", [ + DeclareAttrInterfaceMethods + ]> { let summary = "An dictionary of named Attribute values"; let description = [{ Syntax: @@ -589,7 +595,7 @@ //===----------------------------------------------------------------------===// def Builtin_OpaqueElementsAttr - : Builtin_Attr<"OpaqueElements", "ElementsAttr"> { + : Builtin_Attr<"OpaqueElements", /*traits=*/[], "ElementsAttr"> { let summary = "An opaque representation of a multi-dimensional array"; let description = [{ Syntax: @@ -655,7 +661,7 @@ //===----------------------------------------------------------------------===// def Builtin_SparseElementsAttr - : Builtin_Attr<"SparseElements", "ElementsAttr"> { + : Builtin_Attr<"SparseElements", /*traits=*/[], "ElementsAttr"> { let summary = "An opaque representation of a multi-dimensional array"; let description = [{ Syntax: @@ -892,7 +898,9 @@ // TypeAttr //===----------------------------------------------------------------------===// -def Builtin_TypeAttr : Builtin_Attr<"Type"> { +def Builtin_TypeAttr : Builtin_Attr<"Type", [ + DeclareAttrInterfaceMethods + ]> { let summary = "An Attribute containing a Type"; let description = [{ Syntax: diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -9,8 +9,7 @@ #ifndef MLIR_IR_BUILTINTYPES_H #define MLIR_IR_BUILTINTYPES_H -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Types.h" +#include "SubElementInterfaces.h" namespace llvm { struct fltSemantics; diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -16,14 +16,16 @@ include "mlir/IR/BuiltinDialect.td" include "mlir/IR/BuiltinTypeInterfaces.td" +include "mlir/IR/SubElementInterfaces.td" // TODO: Currently the types defined in this file are prefixed with `Builtin_`. // This is to differentiate the types here with the ones in OpBase.td. We should // remove the definitions in OpBase.td, and repoint users to this file instead. // Base class for Builtin dialect types. -class Builtin_Type - : TypeDef { +class Builtin_Type traits = [], + string baseCppClass = "::mlir::Type"> + : TypeDef { let mnemonic = ?; } @@ -66,7 +68,8 @@ //===----------------------------------------------------------------------===// // Base class for Builtin dialect float types. -class Builtin_FloatType : Builtin_Type { +class Builtin_FloatType + : Builtin_Type { let extraClassDeclaration = [{ static }] # name # [{Type get(MLIRContext *context); }]; @@ -118,7 +121,9 @@ // FunctionType //===----------------------------------------------------------------------===// -def Builtin_Function : Builtin_Type<"Function"> { +def Builtin_Function : Builtin_Type<"Function", [ + DeclareTypeInterfaceMethods + ]> { let summary = "Map from a list of inputs to a list of results"; let description = [{ Syntax: @@ -253,7 +258,9 @@ // MemRefType //===----------------------------------------------------------------------===// -def Builtin_MemRef : Builtin_Type<"MemRef", "BaseMemRefType"> { +def Builtin_MemRef : Builtin_Type<"MemRef", [ + DeclareTypeInterfaceMethods + ], "BaseMemRefType"> { let summary = "Shaped reference to a region of memory"; let description = [{ Syntax: @@ -638,7 +645,9 @@ // RankedTensorType //===----------------------------------------------------------------------===// -def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "TensorType"> { +def Builtin_RankedTensor : Builtin_Type<"RankedTensor", [ + DeclareTypeInterfaceMethods + ], "TensorType"> { let summary = "Multi-dimensional array with a fixed number of dimensions"; let description = [{ Syntax: @@ -726,7 +735,9 @@ // TupleType //===----------------------------------------------------------------------===// -def Builtin_Tuple : Builtin_Type<"Tuple"> { +def Builtin_Tuple : Builtin_Type<"Tuple", [ + DeclareTypeInterfaceMethods + ]> { let summary = "Fixed-sized collection of other types"; let description = [{ Syntax: @@ -793,7 +804,9 @@ // UnrankedMemRefType //===----------------------------------------------------------------------===// -def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "BaseMemRefType"> { +def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", [ + DeclareTypeInterfaceMethods + ], "BaseMemRefType"> { let summary = "Shaped reference, with unknown rank, to a region of memory"; let description = [{ Syntax: @@ -853,7 +866,9 @@ // UnrankedTensorType //===----------------------------------------------------------------------===// -def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "TensorType"> { +def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", [ + DeclareTypeInterfaceMethods + ], "TensorType"> { let summary = "Multi-dimensional array with unknown dimensions"; let description = [{ Syntax: @@ -890,7 +905,9 @@ // VectorType //===----------------------------------------------------------------------===// -def Builtin_Vector : Builtin_Type<"Vector", "ShapedType"> { +def Builtin_Vector : Builtin_Type<"Vector", [ + DeclareTypeInterfaceMethods + ], "ShapedType"> { let summary = "Multi-dimensional SIMD vector type"; let description = [{ Syntax: diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt --- a/mlir/include/mlir/IR/CMakeLists.txt +++ b/mlir/include/mlir/IR/CMakeLists.txt @@ -31,6 +31,13 @@ mlir_tablegen(BuiltinTypeInterfaces.cpp.inc -gen-type-interface-defs) add_public_tablegen_target(MLIRBuiltinTypeInterfacesIncGen) +set(LLVM_TARGET_DEFINITIONS SubElementInterfaces.td) +mlir_tablegen(SubElementAttrInterfaces.h.inc -gen-attr-interface-decls) +mlir_tablegen(SubElementAttrInterfaces.cpp.inc -gen-attr-interface-defs) +mlir_tablegen(SubElementTypeInterfaces.h.inc -gen-type-interface-decls) +mlir_tablegen(SubElementTypeInterfaces.cpp.inc -gen-type-interface-defs) +add_public_tablegen_target(MLIRSubElementInterfacesIncGen) + set(LLVM_TARGET_DEFINITIONS TensorEncoding.td) mlir_tablegen(TensorEncInterfaces.h.inc -gen-attr-interface-decls) mlir_tablegen(TensorEncInterfaces.cpp.inc -gen-attr-interface-defs) diff --git a/mlir/include/mlir/IR/SubElementInterfaces.h b/mlir/include/mlir/IR/SubElementInterfaces.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/SubElementInterfaces.h @@ -0,0 +1,24 @@ +//===- SubElementInterfaces.h - Attr and Type SubElements -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains interfaces and utilities for querying the sub elements of +// an attribute or type. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_SUBELEMENTINTERFACES_H +#define MLIR_INTERFACES_SUBELEMENTINTERFACES_H + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Types.h" + +/// Include the definitions of the sub elemnt interfaces. +#include "mlir/IR/SubElementAttrInterfaces.h.inc" +#include "mlir/IR/SubElementTypeInterfaces.h.inc" + +#endif // MLIR_INTERFACES_SUBELEMENTINTERFACES_H diff --git a/mlir/include/mlir/IR/SubElementInterfaces.td b/mlir/include/mlir/IR/SubElementInterfaces.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/SubElementInterfaces.td @@ -0,0 +1,100 @@ +//===-- SubElementInterfaces.td - Sub-Element Interfaces ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains a set of interfaces that can be used to interface with +// sub-elements, e.g. held attributes and types, of a composite attribute or +// type. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_SUBELEMENTINTERFACES_TD_ +#define MLIR_IR_SUBELEMENTINTERFACES_TD_ + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// SubElementInterfaceBase +//===----------------------------------------------------------------------===// + +class SubElementInterfaceBase { + string cppNamespace = "::mlir"; + + list methods = [ + InterfaceMethod< + /*desc=*/[{ + Walk all of the immediately nested sub-attributes and sub-types. This + method does not recurse into sub elements. + }], "void", "walkImmediateSubElements", + (ins "llvm::function_ref":$walkAttrsFn, + "llvm::function_ref":$walkTypesFn) + >, + ]; + + code extraClassDeclaration = [{ + /// Walk all of the held sub-attributes. + void walkSubAttrs(llvm::function_ref walkFn) { + walkSubElements(walkFn, /*walkTypesFn=*/[](mlir::Type) {}); + } + + /// Walk all of the held sub-types. + void walkSubTypes(llvm::function_ref walkFn) { + walkSubElements(/*walkAttrsFn=*/[](mlir::Attribute) {}, walkFn); + } + + /// Walk all of the held sub-attributes and sub-types. + void walkSubElements(llvm::function_ref walkAttrsFn, + llvm::function_ref walkTypesFn); + }]; + + code extraTraitClassDeclaration = [{ + /// Walk all of the held sub-attributes. + void walkSubAttrs(llvm::function_ref walkFn) { + walkSubElements(walkFn, /*walkTypesFn=*/[](mlir::Type) {}); + } + + /// Walk all of the held sub-types. + void walkSubTypes(llvm::function_ref walkFn) { + walkSubElements(/*walkAttrsFn=*/[](mlir::Attribute) {}, walkFn); + } + + /// Walk all of the held sub-attributes and sub-types. + void walkSubElements(llvm::function_ref walkAttrsFn, + llvm::function_ref walkTypesFn) { + }] # interfaceName # " interface(" # derivedValue # [{); + interface.walkSubElements(walkAttrsFn, walkTypesFn); + } + }]; +} + +//===----------------------------------------------------------------------===// +// SubElementAttrInterface +//===----------------------------------------------------------------------===// + +def SubElementAttrInterface + : AttrInterface<"SubElementAttrInterface">, + SubElementInterfaceBase<"SubElementAttrInterface", "$_attr"> { + let description = [{ + An interface used to query and manipulate sub-elements, such as sub-types + and sub-attributes of a composite attribute. + }]; +} + +//===----------------------------------------------------------------------===// +// SubElementTypeInterface +//===----------------------------------------------------------------------===// + +def SubElementTypeInterface + : TypeInterface<"SubElementTypeInterface">, + SubElementInterfaceBase<"SubElementTypeInterface", "$_type"> { + let description = [{ + An interface used to query and manipulate sub-elements, such as sub-types + and sub-attributes of a composite type. + }]; +} + +#endif // MLIR_IR_SUBELEMENTINTERFACES_TD_ diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/SubElementInterfaces.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/MapVector.h" @@ -626,14 +627,10 @@ return; } - if (auto arrayAttr = attr.dyn_cast()) { - for (Attribute element : arrayAttr.getValue()) - visit(element); - } else if (auto dictAttr = attr.dyn_cast()) { - for (const NamedAttribute &attr : dictAttr) - visit(attr.second); - } else if (auto typeAttr = attr.dyn_cast()) { - visit(typeAttr.getValue()); + // Check for any sub elements. + if (auto subElementInterface = attr.dyn_cast()) { + subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); }, + [&](Type type) { visit(type); }); } } @@ -645,20 +642,10 @@ if (succeeded(generateAlias(type, aliasToType))) return; - // Visit several subtypes that contain types or attributes. - if (auto funcType = type.dyn_cast()) { - // Visit input and result types for functions. - for (auto input : funcType.getInputs()) - visit(input); - for (auto result : funcType.getResults()) - visit(result); - } else if (auto shapedType = type.dyn_cast()) { - visit(shapedType.getElementType()); - - // Visit affine maps in memref type. - if (auto memref = type.dyn_cast()) - for (auto map : memref.getAffineMaps()) - visit(AffineMapAttr::get(map)); + // Check for any sub elements. + if (auto subElementInterface = type.dyn_cast()) { + subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); }, + [&](Type type) { visit(type); }); } } diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -42,6 +42,17 @@ UnitAttr>(); } +//===----------------------------------------------------------------------===// +// ArrayAttr +//===----------------------------------------------------------------------===// + +void ArrayAttr::walkImmediateSubElements( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + for (Attribute attr : getValue()) + walkAttrsFn(attr); +} + //===----------------------------------------------------------------------===// // DictionaryAttr //===----------------------------------------------------------------------===// @@ -197,6 +208,13 @@ return Base::get(context, ArrayRef()); } +void DictionaryAttr::walkImmediateSubElements( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + for (Attribute attr : llvm::make_second_range(getValue())) + walkAttrsFn(attr); +} + //===----------------------------------------------------------------------===// // StringAttr //===----------------------------------------------------------------------===// @@ -1370,3 +1388,13 @@ {&*std::next(sparseIndexValues.begin(), i * rank), rank})); return flatSparseIndices; } + +//===----------------------------------------------------------------------===// +// TypeAttr +//===----------------------------------------------------------------------===// + +void TypeAttr::walkImmediateSubElements( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + walkTypesFn(getValue()); +} diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -199,6 +199,13 @@ return get(getContext(), newInputTypes, newResultTypes); } +void FunctionType::walkImmediateSubElements( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + for (Type type : llvm::concat(getInputs(), getResults())) + walkTypesFn(type); +} + //===----------------------------------------------------------------------===// // OpaqueType //===----------------------------------------------------------------------===// @@ -419,6 +426,12 @@ return VectorType(); } +void VectorType::walkImmediateSubElements( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + walkTypesFn(getElementType()); +} + //===----------------------------------------------------------------------===// // TensorType //===----------------------------------------------------------------------===// @@ -459,6 +472,12 @@ return checkTensorElementType(emitError, elementType); } +void RankedTensorType::walkImmediateSubElements( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + walkTypesFn(getElementType()); +} + //===----------------------------------------------------------------------===// // UnrankedTensorType //===----------------------------------------------------------------------===// @@ -469,6 +488,12 @@ return checkTensorElementType(emitError, elementType); } +void UnrankedTensorType::walkImmediateSubElements( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + walkTypesFn(getElementType()); +} + //===----------------------------------------------------------------------===// // BaseMemRefType //===----------------------------------------------------------------------===// @@ -612,6 +637,15 @@ return success(); } +void MemRefType::walkImmediateSubElements( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + walkTypesFn(getElementType()); + walkAttrsFn(getMemorySpace()); + for (AffineMap map : getAffineMaps()) + walkAttrsFn(AffineMapAttr::get(map)); +} + //===----------------------------------------------------------------------===// // UnrankedMemRefType //===----------------------------------------------------------------------===// @@ -779,6 +813,13 @@ return success(); } +void UnrankedMemRefType::walkImmediateSubElements( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + walkTypesFn(getElementType()); + walkAttrsFn(getMemorySpace()); +} + //===----------------------------------------------------------------------===// /// TupleType //===----------------------------------------------------------------------===// @@ -802,6 +843,13 @@ /// Return the number of element types. size_t TupleType::size() const { return getImpl()->size(); } +void TupleType::walkImmediateSubElements( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + for (Type type : getTypes()) + walkTypesFn(type); +} + //===----------------------------------------------------------------------===// // Type Utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -21,6 +21,7 @@ PatternMatch.cpp Region.cpp RegionKindInterface.cpp + SubElementInterfaces.cpp SymbolTable.cpp TensorEncoding.cpp Types.cpp @@ -46,6 +47,7 @@ MLIROpAsmInterfaceIncGen MLIRRegionKindInterfaceIncGen MLIRSideEffectInterfacesIncGen + MLIRSubElementInterfacesIncGen MLIRSymbolInterfacesIncGen MLIRTensorEncodingIncGen diff --git a/mlir/lib/IR/SubElementInterfaces.cpp b/mlir/lib/IR/SubElementInterfaces.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/IR/SubElementInterfaces.cpp @@ -0,0 +1,65 @@ +//===- SubElementInterfaces.cpp - Attr and Type SubElement Interfaces -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/SubElementInterfaces.h" + +using namespace mlir; + +template +static void walkSubElementsImpl(InterfaceT interface, + function_ref walkAttrsFn, + function_ref walkTypesFn) { + interface.walkImmediateSubElements( + [&](Attribute attr) { + // Guard against potentially null inputs. This removes the need for the + // derived attribute/type to do it. + if (!attr) + return; + + // Walk any sub elements first. + if (auto interface = attr.dyn_cast()) + walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn); + + // Walk this attribute. + walkAttrsFn(attr); + }, + [&](Type type) { + // Guard against potentially null inputs. This removes the need for the + // derived attribute/type to do it. + if (!type) + return; + + // Walk any sub elements first. + if (auto interface = type.dyn_cast()) + walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn); + + // Walk this type. + walkTypesFn(type); + }); +} + +void SubElementAttrInterface::walkSubElements( + function_ref walkAttrsFn, + function_ref walkTypesFn) { + assert(walkAttrsFn && walkTypesFn && "expected valid walk functions"); + walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn); +} + +void SubElementTypeInterface::walkSubElements( + function_ref walkAttrsFn, + function_ref walkTypesFn) { + assert(walkAttrsFn && walkTypesFn && "expected valid walk functions"); + walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn); +} + +//===----------------------------------------------------------------------===// +// SubElementInterface Tablegen definitions +//===----------------------------------------------------------------------===// + +#include "mlir/IR/SubElementAttrInterfaces.cpp.inc" +#include "mlir/IR/SubElementTypeInterfaces.cpp.inc" diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt --- a/mlir/unittests/IR/CMakeLists.txt +++ b/mlir/unittests/IR/CMakeLists.txt @@ -4,6 +4,7 @@ MemRefTypeTest.cpp OperationSupportTest.cpp ShapedTypeTest.cpp + SubElementInterfaceTest.cpp ) target_link_libraries(MLIRIRTests PRIVATE diff --git a/mlir/unittests/IR/SubElementInterfaceTest.cpp b/mlir/unittests/IR/SubElementInterfaceTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/IR/SubElementInterfaceTest.cpp @@ -0,0 +1,35 @@ +//===- SubElementInterfaceTest.cpp - SubElementInterface unit tests -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/SubElementInterfaces.h" +#include "gtest/gtest.h" +#include + +using namespace mlir; +using namespace mlir::detail; + +namespace { +TEST(SubElementInterfaceTest, Nested) { + MLIRContext context; + Builder builder(&context); + + BoolAttr trueAttr = builder.getBoolAttr(true); + BoolAttr falseAttr = builder.getBoolAttr(false); + ArrayAttr boolArrayAttr = builder.getArrayAttr({trueAttr, falseAttr}); + DictionaryAttr dictAttr = + builder.getDictionaryAttr(builder.getNamedAttr("array", boolArrayAttr)); + + SmallVector subAttrs; + dictAttr.walkSubAttrs([&](Attribute attr) { subAttrs.push_back(attr); }); + EXPECT_EQ(llvm::makeArrayRef(subAttrs), + ArrayRef({trueAttr, falseAttr, boolArrayAttr})); +} + +} // end namespace