diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -9,7 +9,7 @@ #ifndef MLIR_IR_BUILTINATTRIBUTES_H #define MLIR_IR_BUILTINATTRIBUTES_H -#include "mlir/IR/Attributes.h" +#include "SubElementInterfaces.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/Sequence.h" #include diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -15,6 +15,7 @@ #define BUILTIN_ATTRIBUTES include "mlir/IR/BuiltinDialect.td" +include "mlir/IR/SubElementInterfaces.td" // TODO: Currently the attributes defined in this file are prefixed with // `Builtin_`. This is to differentiate the attributes here with the ones in @@ -22,8 +23,9 @@ // to this file instead. // Base class for Builtin dialect attributes. -class Builtin_Attr - : AttrDef { +class Builtin_Attr traits = [], + string baseCppClass = "::mlir::Attribute"> + : AttrDef { let mnemonic = ?; } @@ -62,7 +64,9 @@ // ArrayAttr //===----------------------------------------------------------------------===// -def Builtin_ArrayAttr : Builtin_Attr<"Array"> { +def Builtin_ArrayAttr : Builtin_Attr<"Array", [ + DeclareAttrInterfaceMethods + ]> { let summary = "A collection of other Attribute values"; let description = [{ Syntax: @@ -133,7 +137,7 @@ //===----------------------------------------------------------------------===// def Builtin_DenseIntOrFPElementsAttr - : Builtin_Attr<"DenseIntOrFPElements", "DenseElementsAttr"> { + : Builtin_Attr<"DenseIntOrFPElements", /*traits=*/[], "DenseElementsAttr"> { let summary = "An Attribute containing a dense multi-dimensional array of " "integer or floating-point values"; let description = [{ @@ -228,7 +232,7 @@ //===----------------------------------------------------------------------===// def Builtin_DenseStringElementsAttr - : Builtin_Attr<"DenseStringElements", "DenseElementsAttr"> { + : Builtin_Attr<"DenseStringElements", /*traits=*/[], "DenseElementsAttr"> { let summary = "An Attribute containing a dense multi-dimensional array of " "strings"; let description = [{ @@ -277,7 +281,9 @@ // DictionaryAttr //===----------------------------------------------------------------------===// -def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary"> { +def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary", [ + DeclareAttrInterfaceMethods + ]> { let summary = "An dictionary of named Attribute values"; let description = [{ Syntax: @@ -588,7 +594,7 @@ //===----------------------------------------------------------------------===// def Builtin_OpaqueElementsAttr - : Builtin_Attr<"OpaqueElements", "ElementsAttr"> { + : Builtin_Attr<"OpaqueElements", /*traits=*/[], "ElementsAttr"> { let summary = "An opaque representation of a multi-dimensional array"; let description = [{ Syntax: @@ -654,7 +660,7 @@ //===----------------------------------------------------------------------===// def Builtin_SparseElementsAttr - : Builtin_Attr<"SparseElements", "ElementsAttr"> { + : Builtin_Attr<"SparseElements", /*traits=*/[], "ElementsAttr"> { let summary = "An opaque representation of a multi-dimensional array"; let description = [{ Syntax: @@ -883,7 +889,9 @@ // TypeAttr //===----------------------------------------------------------------------===// -def Builtin_TypeAttr : Builtin_Attr<"Type"> { +def Builtin_TypeAttr : Builtin_Attr<"Type", [ + DeclareAttrInterfaceMethods + ]> { let summary = "An Attribute containing a Type"; let description = [{ Syntax: diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -9,8 +9,7 @@ #ifndef MLIR_IR_BUILTINTYPES_H #define MLIR_IR_BUILTINTYPES_H -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Types.h" +#include "SubElementInterfaces.h" namespace llvm { struct fltSemantics; diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -15,14 +15,16 @@ #define BUILTIN_TYPES include "mlir/IR/BuiltinDialect.td" +include "mlir/IR/SubElementInterfaces.td" // TODO: Currently the types defined in this file are prefixed with `Builtin_`. // This is to differentiate the types here with the ones in OpBase.td. We should // remove the definitions in OpBase.td, and repoint users to this file instead. // Base class for Builtin dialect types. -class Builtin_Type - : TypeDef { +class Builtin_Type traits = [], + string baseCppClass = "::mlir::Type"> + : TypeDef { let mnemonic = ?; } @@ -65,7 +67,8 @@ //===----------------------------------------------------------------------===// // Base class for Builtin dialect float types. -class Builtin_FloatType : Builtin_Type { +class Builtin_FloatType + : Builtin_Type { let extraClassDeclaration = [{ static }] # name # [{Type get(MLIRContext *context); }]; @@ -117,7 +120,9 @@ // FunctionType //===----------------------------------------------------------------------===// -def Builtin_Function : Builtin_Type<"Function"> { +def Builtin_Function : Builtin_Type<"Function", [ + DeclareTypeInterfaceMethods + ]> { let summary = "Map from a list of inputs to a list of results"; let description = [{ Syntax: @@ -252,7 +257,9 @@ // MemRefType //===----------------------------------------------------------------------===// -def Builtin_MemRef : Builtin_Type<"MemRef", "BaseMemRefType"> { +def Builtin_MemRef : Builtin_Type<"MemRef", [ + DeclareTypeInterfaceMethods + ], "BaseMemRefType"> { let summary = "Shaped reference to a region of memory"; let description = [{ Syntax: @@ -628,7 +635,9 @@ // RankedTensorType //===----------------------------------------------------------------------===// -def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "TensorType"> { +def Builtin_RankedTensor : Builtin_Type<"RankedTensor", [ + DeclareTypeInterfaceMethods + ], "TensorType"> { let summary = "Multi-dimensional array with a fixed number of dimensions"; let description = [{ Syntax: @@ -716,7 +725,9 @@ // TupleType //===----------------------------------------------------------------------===// -def Builtin_Tuple : Builtin_Type<"Tuple"> { +def Builtin_Tuple : Builtin_Type<"Tuple", [ + DeclareTypeInterfaceMethods + ]> { let summary = "Fixed-sized collection of other types"; let description = [{ Syntax: @@ -783,7 +794,9 @@ // UnrankedMemRefType //===----------------------------------------------------------------------===// -def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "BaseMemRefType"> { +def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", [ + DeclareTypeInterfaceMethods + ], "BaseMemRefType"> { let summary = "Shaped reference, with unknown rank, to a region of memory"; let description = [{ Syntax: @@ -843,7 +856,9 @@ // UnrankedTensorType //===----------------------------------------------------------------------===// -def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "TensorType"> { +def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", [ + DeclareTypeInterfaceMethods + ], "TensorType"> { let summary = "Multi-dimensional array with unknown dimensions"; let description = [{ Syntax: @@ -880,7 +895,9 @@ // VectorType //===----------------------------------------------------------------------===// -def Builtin_Vector : Builtin_Type<"Vector", "ShapedType"> { +def Builtin_Vector : Builtin_Type<"Vector", [ + DeclareTypeInterfaceMethods + ], "ShapedType"> { let summary = "Multi-dimensional SIMD vector type"; let description = [{ Syntax: diff --git a/mlir/include/mlir/IR/SubElementInterfaces.h b/mlir/include/mlir/IR/SubElementInterfaces.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/SubElementInterfaces.h @@ -0,0 +1,24 @@ +//===- SubElementInterfaces.h - Attr and Type SubElements -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains interfaces and utilities for querying the sub elements of +// an attribute or type. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_SUBELEMENTINTERFACES_H +#define MLIR_INTERFACES_SUBELEMENTINTERFACES_H + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Types.h" + +/// Include the definitions of the sub elemnt interfaces. +#include "mlir/IR/SubElementAttrInterfaces.h.inc" +#include "mlir/IR/SubElementTypeInterfaces.h.inc" + +#endif // MLIR_INTERFACES_SUBELEMENTINTERFACES_H diff --git a/mlir/include/mlir/IR/SubElementInterfaces.td b/mlir/include/mlir/IR/SubElementInterfaces.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/SubElementInterfaces.td @@ -0,0 +1,79 @@ +//===-- SubElementInterfaces.td - Sub-Element Interfaces ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains a set of interfaces that can be used to interface with +// sub-elements, e.g. held attributes and types, of a composite attribute or +// type. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_SIDEEFFECTS +#define MLIR_INTERFACES_SIDEEFFECTS + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// SubElementInterfaceBase +//===----------------------------------------------------------------------===// + +class SubElementInterfaceBase { + string cppNamespace = "::mlir"; + + list methods = [ + InterfaceMethod< + /*desc=*/[{ + Walk all of the attributes and types directly held by this type. This + method does not recurse into sub elements. + }], "void", "walkSubElementsImpl", + (ins "llvm::function_ref":$walkAttrsFn, + "llvm::function_ref":$walkTypesFn) + >, + ]; + + code extraClassDeclaration = [{ + /// Walk all of the held sub-attributes. + void walkSubAttrs(llvm::function_ref walkFn) { + walkSubElements(walkFn, /*walkTypesFn=*/[](mlir::Type) {}); + } + + /// Walk all of the held sub-types. + void walkSubTypes(llvm::function_ref walkFn) { + walkSubElements(/*walkAttrsFn=*/[](mlir::Attribute) {}, walkFn); + } + + /// Walk all of the held sub-attributes and sub-types. + void walkSubElements(llvm::function_ref walkAttrsFn, + llvm::function_ref walkTypesFn); + }]; +} + +//===----------------------------------------------------------------------===// +// SubElementAttrInterface +//===----------------------------------------------------------------------===// + +def SubElementAttrInterface + : AttrInterface<"SubElementAttrInterface">, SubElementInterfaceBase { + let description = [{ + An interface used to query and manipulate sub-elements, such as sub-types + and sub-attributes of a composite attribute. + }]; +} + +//===----------------------------------------------------------------------===// +// SubElementTypeInterface +//===----------------------------------------------------------------------===// + +def SubElementTypeInterface + : TypeInterface<"SubElementTypeInterface">, SubElementInterfaceBase { + let description = [{ + An interface used to query and manipulate sub-elements, such as sub-types + and sub-attributes of a composite type. + }]; +} + +#endif // MLIR_INTERFACES_SIDEEFFECTS diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/SubElementInterfaces.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/MapVector.h" @@ -607,14 +608,10 @@ return; } - if (auto arrayAttr = attr.dyn_cast()) { - for (Attribute element : arrayAttr.getValue()) - visit(element); - } else if (auto dictAttr = attr.dyn_cast()) { - for (const NamedAttribute &attr : dictAttr) - visit(attr.second); - } else if (auto typeAttr = attr.dyn_cast()) { - visit(typeAttr.getValue()); + // Check for any sub elements. + if (auto subElementInterface = attr.dyn_cast()) { + subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); }, + [&](Type type) { visit(type); }); } } @@ -626,20 +623,10 @@ if (succeeded(generateAlias(type, aliasToType))) return; - // Visit several subtypes that contain types or attributes. - if (auto funcType = type.dyn_cast()) { - // Visit input and result types for functions. - for (auto input : funcType.getInputs()) - visit(input); - for (auto result : funcType.getResults()) - visit(result); - } else if (auto shapedType = type.dyn_cast()) { - visit(shapedType.getElementType()); - - // Visit affine maps in memref type. - if (auto memref = type.dyn_cast()) - for (auto map : memref.getAffineMaps()) - visit(AffineMapAttr::get(map)); + // Check for any sub elements. + if (auto subElementInterface = type.dyn_cast()) { + subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); }, + [&](Type type) { visit(type); }); } } diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -41,6 +41,17 @@ UnitAttr>(); } +//===----------------------------------------------------------------------===// +// ArrayAttr +//===----------------------------------------------------------------------===// + +void ArrayAttr::walkSubElementsImpl( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + for (Attribute attr : getValue()) + walkAttrsFn(attr); +} + //===----------------------------------------------------------------------===// // DictionaryAttr //===----------------------------------------------------------------------===// @@ -196,6 +207,13 @@ return Base::get(context, ArrayRef()); } +void DictionaryAttr::walkSubElementsImpl( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + for (Attribute attr : llvm::make_second_range(getValue())) + walkAttrsFn(attr); +} + //===----------------------------------------------------------------------===// // FloatAttr //===----------------------------------------------------------------------===// @@ -1346,3 +1364,12 @@ {&*std::next(sparseIndexValues.begin(), i * rank), rank})); return flatSparseIndices; } + +//===----------------------------------------------------------------------===// +// TypeAttr +//===----------------------------------------------------------------------===// + +void TypeAttr::walkSubElementsImpl(function_ref walkAttrsFn, + function_ref walkTypesFn) const { + walkTypesFn(getValue()); +} diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -193,6 +193,13 @@ return get(getContext(), newInputTypes, newResultTypes); } +void FunctionType::walkSubElementsImpl( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + for (Type type : llvm::concat(getInputs(), getResults())) + walkTypesFn(type); +} + //===----------------------------------------------------------------------===// // OpaqueType //===----------------------------------------------------------------------===// @@ -413,6 +420,12 @@ return VectorType(); } +void VectorType::walkSubElementsImpl( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + walkTypesFn(getElementType()); +} + //===----------------------------------------------------------------------===// // TensorType //===----------------------------------------------------------------------===// @@ -453,6 +466,12 @@ return checkTensorElementType(emitError, elementType); } +void RankedTensorType::walkSubElementsImpl( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + walkTypesFn(getElementType()); +} + //===----------------------------------------------------------------------===// // UnrankedTensorType //===----------------------------------------------------------------------===// @@ -463,6 +482,12 @@ return checkTensorElementType(emitError, elementType); } +void UnrankedTensorType::walkSubElementsImpl( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + walkTypesFn(getElementType()); +} + //===----------------------------------------------------------------------===// // BaseMemRefType //===----------------------------------------------------------------------===// @@ -606,6 +631,15 @@ return success(); } +void MemRefType::walkSubElementsImpl( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + walkTypesFn(getElementType()); + walkAttrsFn(getMemorySpace()); + for (AffineMap map : getAffineMaps()) + walkAttrsFn(AffineMapAttr::get(map)); +} + //===----------------------------------------------------------------------===// // UnrankedMemRefType //===----------------------------------------------------------------------===// @@ -773,6 +807,13 @@ return success(); } +void UnrankedMemRefType::walkSubElementsImpl( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + walkTypesFn(getElementType()); + walkAttrsFn(getMemorySpace()); +} + //===----------------------------------------------------------------------===// /// TupleType //===----------------------------------------------------------------------===// @@ -796,6 +837,13 @@ /// Return the number of element types. size_t TupleType::size() const { return getImpl()->size(); } +void TupleType::walkSubElementsImpl( + function_ref walkAttrsFn, + function_ref walkTypesFn) const { + for (Type type : getTypes()) + walkTypesFn(type); +} + //===----------------------------------------------------------------------===// // Type Utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/SubElementInterfaces.cpp b/mlir/lib/IR/SubElementInterfaces.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/IR/SubElementInterfaces.cpp @@ -0,0 +1,65 @@ +//===- SubElementInterfaces.cpp - Attr and Type SubElement Interfaces -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/SubElementInterfaces.h" + +using namespace mlir; + +template +static void walkSubElementsImpl(InterfaceT interface, + function_ref walkAttrsFn, + function_ref walkTypesFn) { + interface.walkSubElementsImpl( + [&](Attribute attr) { + // Guard against potentially null inputs. This removes the need for the + // derived attribute/type to do it. + if (!attr) + return; + + // Walk any sub elements first. + if (auto interface = attr.dyn_cast()) + walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn); + + // Walk this attribute. + walkAttrsFn(attr); + }, + [&](Type type) { + // Guard against potentially null inputs. This removes the need for the + // derived attribute/type to do it. + if (!type) + return; + + // Walk any sub elements first. + if (auto interface = type.dyn_cast()) + walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn); + + // Walk this type. + walkTypesFn(type); + }); +} + +void SubElementAttrInterface::walkSubElements( + function_ref walkAttrsFn, + function_ref walkTypesFn) { + assert(walkAttrsFn && walkTypesFn && "expected valid walk functions"); + ::walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn); +} + +void SubElementTypeInterface::walkSubElements( + function_ref walkAttrsFn, + function_ref walkTypesFn) { + assert(walkAttrsFn && walkTypesFn && "expected valid walk functions"); + ::walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn); +} + +//===----------------------------------------------------------------------===// +// SubElementInterface Tablegen definitions +//===----------------------------------------------------------------------===// + +#include "mlir/IR/SubElementAttrInterfaces.cpp.inc" +#include "mlir/IR/SubElementTypeInterfaces.cpp.inc"