diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -918,33 +918,47 @@ } /// Parse a dense array attribute. -Attribute Parser::parseDenseArrayAttr(Type type) { +Attribute Parser::parseDenseArrayAttr(Type attrType) { consumeToken(Token::kw_array); if (parseToken(Token::less, "expected '<' after 'array'")) return {}; - // Only bool or integer and floating point elements divisible by bytes are - // supported. SMLoc typeLoc = getToken().getLoc(); - if (!type && !(type = parseType())) + Type eltType; + // If an attribute type was provided, use its element type. + if (attrType) { + auto tensorType = attrType.dyn_cast(); + if (!tensorType) { + emitError(typeLoc, "dense array attribute expected ranked tensor type"); + return {}; + } + eltType = tensorType.getElementType(); + + // Otherwise, parse a type. + } else if (!(eltType = parseType())) { return {}; - if (!type.isIntOrIndexOrFloat()) { - emitError(typeLoc, "expected integer or float type, got: ") << type; + } + + // Only bool or integer and floating point elements divisible by bytes are + // supported. + if (!eltType.isIntOrIndexOrFloat()) { + emitError(typeLoc, "expected integer or float type, got: ") << eltType; return {}; } - if (!type.isInteger(1) && type.getIntOrFloatBitWidth() % 8 != 0) { + if (!eltType.isInteger(1) && eltType.getIntOrFloatBitWidth() % 8 != 0) { emitError(typeLoc, "element type bitwidth must be a multiple of 8"); return {}; } // Check for empty list. if (consumeIf(Token::greater)) - return DenseArrayAttr::get(RankedTensorType::get(0, type), {}); - if (parseToken(Token::colon, "expected ':' after dense array type")) + return DenseArrayAttr::get(RankedTensorType::get(0, eltType), {}); + if (!attrType && + parseToken(Token::colon, "expected ':' after dense array type")) return {}; - DenseArrayElementParser eltParser(type); - if (type.isIntOrIndex()) { + DenseArrayElementParser eltParser(eltType); + if (eltType.isIntOrIndex()) { if (parseCommaSeparatedList( [&] { return eltParser.parseIntegerElement(*this); })) return {}; @@ -953,9 +967,15 @@ [&] { return eltParser.parseFloatElement(*this); })) return {}; } + DenseArrayAttr result = eltParser.getAttr(); + if (attrType && result.getType() != attrType) { + emitError(typeLoc, "expected attribute type ") + << attrType << " does not match parsed type " << result.getType(); + return {}; + } if (parseToken(Token::greater, "expected '>' to close an array attribute")) return {}; - return eltParser.getAttr(); + return result; } /// Parse a dense elements attribute. diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1862,13 +1862,16 @@ os << '>'; } } else if (auto denseArrayAttr = attr.dyn_cast()) { - typeElision = AttrTypeElision::Must; - os << "array<" << denseArrayAttr.getType().getElementType(); + os << "array<"; + if (typeElision != AttrTypeElision::Must) + os << denseArrayAttr.getType().getElementType(); if (!denseArrayAttr.empty()) { - os << ": "; + if (typeElision != AttrTypeElision::Must) + os << ": "; printDenseArrayAttr(denseArrayAttr); } os << ">"; + return; } else if (auto resourceAttr = attr.dyn_cast()) { os << "dense_resource<"; printResourceHandle(resourceAttr.getRawHandle());