diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -781,18 +781,14 @@ void printWithoutBraces(raw_ostream &os) const; /// Parse the short form `[42, 100, -1]` without any type prefix. - static Attribute parse(AsmParser &parser, Type odsType); + static Attribute parse(AsmParser &parser, Type type); /// Parse the short form `42, 100, -1` without any type prefix or braces. - static Attribute parseWithoutBraces(AsmParser &parser, Type odsType); + static Attribute parseWithoutBraces(AsmParser &parser, Type type); /// Support for isa<>/cast<>. static bool classof(Attribute attr); }; -template <> -void DenseArrayAttr::printWithoutBraces(raw_ostream &os) const; -template <> -void DenseArrayAttr::printWithoutBraces(raw_ostream &os) const; extern template class DenseArrayAttr; extern template class DenseArrayAttr; diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -188,12 +188,8 @@ ``` }]; let parameters = (ins AttributeSelfTypeParameter<"", "ShapedType">:$type, - "DenseArrayBaseAttr::EltType":$elementType, Builtin_DenseArrayRawDataParameter:$rawData); let extraClassDeclaration = [{ - // All possible supported element type. - enum class EltType { I1, I8, I16, I32, I64, F32, F64 }; - /// Allow implicit conversion to ElementsAttr. operator ElementsAttr() const { return *this ? cast() : nullptr; diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/Types.h" #include "llvm/ADT/APSInt.h" #include "llvm/ADT/Sequence.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Endian.h" using namespace mlir; @@ -718,30 +719,10 @@ } void DenseArrayBaseAttr::printWithoutBraces(raw_ostream &os) const { - switch (getElementType()) { - case DenseArrayBaseAttr::EltType::I1: - this->cast().printWithoutBraces(os); - return; - case DenseArrayBaseAttr::EltType::I8: - this->cast().printWithoutBraces(os); - return; - case DenseArrayBaseAttr::EltType::I16: - this->cast().printWithoutBraces(os); - return; - case DenseArrayBaseAttr::EltType::I32: - this->cast().printWithoutBraces(os); - return; - case DenseArrayBaseAttr::EltType::I64: - this->cast().printWithoutBraces(os); - return; - case DenseArrayBaseAttr::EltType::F32: - this->cast().printWithoutBraces(os); - return; - case DenseArrayBaseAttr::EltType::F64: - this->cast().printWithoutBraces(os); - return; - } - llvm_unreachable(""); + llvm::TypeSwitch(*this) + .Case([&](auto attr) { attr.printWithoutBraces(os); }); } void DenseArrayBaseAttr::print(raw_ostream &os) const { @@ -750,6 +731,89 @@ os << "]"; } +namespace { +/// Instantiations of this class provide utilities for interacting with native +/// data types in the context of DenseArrayAttr. +template +struct DenseArrayAttrIntUtil { + static bool checkElementType(Type eltType) { + auto type = eltType.dyn_cast(); + if (!type || type.getWidth() != width) + return false; + return type.getSignedness() == signedness; + } + + static Type getElementType(MLIRContext *ctx) { + return IntegerType::get(ctx, width, signedness); + } + + template + static void printElement(raw_ostream &os, T value) { + os << value; + } + + template + static ParseResult parseElement(AsmParser &parser, T &value) { + return parser.parseInteger(value); + } +}; +template +struct DenseArrayAttrUtil; + +/// Specialization for boolean elements to print 'true' and 'false' literals for +/// elements. +template <> +struct DenseArrayAttrUtil : public DenseArrayAttrIntUtil<1> { + static void printElement(raw_ostream &os, bool value) { + os << (value ? "true" : "false"); + } +}; + +/// Specialization for 8-bit integers to ensure values are printed as integers +/// and not characters. +template <> +struct DenseArrayAttrUtil : public DenseArrayAttrIntUtil<8> { + static void printElement(raw_ostream &os, int8_t value) { + os << static_cast(value); + } +}; +template <> +struct DenseArrayAttrUtil : public DenseArrayAttrIntUtil<16> {}; +template <> +struct DenseArrayAttrUtil : public DenseArrayAttrIntUtil<32> {}; +template <> +struct DenseArrayAttrUtil : public DenseArrayAttrIntUtil<64> {}; + +/// Specialization for 32-bit floats. +template <> +struct DenseArrayAttrUtil { + static bool checkElementType(Type eltType) { return eltType.isF32(); } + static Type getElementType(MLIRContext *ctx) { return Float32Type::get(ctx); } + static void printElement(raw_ostream &os, float value) { os << value; } + + /// Parse a double and cast it to a float. + static ParseResult parseElement(AsmParser &parser, float &value) { + double doubleVal; + if (parser.parseFloat(doubleVal)) + return failure(); + value = doubleVal; + return success(); + } +}; + +/// Specialization for 64-bit floats. +template <> +struct DenseArrayAttrUtil { + static bool checkElementType(Type eltType) { return eltType.isF64(); } + static Type getElementType(MLIRContext *ctx) { return Float64Type::get(ctx); } + static void printElement(raw_ostream &os, float value) { os << value; } + static ParseResult parseElement(AsmParser &parser, double &value) { + return parser.parseFloat(value); + } +}; +} // namespace + template void DenseArrayAttr::print(AsmPrinter &printer) const { print(printer.getStream()); @@ -757,20 +821,9 @@ template void DenseArrayAttr::printWithoutBraces(raw_ostream &os) const { - llvm::interleaveComma(asArrayRef(), os); -} - -/// Specialization for bool to print `true` or `false`. -template <> -void DenseArrayAttr::printWithoutBraces(raw_ostream &os) const { - llvm::interleaveComma(asArrayRef(), os, - [&](bool v) { os << (v ? "true" : "false"); }); -} - -/// Specialization for int8_t for forcing printing as number instead of chars. -template <> -void DenseArrayAttr::printWithoutBraces(raw_ostream &os) const { - llvm::interleaveComma(asArrayRef(), os, [&](int64_t v) { os << v; }); + llvm::interleaveComma(asArrayRef(), os, [&](T value) { + DenseArrayAttrUtil::printElement(os, value); + }); } template @@ -780,27 +833,6 @@ os << "]"; } -/// Parse a single element: generic template for int types, specialized for -/// floating point and boolean values below. -template -static ParseResult parseDenseArrayAttrElt(AsmParser &parser, T &value) { - return parser.parseInteger(value); -} - -template <> -ParseResult parseDenseArrayAttrElt(AsmParser &parser, float &value) { - double doubleVal; - if (parser.parseFloat(doubleVal)) - return failure(); - value = doubleVal; - return success(); -} - -template <> -ParseResult parseDenseArrayAttrElt(AsmParser &parser, double &value) { - return parser.parseFloat(value); -} - /// Parse a DenseArrayAttr without the braces: `1, 2, 3` template Attribute DenseArrayAttr::parseWithoutBraces(AsmParser &parser, @@ -808,7 +840,7 @@ SmallVector data; if (failed(parser.parseCommaSeparatedList([&]() { T value; - if (parseDenseArrayAttrElt(parser, value)) + if (DenseArrayAttrUtil::parseElement(parser, value)) return failure(); data.push_back(value); return success(); @@ -840,87 +872,23 @@ raw.size() / sizeof(T)); } -namespace { -/// Mapping from C++ element type to MLIR DenseArrayAttr internals. -template -struct denseArrayAttrEltTypeBuilder; -template <> -struct denseArrayAttrEltTypeBuilder { - constexpr static auto eltType = DenseArrayBaseAttr::EltType::I1; - static ShapedType getShapedType(MLIRContext *context, - ArrayRef shape) { - return RankedTensorType::get(shape, IntegerType::get(context, 1)); - } -}; -template <> -struct denseArrayAttrEltTypeBuilder { - constexpr static auto eltType = DenseArrayBaseAttr::EltType::I8; - static ShapedType getShapedType(MLIRContext *context, - ArrayRef shape) { - return RankedTensorType::get(shape, IntegerType::get(context, 8)); - } -}; -template <> -struct denseArrayAttrEltTypeBuilder { - constexpr static auto eltType = DenseArrayBaseAttr::EltType::I16; - static ShapedType getShapedType(MLIRContext *context, - ArrayRef shape) { - return RankedTensorType::get(shape, IntegerType::get(context, 16)); - } -}; -template <> -struct denseArrayAttrEltTypeBuilder { - constexpr static auto eltType = DenseArrayBaseAttr::EltType::I32; - static ShapedType getShapedType(MLIRContext *context, - ArrayRef shape) { - return RankedTensorType::get(shape, IntegerType::get(context, 32)); - } -}; -template <> -struct denseArrayAttrEltTypeBuilder { - constexpr static auto eltType = DenseArrayBaseAttr::EltType::I64; - static ShapedType getShapedType(MLIRContext *context, - ArrayRef shape) { - return RankedTensorType::get(shape, IntegerType::get(context, 64)); - } -}; -template <> -struct denseArrayAttrEltTypeBuilder { - constexpr static auto eltType = DenseArrayBaseAttr::EltType::F32; - static ShapedType getShapedType(MLIRContext *context, - ArrayRef shape) { - return RankedTensorType::get(shape, Float32Type::get(context)); - } -}; -template <> -struct denseArrayAttrEltTypeBuilder { - constexpr static auto eltType = DenseArrayBaseAttr::EltType::F64; - static ShapedType getShapedType(MLIRContext *context, - ArrayRef shape) { - return RankedTensorType::get(shape, Float64Type::get(context)); - } -}; -} // namespace - /// Builds a DenseArrayAttr from an ArrayRef. template DenseArrayAttr DenseArrayAttr::get(MLIRContext *context, ArrayRef content) { - auto size = static_cast(content.size()); - auto shapedType = - denseArrayAttrEltTypeBuilder::getShapedType(context, size); - auto eltType = denseArrayAttrEltTypeBuilder::eltType; + auto shapedType = RankedTensorType::get( + content.size(), DenseArrayAttrUtil::getElementType(context)); auto rawArray = ArrayRef(reinterpret_cast(content.data()), content.size() * sizeof(T)); - return Base::get(context, shapedType, eltType, rawArray) + return Base::get(context, shapedType, rawArray) .template cast>(); } template bool DenseArrayAttr::classof(Attribute attr) { - return attr.isa() && - attr.cast().getElementType() == - denseArrayAttrEltTypeBuilder::eltType; + if (auto denseArray = attr.dyn_cast()) + return DenseArrayAttrUtil::checkElementType(denseArray.getElementType()); + return false; } namespace mlir { diff --git a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp --- a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp +++ b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp @@ -9,6 +9,7 @@ #include "TestAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/FormatVariadic.h" using namespace mlir; @@ -42,29 +43,28 @@ continue; if (auto concreteAttr = attr.getValue().dyn_cast()) { - switch (concreteAttr.getElementType()) { - case DenseArrayBaseAttr::EltType::I1: - testElementsAttrIteration(op, elementsAttr, "bool"); - break; - case DenseArrayBaseAttr::EltType::I8: - testElementsAttrIteration(op, elementsAttr, "int8_t"); - break; - case DenseArrayBaseAttr::EltType::I16: - testElementsAttrIteration(op, elementsAttr, "int16_t"); - break; - case DenseArrayBaseAttr::EltType::I32: - testElementsAttrIteration(op, elementsAttr, "int32_t"); - break; - case DenseArrayBaseAttr::EltType::I64: - testElementsAttrIteration(op, elementsAttr, "int64_t"); - break; - case DenseArrayBaseAttr::EltType::F32: - testElementsAttrIteration(op, elementsAttr, "float"); - break; - case DenseArrayBaseAttr::EltType::F64: - testElementsAttrIteration(op, elementsAttr, "double"); - break; - } + llvm::TypeSwitch(concreteAttr) + .Case([&](DenseBoolArrayAttr attr) { + testElementsAttrIteration(op, attr, "bool"); + }) + .Case([&](DenseI8ArrayAttr attr) { + testElementsAttrIteration(op, attr, "int8_t"); + }) + .Case([&](DenseI16ArrayAttr attr) { + testElementsAttrIteration(op, attr, "int16_t"); + }) + .Case([&](DenseI32ArrayAttr attr) { + testElementsAttrIteration(op, attr, "int32_t"); + }) + .Case([&](DenseI64ArrayAttr attr) { + testElementsAttrIteration(op, attr, "int64_t"); + }) + .Case([&](DenseF32ArrayAttr attr) { + testElementsAttrIteration(op, attr, "float"); + }) + .Case([&](DenseF64ArrayAttr attr) { + testElementsAttrIteration(op, attr, "double"); + }); continue; } testElementsAttrIteration(op, elementsAttr, "int64_t");