diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -1066,12 +1066,12 @@ return nullptr; } - if (!type.isa()) { - emitError("elements literal must be a ranked tensor or vector type"); + auto sType = type.dyn_cast(); + if (!sType) { + emitError("elements literal must be a shaped type"); return nullptr; } - auto sType = type.cast(); if (!sType.hasStaticShape()) return (emitError("elements literal type must have static shape"), nullptr); diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -1381,8 +1381,6 @@ DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, ArrayRef data) { - assert((type.isa()) && - "type must be ranked tensor or vector"); assert(type.hasStaticShape() && "type must have static shape"); bool isSplat = false; bool isValid = isValidRawBuffer(type, data, isSplat); @@ -1498,16 +1496,7 @@ size_t bitWidth = getDenseElementBitWidth(newElementType); size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); - ShapedType newArrayType; - if (inType.isa()) - newArrayType = RankedTensorType::get(inType.getShape(), newElementType); - else if (inType.isa()) - newArrayType = RankedTensorType::get(inType.getShape(), newElementType); - else if (auto vType = inType.dyn_cast()) - newArrayType = VectorType::get(vType.getShape(), newElementType, - vType.getNumScalableDims()); - else - assert(newArrayType && "Unhandled tensor type"); + ShapedType newArrayType = inType.cloneWith(inType.getShape(), newElementType); size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); data.resize(llvm::divideCeil(storageBitWidth * numRawElements, CHAR_BIT)); diff --git a/mlir/test/IR/invalid-builtin-attributes.mlir b/mlir/test/IR/invalid-builtin-attributes.mlir --- a/mlir/test/IR/invalid-builtin-attributes.mlir +++ b/mlir/test/IR/invalid-builtin-attributes.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics func.func @elementsattr_non_tensor_type() -> () { - "foo"(){bar = dense<[4]> : i32} : () -> () // expected-error {{elements literal must be a ranked tensor or vector type}} + "foo"(){bar = dense<[4]> : i32} : () -> () // expected-error {{elements literal must be a shaped type}} } // -----