diff --git a/mlir/include/mlir/IR/AttributeSupport.h b/mlir/include/mlir/IR/AttributeSupport.h --- a/mlir/include/mlir/IR/AttributeSupport.h +++ b/mlir/include/mlir/IR/AttributeSupport.h @@ -251,6 +251,75 @@ }; } // namespace detail +namespace AttributeTrait { + +/// This trait provides an "ArrayRef-like" behavior by defining all the methods +/// from ArrayRef. This expects the attribute to define an `getArray()` accessor +/// for the underlying ArrayRef. +template +struct ArrayRefAttr { + template + class Impl { + ArrayRef crtp() const { + return static_cast(this)->getArray(); + } + + public: + using value_type = T; + using pointer = value_type *; + using const_pointer = const value_type *; + using reference = value_type &; + using const_reference = const value_type &; + using iterator = const_pointer; + using const_iterator = const_pointer; + using reverse_iterator = std::reverse_iterator; + using const_reverse_iterator = std::reverse_iterator; + using size_type = size_t; + using difference_type = ptrdiff_t; + + operator ArrayRef() { return crtp(); } + + iterator begin() const { return crtp().begin(); } + iterator end() const { return crtp().end(); } + reverse_iterator rbegin() const { return crtp().rbegin(); } + reverse_iterator rend() const { return crtp().rend(); } + bool empty() const { return crtp().empty(); } + const T *data() const { return crtp().data(); } + size_t size() const { return crtp().size(); } + const T &front() const { return crtp().front(); } + const T &back() const { return crtp().back(); } + template + ArrayRef copy(Allocator &A) { + return crtp().copy(A); + } + bool equals(ArrayRefAttr RHS) const { return crtp().equals(RHS->crtp()); } + ArrayRef slice(size_t N, size_t M) const { return crtp().slice(N, M); } + ArrayRef slice(size_t N) const { return crtp().slice(N); } + ArrayRef drop_front(size_t N = 1) const { return crtp().drop_front(N); } + ArrayRef drop_back(size_t N = 1) const { return crtp().drop_back(N); } + template + ArrayRef drop_while(PredicateT Pred) const { + return crtp().drop_while(Pred); + } + template + ArrayRef drop_until(PredicateT Pred) const { + return crtp().drop_until(Pred); + } + ArrayRef take_front(size_t N = 1) const { return crtp().take_front(N); } + ArrayRef take_back(size_t N = 1) const { return crtp().take_back(N); } + template + ArrayRef take_while(PredicateT Pred) const { + return crtp().take_while(Pred); + } + template + ArrayRef take_until(PredicateT Pred) const { + return crtp().take_until(Pred); + } + const T &operator[](size_t Index) const { return (*crtp())[Index]; } + }; +}; + +} // namespace AttributeTrait } // namespace mlir #endif diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -142,6 +142,9 @@ ArrayAttr getStrArrayAttr(ArrayRef values); ArrayAttr getTypeArrayAttr(TypeRange values); + // Dimension list is a Dense array of int64_t values. + DimensionListAttr getDimListAttr(ArrayRef values); + // Affine expressions and affine maps. AffineExpr getAffineDimExpr(unsigned position); AffineExpr getAffineSymbolExpr(unsigned position); diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -339,6 +339,51 @@ let skipDefaultBuilders = 1; } + +//===----------------------------------------------------------------------===// +// DimensionListAttr +//===----------------------------------------------------------------------===// + +def Builtin_DimensionListAttr : Builtin_Attr< + "DimensionList", [ArrayRefAttrTrait<"int64_t">]> { + let mnemonic = "dims"; + let summary = "An Attribute containing a dense array of in64_t, suitable to " + "model a list of dimensions or a subscript for example."; + let description = [{ + Syntax: + + ``` + `dims` `<` `>` + ``` + + For example: `dims<42, -5>` or `dims<>`. + + When used embedded in an operation, it is displayed with an array syntax: + + ``` + `dims` `<` `>` + ``` + + For example: `test.op [42, -5]`. + }]; + let parameters = (ins ArrayRefParameter<"int64_t">:$array); + let returnType = "llvm::ArrayRef"; + let assemblyFormat = "`[` $array `]`"; +} + +class DimensionListAttrOfSize : Confined().size() == " #n>, + "with exactly " # n # " elements"> + ]>; + +class DimensionListAttrOfMaxSize : Confined().size() <= " #n>, + "with at most " # n # " elements"> + ]>; + + //===----------------------------------------------------------------------===// // DictionaryAttr //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -3168,4 +3168,8 @@ string derivedType = "::mlir::Type"> : AttrOrTypeParameter {} +// Attribute Trait for the native `ArrayRefAttr`. +class ArrayRefAttrTrait : ParamNativeAttrTrait<"ArrayRefAttr", T> {} + + #endif // OP_BASE diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1293,6 +1293,8 @@ /// used instead of individual elements when the elements attr is large. void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex); + void printDimensionListAttr(DimensionListAttr attr); + /// Print a dense string elements attribute. void printDenseStringElementsAttr(DenseStringElementsAttr attr); @@ -1726,6 +1728,11 @@ os << '>'; } + } else if (auto dimsList = attr.dyn_cast()) { + os << "dims<"; + printDimensionListAttr(dimsList); + os << '>'; + } else if (auto strEltAttr = attr.dyn_cast()) { if (printerFlags.shouldElideElementsAttr(strEltAttr)) { printElidedElementsAttr(os); @@ -1899,6 +1906,10 @@ } } +void AsmPrinter::Impl::printDimensionListAttr(DimensionListAttr attr) { + llvm::interleaveComma(attr, os); +} + void AsmPrinter::Impl::printDenseStringElementsAttr( DenseStringElementsAttr attr) { ArrayRef data = attr.getRawStringData(); diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -227,6 +227,10 @@ return getArrayAttr(attrs); } +DimensionListAttr Builder::getDimListAttr(ArrayRef values) { + return DimensionListAttr::get(getContext(), values); +} + ArrayAttr Builder::getIndexArrayAttr(ArrayRef values) { auto attrs = llvm::to_vector<8>( llvm::map_range(values, [this](int64_t v) -> Attribute { diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -11,13 +11,16 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/IntegerSet.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/IR/Operation.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Types.h" #include "mlir/Interfaces/DecodeAttributesInterfaces.h" #include "llvm/ADT/APSInt.h" #include "llvm/ADT/Sequence.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Endian.h" using namespace mlir; @@ -36,10 +39,10 @@ void BuiltinDialect::registerAttributes() { addAttributes(); + DimensionListAttr, DenseStringElementsAttr, DictionaryAttr, + FloatAttr, SymbolRefAttr, IntegerAttr, IntegerSetAttr, + OpaqueAttr, OpaqueElementsAttr, SparseElementsAttr, StringAttr, + TypeAttr, UnitAttr>(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -90,6 +90,10 @@ case Token::kw_dense: return parseDenseElementsAttr(type); + // Parse a dimension lists attribute. + case Token::kw_dims: + return parseDimensionListAttr(); + // Parse a dictionary attribute. case Token::l_brace: { NamedAttrList elements; @@ -831,6 +835,31 @@ return literalParser.getAttr(loc, type); } +Attribute Parser::parseDimensionListAttr() { + consumeToken(Token::kw_dims); + SmallVector dims; + auto parseElement = [&]() -> LogicalResult { + if (getToken().is(Token::greater)) + return success(); + bool negative = false; + if (getToken().is(Token::minus)) { + negative = true; + consumeToken(); + } + if (getToken().isNot(Token::integer)) + return emitError("expected integer in dimensions list"); + Optional value = getToken().getUInt64IntegerValue(); + if (!value || *value > std::numeric_limits::max()) + return emitError("integer overflow in dimensions list"); + dims.emplace_back(negative ? -*value : *value); + consumeToken(); + return success(); + }; + if (parseCommaSeparatedList(Delimiter::LessGreater, parseElement)) + return nullptr; + return DimensionListAttr::get(getContext(), dims); +} + /// Parse an opaque elements attribute. Attribute Parser::parseOpaqueElementsAttr(Type attrType) { consumeToken(Token::kw_opaque); diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h --- a/mlir/lib/Parser/Parser.h +++ b/mlir/lib/Parser/Parser.h @@ -259,6 +259,9 @@ Attribute parseDenseElementsAttr(Type attrType); ShapedType parseElementsLiteralType(Type type); + /// Parse a dimension lists attribute. + Attribute parseDimensionListAttr(); + /// Parse a sparse elements attribute. Attribute parseSparseElementsAttr(Type attrType); diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def --- a/mlir/lib/Parser/TokenKinds.def +++ b/mlir/lib/Parser/TokenKinds.def @@ -82,6 +82,7 @@ TOK_KEYWORD(ceildiv) TOK_KEYWORD(complex) TOK_KEYWORD(dense) +TOK_KEYWORD(dims) TOK_KEYWORD(f16) TOK_KEYWORD(f32) TOK_KEYWORD(f64) diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -603,6 +603,37 @@ return } +// ----- + +//===----------------------------------------------------------------------===// +// Test DimensionListAttr +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @dim_lists +func @dim_lists() { + // CHECK: test.dim_list_attr + // CHECK-SAME: attr1 = dims<6> + // CHECK-SAME: attr2 = dims<-1, 2, 3> + // CHECK-SAME: unregistered_attr = dims<> + "test.dim_list_attr"() { + attr1 = dims<6>, + attr2 = dims<-1, 2, 3>, + unregistered_attr = dims<> + } : () -> () + return +} + +// ----- + +// CHECK-LABEL: func @dim_list_attr_custom +func @dim_list_attr_custom() { + // CHECK: test.dim_list_attr_custom + // CHECK-SAME: theDims : [6] + test.dim_list_attr_custom theDims : [6] + return +} + + // ----- func @wrong_shape_fail() { diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -14,8 +14,8 @@ add_public_tablegen_target(MLIRTestInterfaceIncGen) set(LLVM_TARGET_DEFINITIONS TestOps.td) -mlir_tablegen(TestAttrDefs.h.inc -gen-attrdef-decls) -mlir_tablegen(TestAttrDefs.cpp.inc -gen-attrdef-defs) +mlir_tablegen(TestAttrDefs.h.inc -gen-attrdef-decls -attrdefs-dialect=test) +mlir_tablegen(TestAttrDefs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=test) add_public_tablegen_target(MLIRTestAttrDefIncGen) set(LLVM_TARGET_DEFINITIONS TestTypeDefs.td) diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -12,6 +12,7 @@ include "TestDialect.td" include "mlir/Dialect/DLTI/DLTIBase.td" include "mlir/IR/EnumAttr.td" +include "mlir/IR/BuiltinAttributes.td" include "mlir/IR/OpBase.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/RegionKindInterface.td" @@ -273,6 +274,20 @@ ); } +def DimListAttrOp : TEST_Op<"dim_list_attr"> { + let arguments = (ins + Builtin_DimensionListAttr:$attr1, + Builtin_DimensionListAttr:$attr2 + ); +} + +def DimListAttrCustomAsmOp : TEST_Op<"dim_list_attr_custom"> { + let arguments = (ins + Builtin_DimensionListAttr:$attr1 + ); + let assemblyFormat = "`theDims` `:` $attr1 attr-dict"; +} + def DerivedTypeAttrOp : TEST_Op<"derived_type_attr", []> { let results = (outs AnyTensor:$output); DerivedTypeAttr element_dtype = diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -146,11 +146,17 @@ strip_include_prefix = "lib/Dialect/Test", tbl_outs = [ ( - ["-gen-attrdef-decls"], + [ + "-gen-attrdef-decls", + "--attrdefs-dialect=test", + ], "lib/Dialect/Test/TestAttrDefs.h.inc", ), ( - ["-gen-attrdef-defs"], + [ + "-gen-attrdef-defs", + "--attrdefs-dialect=test", + ], "lib/Dialect/Test/TestAttrDefs.cpp.inc", ), ],