diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -560,6 +560,11 @@ /// same total number of elements as well as element type. DenseElementsAttr reshape(ShapedType newType); + /// Return a new DenseElementsAttr that has the same data as the current + /// attribute, but has bitcast elements to 'newElType'. The new type must have + /// the same bitwidth as the current element type. + DenseElementsAttr bitcast(Type newElType); + /// Generates a new DenseElementsAttr by mapping each int value to a new /// underlying APInt. The new values can represent either an integer or float. /// This underlying type must be an DenseIntElementsAttr. diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -512,16 +512,8 @@ Type resType = getResult().getType(); - if (auto denseAttr = operand.dyn_cast()) { - Type elType = getElementTypeOrSelf(resType); - return denseAttr.mapValues( - elType, [](const APFloat &f) { return f.bitcastToAPInt(); }); - } - if (auto denseAttr = operand.dyn_cast()) { - Type elType = getElementTypeOrSelf(resType); - // mapValues does its own bitcast to the target type. - return denseAttr.mapValues(elType, [](const APInt &i) { return i; }); - } + if (auto denseAttr = operand.dyn_cast()) + return denseAttr.bitcast(resType.cast().getElementType()); APInt bits; if (auto floatAttr = operand.dyn_cast()) diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -1025,6 +1025,7 @@ return *this; (void)curType; + assert(newType.getElementType() && "excepted integer or float element type"); assert(newType.getElementType() == curType.getElementType() && "expected the same element type"); assert(newType.getNumElements() == curType.getNumElements() && @@ -1032,6 +1033,25 @@ return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat()); } +/// Return a new DenseElementsAttr that has the same data as the current +/// attribute, but has bitcast elements such that it is now 'newType'. The new +/// type must have the same shape and element types of the same bitwidth as the +/// current type. +DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) { + ShapedType curType = getType(); + Type curElType = curType.getElementType(); + if (curElType == newElType) + return *this; + + (void)curElType; + assert(newElType.isIntOrFloat() && "excepted integer or float element type"); + assert(newElType.getIntOrFloatBitWidth() == + curElType.getIntOrFloatBitWidth() && + "expected element types with the same bitwidth"); + return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType), + getRawData(), isSplat()); +} + DenseElementsAttr DenseElementsAttr::mapValues(Type newElementType, function_ref mapping) const { diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -445,10 +445,14 @@ return emitError() << "vector types must have at least one dimension"; if (!isValidElementType(elementType)) - return emitError() << "vector elements must be int/index/float type"; + return emitError() + << "vector elements must be int/index/float type but got " + << elementType; if (any_of(shape, [](int64_t i) { return i <= 0; })) - return emitError() << "vector types must have positive constant sizes"; + return emitError() + << "vector types must have positive constant sizes but got " + << llvm::to_vector<4>(shape); return success(); }