diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -142,6 +142,9 @@ ArrayAttr getStrArrayAttr(ArrayRef values); ArrayAttr getTypeArrayAttr(TypeRange values); + // Dimension list is a Dense array of int64_t values. + DimensionListAttr getDimListAttr(ArrayRef values); + // Affine expressions and affine maps. AffineExpr getAffineDimExpr(unsigned position); AffineExpr getAffineSymbolExpr(unsigned position); diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -339,6 +339,56 @@ let skipDefaultBuilders = 1; } + +//===----------------------------------------------------------------------===// +// DimensionListAttr +//===----------------------------------------------------------------------===// + +def Builtin_DimensionListAttr : Builtin_Attr<"DimensionList"> { + let mnemonic = "dims"; + let summary = "An Attribute containing a dense array of in64_t, suitable to " + "model a list of dimensions or a subscript for example."; + let description = [{ + Syntax: + + ``` + `dims` `<` `>` + ``` + + For example: `dims<42, -5>` or `dims<>`. + + When used embedded in an operation, it is displayed with an array syntax: + + ``` + `dims` `[` `]` + ``` + + For example: `test.op [42, -5]`. + }]; + let parameters = (ins ArrayRefParameter<"int64_t">:$array); + let returnType = "::llvm::ArrayRef"; + code convertFromStorage = "$_self.getArray()"; + let assemblyFormat = "`[` $array `]`"; + let extraClassDeclaration = [{ + // Convenient implicit conversion to allow using this attribute in APIs + // taking an ArrayRef. + operator ArrayRef() { return getArray(); } + }]; +} + +class DimensionListAttrOfSize : Confined().getArray().size() == " #n>, + "with exactly " # n # " elements"> + ]>; + +class DimensionListAttrOfMaxSize : Confined().getArray().size() <= " #n>, + "with at most " # n # " elements"> + ]>; + + //===----------------------------------------------------------------------===// // DictionaryAttr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1293,6 +1293,8 @@ /// used instead of individual elements when the elements attr is large. void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex); + void printDimensionListAttr(DimensionListAttr attr); + /// Print a dense string elements attribute. void printDenseStringElementsAttr(DenseStringElementsAttr attr); @@ -1726,6 +1728,11 @@ os << '>'; } + } else if (auto dimsList = attr.dyn_cast()) { + os << "dims<"; + printDimensionListAttr(dimsList); + os << '>'; + } else if (auto strEltAttr = attr.dyn_cast()) { if (printerFlags.shouldElideElementsAttr(strEltAttr)) { printElidedElementsAttr(os); @@ -1899,6 +1906,10 @@ } } +void AsmPrinter::Impl::printDimensionListAttr(DimensionListAttr attr) { + llvm::interleaveComma(attr.getArray(), os); +} + void AsmPrinter::Impl::printDenseStringElementsAttr( DenseStringElementsAttr attr) { ArrayRef data = attr.getRawStringData(); diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -227,6 +227,10 @@ return getArrayAttr(attrs); } +DimensionListAttr Builder::getDimListAttr(ArrayRef values) { + return DimensionListAttr::get(getContext(), values); +} + ArrayAttr Builder::getIndexArrayAttr(ArrayRef values) { auto attrs = llvm::to_vector<8>( llvm::map_range(values, [this](int64_t v) -> Attribute { diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -11,13 +11,16 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/IntegerSet.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/IR/Operation.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Types.h" #include "mlir/Interfaces/DecodeAttributesInterfaces.h" #include "llvm/ADT/APSInt.h" #include "llvm/ADT/Sequence.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Endian.h" using namespace mlir; @@ -36,10 +39,10 @@ void BuiltinDialect::registerAttributes() { addAttributes(); + DimensionListAttr, DenseStringElementsAttr, DictionaryAttr, + FloatAttr, SymbolRefAttr, IntegerAttr, IntegerSetAttr, + OpaqueAttr, OpaqueElementsAttr, SparseElementsAttr, StringAttr, + TypeAttr, UnitAttr>(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -90,6 +90,10 @@ case Token::kw_dense: return parseDenseElementsAttr(type); + // Parse a dimension lists attribute. + case Token::kw_dims: + return parseDimensionListAttr(); + // Parse a dictionary attribute. case Token::l_brace: { NamedAttrList elements; @@ -831,6 +835,31 @@ return literalParser.getAttr(loc, type); } +Attribute Parser::parseDimensionListAttr() { + consumeToken(Token::kw_dims); + SmallVector dims; + auto parseElement = [&]() -> LogicalResult { + if (getToken().is(Token::greater)) + return success(); + bool negative = false; + if (getToken().is(Token::minus)) { + negative = true; + consumeToken(); + } + if (getToken().isNot(Token::integer)) + return emitError("expected integer in dimensions list"); + Optional value = getToken().getUInt64IntegerValue(); + if (!value || *value > std::numeric_limits::max()) + return emitError("integer overflow in dimensions list"); + dims.emplace_back(negative ? -*value : *value); + consumeToken(); + return success(); + }; + if (parseCommaSeparatedList(Delimiter::LessGreater, parseElement)) + return nullptr; + return DimensionListAttr::get(getContext(), dims); +} + /// Parse an opaque elements attribute. Attribute Parser::parseOpaqueElementsAttr(Type attrType) { consumeToken(Token::kw_opaque); diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h --- a/mlir/lib/Parser/Parser.h +++ b/mlir/lib/Parser/Parser.h @@ -259,6 +259,9 @@ Attribute parseDenseElementsAttr(Type attrType); ShapedType parseElementsLiteralType(Type type); + /// Parse a dimension lists attribute. + Attribute parseDimensionListAttr(); + /// Parse a sparse elements attribute. Attribute parseSparseElementsAttr(Type attrType); diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def --- a/mlir/lib/Parser/TokenKinds.def +++ b/mlir/lib/Parser/TokenKinds.def @@ -82,6 +82,7 @@ TOK_KEYWORD(ceildiv) TOK_KEYWORD(complex) TOK_KEYWORD(dense) +TOK_KEYWORD(dims) TOK_KEYWORD(f16) TOK_KEYWORD(f32) TOK_KEYWORD(f64) diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -603,6 +603,37 @@ return } +// ----- + +//===----------------------------------------------------------------------===// +// Test DimensionListAttr +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @dim_lists +func @dim_lists() { + // CHECK: test.dim_list_attr + // CHECK-SAME: attr1 = dims<6> + // CHECK-SAME: attr2 = dims<-1, 2, 3> + // CHECK-SAME: unregistered_attr = dims<> + "test.dim_list_attr"() { + attr1 = dims<6>, + attr2 = dims<-1, 2, 3>, + unregistered_attr = dims<> + } : () -> () + return +} + +// ----- + +// CHECK-LABEL: func @dim_list_attr_custom +func @dim_list_attr_custom() { + // CHECK: test.dim_list_attr_custom + // CHECK-SAME: theDims : [6] + test.dim_list_attr_custom theDims : [6] + return +} + + // ----- func @wrong_shape_fail() { diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -14,8 +14,8 @@ add_public_tablegen_target(MLIRTestInterfaceIncGen) set(LLVM_TARGET_DEFINITIONS TestOps.td) -mlir_tablegen(TestAttrDefs.h.inc -gen-attrdef-decls) -mlir_tablegen(TestAttrDefs.cpp.inc -gen-attrdef-defs) +mlir_tablegen(TestAttrDefs.h.inc -gen-attrdef-decls -attrdefs-dialect=test) +mlir_tablegen(TestAttrDefs.cpp.inc -gen-attrdef-defs -attrdefs-dialect=test) add_public_tablegen_target(MLIRTestAttrDefIncGen) set(LLVM_TARGET_DEFINITIONS TestTypeDefs.td) diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -12,6 +12,7 @@ include "TestDialect.td" include "mlir/Dialect/DLTI/DLTIBase.td" include "mlir/IR/EnumAttr.td" +include "mlir/IR/BuiltinAttributes.td" include "mlir/IR/OpBase.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/RegionKindInterface.td" @@ -273,6 +274,20 @@ ); } +def DimListAttrOp : TEST_Op<"dim_list_attr"> { + let arguments = (ins + Builtin_DimensionListAttr:$attr1, + Builtin_DimensionListAttr:$attr2 + ); +} + +def DimListAttrCustomAsmOp : TEST_Op<"dim_list_attr_custom"> { + let arguments = (ins + Builtin_DimensionListAttr:$attr1 + ); + let assemblyFormat = "`theDims` `:` $attr1 attr-dict"; +} + def DerivedTypeAttrOp : TEST_Op<"derived_type_attr", []> { let results = (outs AnyTensor:$output); DerivedTypeAttr element_dtype = diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -146,11 +146,17 @@ strip_include_prefix = "lib/Dialect/Test", tbl_outs = [ ( - ["-gen-attrdef-decls"], + [ + "-gen-attrdef-decls", + "--attrdefs-dialect=test", + ], "lib/Dialect/Test/TestAttrDefs.h.inc", ), ( - ["-gen-attrdef-defs"], + [ + "-gen-attrdef-defs", + "--attrdefs-dialect=test", + ], "lib/Dialect/Test/TestAttrDefs.cpp.inc", ), ],