diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -337,6 +337,27 @@ let skipDefaultBuilders = 1; } + +//===----------------------------------------------------------------------===// +// DimensionListAttr +//===----------------------------------------------------------------------===// + +def Builtin_DimensionListAttr : Builtin_Attr< + "DimensionList"> { + let mnemonic = "dims"; + let summary = "An Attribute containing a dense array of in64_t, suitable to " + "model a list of dimensions or a subscript."; + let description = [{ + Syntax: + + ``` + `dims` `<` `>` + ``` + }]; + let parameters = (ins ArrayRefParameter<"int64_t">:$dims); + let assemblyFormat = "`[` $dims `]`"; +} + //===----------------------------------------------------------------------===// // DictionaryAttr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1293,6 +1293,8 @@ /// used instead of individual elements when the elements attr is large. void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex); + void printDimensionListAttr(DimensionListAttr attr); + /// Print a dense string elements attribute. void printDenseStringElementsAttr(DenseStringElementsAttr attr); @@ -1726,6 +1728,11 @@ os << '>'; } + } else if (auto dimsList = attr.dyn_cast()) { + os << "dims<"; + printDimensionListAttr(dimsList); + os << '>'; + } else if (auto strEltAttr = attr.dyn_cast()) { if (printerFlags.shouldElideElementsAttr(strEltAttr)) { printElidedElementsAttr(os); @@ -1899,6 +1906,10 @@ } } +void AsmPrinter::Impl::printDimensionListAttr(DimensionListAttr attr) { + llvm::interleaveComma(attr.getDims(), os); +} + void AsmPrinter::Impl::printDenseStringElementsAttr( DenseStringElementsAttr attr) { ArrayRef data = attr.getRawStringData(); diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -11,13 +11,16 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/IntegerSet.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/IR/Operation.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Types.h" #include "mlir/Interfaces/DecodeAttributesInterfaces.h" #include "llvm/ADT/APSInt.h" #include "llvm/ADT/Sequence.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Endian.h" using namespace mlir; @@ -36,10 +39,10 @@ void BuiltinDialect::registerAttributes() { addAttributes(); + DimensionListAttr, DenseStringElementsAttr, DictionaryAttr, + FloatAttr, SymbolRefAttr, IntegerAttr, IntegerSetAttr, + OpaqueAttr, OpaqueElementsAttr, SparseElementsAttr, StringAttr, + TypeAttr, UnitAttr>(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -90,6 +90,10 @@ case Token::kw_dense: return parseDenseElementsAttr(type); + // Parse a dimension lists attribute. + case Token::kw_dims: + return parseDimensionListAttr(); + // Parse a dictionary attribute. case Token::l_brace: { NamedAttrList elements; @@ -830,6 +834,30 @@ return literalParser.getAttr(loc, type); } +Attribute Parser::parseDimensionListAttr() { + consumeToken(Token::kw_dims); + SmallVector dims; + if (parseCommaSeparatedList(Delimiter::LessGreater, [&] { + if (getToken().is(Token::greater)) + return success(); + bool negative = false; + if (getToken().is(Token::minus)) { + negative = true; + consumeToken(); + } + if (getToken().isNot(Token::integer)) + return (emitError("expected integer in dimensions list"), failure()); + Optional value = getToken().getUInt64IntegerValue(); + if (!value || *value > std::numeric_limits::max()) + return (emitError("integer overflow in dimensions list"), failure()); + dims.emplace_back(negative ? -*value : *value); + consumeToken(); + return success(); + })) + return nullptr; + return DimensionListAttr::get(getContext(), dims); +} + /// Parse an opaque elements attribute. Attribute Parser::parseOpaqueElementsAttr(Type attrType) { consumeToken(Token::kw_opaque); diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h --- a/mlir/lib/Parser/Parser.h +++ b/mlir/lib/Parser/Parser.h @@ -259,6 +259,9 @@ Attribute parseDenseElementsAttr(Type attrType); ShapedType parseElementsLiteralType(Type type); + /// Parse a dimension lists attribute. + Attribute parseDimensionListAttr(); + /// Parse a sparse elements attribute. Attribute parseSparseElementsAttr(Type attrType); diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def --- a/mlir/lib/Parser/TokenKinds.def +++ b/mlir/lib/Parser/TokenKinds.def @@ -82,6 +82,7 @@ TOK_KEYWORD(ceildiv) TOK_KEYWORD(complex) TOK_KEYWORD(dense) +TOK_KEYWORD(dims) TOK_KEYWORD(f16) TOK_KEYWORD(f32) TOK_KEYWORD(f64) diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -557,6 +557,37 @@ return } +// ----- + +//===----------------------------------------------------------------------===// +// Test DimensionListAttr +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @dim_lists +func @dim_lists() { + // CHECK: test.dim_list_attr + // CHECK-SAME: attr1 = dims<6> + // CHECK-SAME: attr2 = dims<-1, 2, 3> + // CHECK-SAME: unregistered_attr = dims<> + "test.dim_list_attr"() { + attr1 = dims<6>, + attr2 = dims<-1, 2, 3>, + unregistered_attr = dims<> + } : () -> () + return +} + +// ----- + +// CHECK-LABEL: func @dim_list_attr_custom +func @dim_list_attr_custom() { + // CHECK: test.dim_list_attr_custom + // CHECK-SAME: theDims : [6] + test.dim_list_attr_custom theDims : [6] + return +} + + // ----- func @wrong_shape_fail() { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -12,6 +12,7 @@ include "TestDialect.td" include "mlir/Dialect/DLTI/DLTIBase.td" include "mlir/IR/EnumAttr.td" +include "mlir/IR/BuiltinAttributes.td" include "mlir/IR/OpBase.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/RegionKindInterface.td" @@ -273,6 +274,20 @@ ); } +def DimListAttrOp : TEST_Op<"dim_list_attr"> { + let arguments = (ins + Builtin_DimensionListAttr:$attr1, + Builtin_DimensionListAttr:$attr2 + ); +} + +def DimListAttrCustomAsmOp : TEST_Op<"dim_list_attr_custom"> { + let arguments = (ins + Builtin_DimensionListAttr:$attr1 + ); + let assemblyFormat = "`theDims` `:` $attr1 attr-dict"; +} + def DerivedTypeAttrOp : TEST_Op<"derived_type_attr", []> { let results = (outs AnyTensor:$output); DerivedTypeAttr element_dtype = diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -146,11 +146,17 @@ strip_include_prefix = "lib/Dialect/Test", tbl_outs = [ ( - ["-gen-attrdef-decls"], + [ + "-gen-attrdef-decls", + "--attrdefs-dialect=test", + ], "lib/Dialect/Test/TestAttrDefs.h.inc", ), ( - ["-gen-attrdef-defs"], + [ + "-gen-attrdef-defs", + "--attrdefs-dialect=test", + ], "lib/Dialect/Test/TestAttrDefs.cpp.inc", ), ],