diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -675,6 +675,21 @@ public: using ElementsAttr::ElementsAttr; + /// Type trait used to check if the given type T is a potentially valid C++ + /// floating point type that can be used to access the underlying element + /// types of a DenseElementsAttr. + // TODO: Use std::disjunction when C++17 is supported. + template struct is_valid_cpp_fp_type { + /// The type is a valid floating point type if it is a builtin floating + /// point type, or is a potentially user defined floating point type. The + /// latter allows for supporting users that have custom types defined for + /// bfloat16/half/etc. + static inline constexpr bool value = + llvm::is_one_of::value || + (std::numeric_limits::is_specialized && + !std::numeric_limits::is_integer); + }; + /// Method for support type inquiry through isa, cast and dyn_cast. static bool classof(Attribute attr); @@ -690,7 +705,7 @@ /// static shape. template ::is_integer || - llvm::is_one_of::value>::type> + is_valid_cpp_fp_type::value>::type> static DenseElementsAttr get(const ShapedType &type, ArrayRef values) { const char *data = reinterpret_cast(values.data()); return getRawIntOrFloat( @@ -701,7 +716,7 @@ /// Constructs a dense integer elements attribute from a single element. template ::is_integer || - llvm::is_one_of::value || + is_valid_cpp_fp_type::value || detail::is_complex_t::value>::type> static DenseElementsAttr get(const ShapedType &type, T value) { return get(type, llvm::makeArrayRef(value)); @@ -714,7 +729,7 @@ typename = typename std::enable_if< detail::is_complex_t::value && (std::numeric_limits::is_integer || - llvm::is_one_of::value)>::type> + is_valid_cpp_fp_type::value)>::type> static DenseElementsAttr get(const ShapedType &type, ArrayRef values) { const char *data = reinterpret_cast(values.data()); return getRawComplex(type, ArrayRef(data, values.size() * sizeof(T)), @@ -944,7 +959,7 @@ template ::value && std::numeric_limits::is_integer) || - llvm::is_one_of::value>::type> + is_valid_cpp_fp_type::value>::type> llvm::iterator_range> getValues() const { assert(isValidIntOrFloat(sizeof(T), std::numeric_limits::is_integer, std::numeric_limits::is_signed)); @@ -959,7 +974,7 @@ typename = typename std::enable_if< detail::is_complex_t::value && (std::numeric_limits::is_integer || - llvm::is_one_of::value)>::type> + is_valid_cpp_fp_type::value)>::type> llvm::iterator_range> getValues() const { assert(isValidComplex(sizeof(T), std::numeric_limits::is_integer, std::numeric_limits::is_signed)); @@ -1411,7 +1426,8 @@ template typename std::enable_if< std::numeric_limits::is_integer || - llvm::is_one_of::value || + DenseElementsAttr::is_valid_cpp_fp_type::value || + std::is_same::value || (detail::is_complex_t::value && !llvm::is_one_of, std::complex>::value),