diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -925,33 +925,59 @@ } /// Parse a dense array attribute. -Attribute Parser::parseDenseArrayAttr(Type type) { +Attribute Parser::parseDenseArrayAttr(Type attrType) { consumeToken(Token::kw_array); if (parseToken(Token::less, "expected '<' after 'array'")) return {}; - // Only bool or integer and floating point elements divisible by bytes are - // supported. SMLoc typeLoc = getToken().getLoc(); - if (!type && !(type = parseType())) + Type eltType; + // If an attribute type was provided, use its element type. + if (attrType) { + auto tensorType = attrType.dyn_cast(); + if (!tensorType) { + emitError(typeLoc, "dense array attribute expected ranked tensor type"); + return {}; + } + eltType = tensorType.getElementType(); + + // Otherwise, parse a type. + } else if (!(eltType = parseType())) { return {}; - if (!type.isIntOrIndexOrFloat()) { - emitError(typeLoc, "expected integer or float type, got: ") << type; + } + + // Only bool or integer and floating point elements divisible by bytes are + // supported. + if (!eltType.isIntOrIndexOrFloat()) { + emitError(typeLoc, "expected integer or float type, got: ") << eltType; return {}; } - if (!type.isInteger(1) && type.getIntOrFloatBitWidth() % 8 != 0) { + if (!eltType.isInteger(1) && eltType.getIntOrFloatBitWidth() % 8 != 0) { emitError(typeLoc, "element type bitwidth must be a multiple of 8"); return {}; } + // If a type was provided, check that it matches the parsed type. + auto checkProvidedType = [&](DenseArrayAttr result) -> Attribute { + if (attrType && result.getType() != attrType) { + emitError(typeLoc, "expected attribute type ") + << attrType << " does not match parsed type " << result.getType(); + return {}; + } + return result; + }; + // Check for empty list. - if (consumeIf(Token::greater)) - return DenseArrayAttr::get(RankedTensorType::get(0, type), {}); - if (parseToken(Token::colon, "expected ':' after dense array type")) + if (consumeIf(Token::greater)) { + return checkProvidedType( + DenseArrayAttr::get(RankedTensorType::get(0, eltType), {})); + } + if (!attrType && + parseToken(Token::colon, "expected ':' after dense array type")) return {}; - DenseArrayElementParser eltParser(type); - if (type.isIntOrIndex()) { + DenseArrayElementParser eltParser(eltType); + if (eltType.isIntOrIndex()) { if (parseCommaSeparatedList( [&] { return eltParser.parseIntegerElement(*this); })) return {}; @@ -962,7 +988,7 @@ } if (parseToken(Token::greater, "expected '>' to close an array attribute")) return {}; - return eltParser.getAttr(); + return checkProvidedType(eltParser.getAttr()); } /// Parse a dense elements attribute. diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1864,13 +1864,16 @@ } else if (auto stridedLayoutAttr = attr.dyn_cast()) { stridedLayoutAttr.print(os); } else if (auto denseArrayAttr = attr.dyn_cast()) { - typeElision = AttrTypeElision::Must; - os << "array<" << denseArrayAttr.getType().getElementType(); + os << "array<"; + if (typeElision != AttrTypeElision::Must) + printType(denseArrayAttr.getType().getElementType()); if (!denseArrayAttr.empty()) { - os << ": "; + if (typeElision != AttrTypeElision::Must) + os << ": "; printDenseArrayAttr(denseArrayAttr); } os << ">"; + return; } else if (auto resourceAttr = attr.dyn_cast()) { os << "dense_resource<"; printResourceHandle(resourceAttr.getRawHandle()); diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -589,6 +589,9 @@ x6_bf16 = array, x7_f16 = array }: () -> () + + // CHECK: test.typed_attr tensor<4xi32> = array<1, 2, 3, 4> + test.typed_attr tensor<4xi32> = array<1, 2, 3, 4> return } diff --git a/mlir/test/IR/invalid-builtin-attributes.mlir b/mlir/test/IR/invalid-builtin-attributes.mlir --- a/mlir/test/IR/invalid-builtin-attributes.mlir +++ b/mlir/test/IR/invalid-builtin-attributes.mlir @@ -546,3 +546,18 @@ // expected-error@below {{expected '>' to close an array attribute}} #attr = array + +// ----- + +// expected-error@below {{does not match parsed type}} +test.typed_attr tensor<1xi32> = array<> + +// ----- + +// expected-error@below {{does not match parsed type}} +test.typed_attr tensor<0xi32> = array<1> diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -460,6 +460,22 @@ return {}; } +//===----------------------------------------------------------------------===// +// TypedAttrOp +//===----------------------------------------------------------------------===// + +/// Parse an attribute with a given type. +static ParseResult parseAttrElideType(AsmParser &parser, TypeAttr type, + Attribute &attr) { + return parser.parseAttribute(attr, type.getValue()); +} + +/// Print an attribute without its type. +static void printAttrElideType(AsmPrinter &printer, Operation *op, + TypeAttr type, Attribute attr) { + printer.printAttributeWithoutType(attr); +} + //===----------------------------------------------------------------------===// // TestBranchOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -270,6 +270,13 @@ ); } +def TypedAttrOp : TEST_Op<"typed_attr"> { + let arguments = (ins TypeAttr:$type, AnyAttr:$attr); + let assemblyFormat = [{ + attr-dict $type `=` custom(ref($type), $attr) + }]; +} + def DenseArrayAttrOp : TEST_Op<"dense_array_attr"> { let arguments = (ins DenseBoolArrayAttr:$i1attr,