diff --git a/mlir/include/mlir/Dialect/CommonFolders.h b/mlir/include/mlir/Dialect/CommonFolders.h --- a/mlir/include/mlir/Dialect/CommonFolders.h +++ b/mlir/include/mlir/Dialect/CommonFolders.h @@ -189,7 +189,7 @@ return {}; elementResults.push_back(*elementResult); } - return DenseElementsAttr::get(op.getType(), elementResults); + return DenseElementsAttr::get(op.getShapedType(), elementResults); } return {}; } diff --git a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td --- a/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinAttributeInterfaces.td @@ -15,11 +15,30 @@ include "mlir/IR/OpBase.td" +//===----------------------------------------------------------------------===// +// TypedAttrInterface +//===----------------------------------------------------------------------===// + +def TypedAttrInterface : AttrInterface<"TypedAttr"> { + let cppNamespace = "::mlir"; + + let description = [{ + This interface is used for attributes that have a type. The type of an + attribute is understood to represent the type of the data contained in the + attribute and is often used as the type of a value with this data. + }]; + + let methods = [InterfaceMethod< + "Get the attribute's type", + "::mlir::Type", "getType" + >]; +} + //===----------------------------------------------------------------------===// // ElementsAttrInterface //===----------------------------------------------------------------------===// -def ElementsAttrInterface : AttrInterface<"ElementsAttr"> { +def ElementsAttrInterface : AttrInterface<"ElementsAttr", [TypedAttrInterface]> { let cppNamespace = "::mlir"; let description = [{ This interface is used for attributes that contain the constant elements of @@ -78,7 +97,7 @@ } /// * Attribute auto value_begin_impl(OverloadToken) const { - mlir::Type elementType = getType().getElementType(); + mlir::Type elementType = getShapedType().getElementType(); auto it = llvm::map_range(getElements(), [=](uint64_t value) { return mlir::IntegerAttr::get(elementType, llvm::APInt(/*numBits=*/64, value)); @@ -154,13 +173,15 @@ InterfaceMethod<[{ Returns true if the attribute elements correspond to a splat, i.e. that all elements of the attribute are the same value. - }], "bool", "isSplat", (ins), /*defaultImplementation=*/[{}], [{ + }], "bool", "isSplat", (ins), [{}], /*defaultImplementation=*/[{ // By default, only check for a single element splat. return $_attr.getNumElements() == 1; }]>, InterfaceMethod<[{ Returns the shaped type of the elements attribute. - }], "::mlir::ShapedType", "getType"> + }], "::mlir::ShapedType", "getShapedType", (ins), [{}], /*defaultImplementation=*/[{ + return $_attr.getType(); + }]> ]; string ElementsAttrInterfaceAccessors = [{ @@ -325,7 +346,7 @@ ArrayRef index); static uint64_t getFlattenedIndex(ElementsAttr elementsAttr, ArrayRef index) { - return getFlattenedIndex(elementsAttr.getType(), index); + return getFlattenedIndex(elementsAttr.getShapedType(), index); } /// Returns the number of elements held by this attribute. @@ -357,7 +378,7 @@ /// Return the elements of this attribute as a value of type 'T'. template DefaultValueCheckT> getValues() const { - return {getType(), value_begin(), value_end()}; + return {getShapedType(), value_begin(), value_end()}; } template DefaultValueCheckT> value_begin() const; @@ -377,7 +398,7 @@ template > DerivedAttrValueIteratorRange getValues() const { auto castFn = [](Attribute attr) { return attr.template cast(); }; - return {getType(), llvm::map_range(getValues(), + return {getShapedType(), llvm::map_range(getValues(), static_cast(castFn))}; } template > @@ -397,7 +418,7 @@ template DefaultValueCheckT>> tryGetValues() const { if (std::optional> beginIt = try_value_begin()) - return iterator_range(getType(), *beginIt, value_end()); + return iterator_range(getShapedType(), *beginIt, value_end()); return std::nullopt; } template @@ -413,7 +434,7 @@ auto castFn = [](Attribute attr) { return attr.template cast(); }; return DerivedAttrValueIteratorRange( - getType(), + getShapedType(), llvm::map_range(*values, static_cast(castFn)) ); } @@ -474,23 +495,4 @@ ]; } -//===----------------------------------------------------------------------===// -// TypedAttrInterface -//===----------------------------------------------------------------------===// - -def TypedAttrInterface : AttrInterface<"TypedAttr"> { - let cppNamespace = "::mlir"; - - let description = [{ - This interface is used for attributes that have a type. The type of an - attribute is understood to represent the type of the data contained in the - attribute and is often used as the type of a value with this data. - }]; - - let methods = [InterfaceMethod< - "Get the attribute's type", - "::mlir::Type", "getType" - >]; -} - #endif // MLIR_IR_BUILTINATTRIBUTEINTERFACES_TD_ diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -218,7 +218,7 @@ //===----------------------------------------------------------------------===// def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr< - "DenseIntOrFPElements", [ElementsAttrInterface, TypedAttrInterface], + "DenseIntOrFPElements", [ElementsAttrInterface], "DenseElementsAttr" > { let summary = "An Attribute containing a dense multi-dimensional array of " @@ -359,7 +359,7 @@ //===----------------------------------------------------------------------===// def Builtin_DenseStringElementsAttr : Builtin_Attr< - "DenseStringElements", [ElementsAttrInterface, TypedAttrInterface], + "DenseStringElements", [ElementsAttrInterface], "DenseElementsAttr" > { let summary = "An Attribute containing a dense multi-dimensional array of " @@ -430,7 +430,7 @@ //===----------------------------------------------------------------------===// def Builtin_DenseResourceElementsAttr : Builtin_Attr<"DenseResourceElements", [ - ElementsAttrInterface, TypedAttrInterface + ElementsAttrInterface ]> { let summary = "An Attribute containing a dense multi-dimensional array " "backed by a resource"; @@ -804,7 +804,7 @@ //===----------------------------------------------------------------------===// def Builtin_SparseElementsAttr : Builtin_Attr< - "SparseElements", [ElementsAttrInterface, TypedAttrInterface] + "SparseElements", [ElementsAttrInterface] > { let summary = "An opaque representation of a multi-dimensional array"; let description = [{ diff --git a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp --- a/mlir/lib/IR/BuiltinAttributeInterfaces.cpp +++ b/mlir/lib/IR/BuiltinAttributeInterfaces.cpp @@ -25,11 +25,11 @@ //===----------------------------------------------------------------------===// Type ElementsAttr::getElementType(ElementsAttr elementsAttr) { - return elementsAttr.getType().getElementType(); + return elementsAttr.getShapedType().getElementType(); } int64_t ElementsAttr::getNumElements(ElementsAttr elementsAttr) { - return elementsAttr.getType().getNumElements(); + return elementsAttr.getShapedType().getNumElements(); } bool ElementsAttr::isValidIndex(ShapedType type, ArrayRef index) { @@ -49,7 +49,7 @@ } bool ElementsAttr::isValidIndex(ElementsAttr elementsAttr, ArrayRef index) { - return isValidIndex(elementsAttr.getType(), index); + return isValidIndex(elementsAttr.getShapedType(), index); } uint64_t ElementsAttr::getFlattenedIndex(Type type, ArrayRef index) { diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -407,8 +407,8 @@ // Fall back to element-by-element construction otherwise. if (auto elementsAttr = attr.dyn_cast()) { - assert(elementsAttr.getType().hasStaticShape()); - assert(!elementsAttr.getType().getShape().empty() && + assert(elementsAttr.getShapedType().hasStaticShape()); + assert(!elementsAttr.getShapedType().getShape().empty() && "unexpected empty elements attribute shape"); SmallVector constants; @@ -422,7 +422,7 @@ } ArrayRef constantsRef = constants; llvm::Constant *result = buildSequentialConstant( - constantsRef, elementsAttr.getType().getShape(), llvmType, loc); + constantsRef, elementsAttr.getShapedType().getShape(), llvmType, loc); assert(constantsRef.empty() && "did not consume all elemental constants"); return result; } diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp --- a/mlir/lib/Transforms/ViewOpGraph.cpp +++ b/mlir/lib/Transforms/ViewOpGraph.cpp @@ -153,8 +153,8 @@ // Elide "big" elements attributes. auto elements = attr.dyn_cast(); if (elements && elements.getNumElements() > largeAttrLimit) { - os << std::string(elements.getType().getRank(), '[') << "..." - << std::string(elements.getType().getRank(), ']') << " : " + os << std::string(elements.getShapedType().getRank(), '[') << "..." + << std::string(elements.getShapedType().getRank(), ']') << " : " << elements.getType(); return; } diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -81,9 +81,7 @@ } // Test support for ElementsAttrInterface. -def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [ - ElementsAttrInterface, TypedAttrInterface - ]> { +def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [ElementsAttrInterface]> { let mnemonic = "i64_elements"; let parameters = (ins AttributeSelfTypeParameter<"", "::mlir::ShapedType">:$type, @@ -269,9 +267,7 @@ } // Test simple extern 1D vector using ElementsAttrInterface. -def TestExtern1DI64ElementsAttr : Test_Attr<"TestExtern1DI64Elements", [ - ElementsAttrInterface, TypedAttrInterface - ]> { +def TestExtern1DI64ElementsAttr : Test_Attr<"TestExtern1DI64Elements", [ElementsAttrInterface]> { let mnemonic = "e1di64_elements"; let parameters = (ins AttributeSelfTypeParameter<"", "::mlir::ShapedType">:$type,