diff --git a/llvm/include/llvm/Analysis/Utils/TFUtils.h b/llvm/include/llvm/Analysis/Utils/TFUtils.h --- a/llvm/include/llvm/Analysis/Utils/TFUtils.h +++ b/llvm/include/llvm/Analysis/Utils/TFUtils.h @@ -46,23 +46,44 @@ /// /// TensorSpec is used to set up a TFModelEvaluator by describing the expected /// inputs and outputs. + +/// Known tensor types. +#define SUPPORTED_TENSOR_TYPES(M) \ + M(float, Float) \ + M(double, Double) \ + M(int8_t, Int8) \ + M(uint8_t, UInt8) \ + M(int16_t, Int16) \ + M(uint16_t, UInt16) \ + M(int32_t, Int32) \ + M(uint32_t, UInt32) \ + M(int64_t, Int64) \ + M(uint64_t, UInt64) + +enum class TensorType { + Invalid, +#define _TENSOR_TYPE_ENUM_MEMBERS(_, Name) Name, + SUPPORTED_TENSOR_TYPES(_TENSOR_TYPE_ENUM_MEMBERS) +#undef _TENSOR_TYPE_ENUM_MEMBERS +}; + class TensorSpec final { public: template static TensorSpec createSpec(const std::string &Name, const std::vector &Shape, int Port = 0) { - return TensorSpec(Name, Port, getDataType(), Shape); + return TensorSpec(Name, Port, getDataType(), sizeof(T), Shape); } const std::string &name() const { return Name; } int port() const { return Port; } - int typeIndex() const { return TypeIndex; } + TensorType type() const { return Type; } const std::vector &shape() const { return Shape; } bool operator==(const TensorSpec &Other) const { - return Name == Other.Name && Port == Other.Port && - TypeIndex == Other.TypeIndex && Shape == Other.Shape; + return Name == Other.Name && Port == Other.Port && Type == Other.Type && + Shape == Other.Shape; } bool operator!=(const TensorSpec &Other) const { return !(*this == Other); } @@ -70,25 +91,24 @@ /// Get the number of elements in a tensor with this shape. size_t getElementCount() const { return ElementCount; } /// Get the size, in bytes, of one element. - size_t getElementByteSize() const; + size_t getElementByteSize() const { return ElementSize; } template bool isElementType() const { - return getDataType() == TypeIndex; + return getDataType() == Type; } private: - TensorSpec(const std::string &Name, int Port, int TypeIndex, - const std::vector &Shape); + TensorSpec(const std::string &Name, int Port, TensorType Type, + size_t ElementSize, const std::vector &Shape); - template static int getDataType() { - llvm_unreachable("Undefined tensor type"); - } + template static TensorType getDataType(); std::string Name; int Port = 0; - int TypeIndex = 0; + TensorType Type = TensorType::Invalid; std::vector Shape; size_t ElementCount = 0; + size_t ElementSize = 0; }; /// Construct a TensorSpec from a JSON dictionary of the form: @@ -262,25 +282,9 @@ std::unique_ptr Impl; }; -/// List of supported types, as a pair: -/// - C++ type -/// - enum name (implementation-specific) -#define TFUTILS_SUPPORTED_TYPES(M) \ - M(float, TF_FLOAT) \ - M(double, TF_DOUBLE) \ - M(int8_t, TF_INT8) \ - M(uint8_t, TF_UINT8) \ - M(int16_t, TF_INT16) \ - M(uint16_t, TF_UINT16) \ - M(int32_t, TF_INT32) \ - M(uint32_t, TF_UINT32) \ - M(int64_t, TF_INT64) \ - M(uint64_t, TF_UINT64) - -#define TFUTILS_GETDATATYPE_DEF(T, E) \ - template <> int TensorSpec::getDataType(); - -TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_DEF) +#define TFUTILS_GETDATATYPE_DEF(T, Name) \ + template <> TensorType TensorSpec::getDataType(); +SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_DEF) #undef TFUTILS_GETDATATYPE_DEF } // namespace llvm diff --git a/llvm/lib/Analysis/TFUtils.cpp b/llvm/lib/Analysis/TFUtils.cpp --- a/llvm/lib/Analysis/TFUtils.cpp +++ b/llvm/lib/Analysis/TFUtils.cpp @@ -9,7 +9,9 @@ // This file implements utilities for interfacing with tensorflow C APIs. // //===----------------------------------------------------------------------===// +#include "tensorflow/c/tf_datatype.h" #include "llvm/Config/config.h" +#include "llvm/Support/ErrorHandling.h" #if defined(LLVM_HAVE_TF_API) #include "llvm/ADT/Twine.h" @@ -82,6 +84,33 @@ *OutStr = SE.SerializeAsString(); } } + +int getTFTypeIndex(TensorType TType) { + switch (TType) { + case TensorType::Double: + return TF_DOUBLE; + case TensorType::Float: + return TF_FLOAT; + case TensorType::Int8: + return TF_INT8; + case TensorType::UInt8: + return TF_UINT8; + case TensorType::Int16: + return TF_INT16; + case TensorType::UInt16: + return TF_UINT16; + case TensorType::Int32: + return TF_INT32; + case TensorType::UInt32: + return TF_UINT32; + case TensorType::Int64: + return TF_INT64; + case TensorType::UInt64: + return TF_UINT64; + case TensorType::Invalid: + llvm_unreachable("Unknown tensor type"); + } +} } // namespace namespace llvm { @@ -105,15 +134,12 @@ std::vector Output; }; -size_t TensorSpec::getElementByteSize() const { - return TF_DataTypeSize(static_cast(TypeIndex)); -} - -TensorSpec::TensorSpec(const std::string &Name, int Port, int TypeIndex, - const std::vector &Shape) - : Name(Name), Port(Port), TypeIndex(TypeIndex), Shape(Shape), +TensorSpec::TensorSpec(const std::string &Name, int Port, TensorType Type, + size_t ElementSize, const std::vector &Shape) + : Name(Name), Port(Port), Type(Type), Shape(Shape), ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1, - std::multiplies())) {} + std::multiplies())), + ElementSize(ElementSize) {} Optional getTensorSpecFromJSON(LLVMContext &Ctx, const json::Value &Value) { @@ -147,7 +173,7 @@ #define PARSE_TYPE(T, E) \ if (TensorType == #T) \ return TensorSpec::createSpec(TensorName, TensorShape, TensorPort); - TFUTILS_SUPPORTED_TYPES(PARSE_TYPE) + SUPPORTED_TENSOR_TYPES(PARSE_TYPE) #undef PARSE_TYPE return None; } @@ -390,7 +416,7 @@ InputSpec.port()}; if (!checkReportAndInvalidate(InputFeed[I], InputSpec)) return; - initInput(I, static_cast(InputSpec.typeIndex()), + initInput(I, static_cast(getTFTypeIndex(InputSpec.type())), InputSpec.shape()); } for (size_t I = 0; I < OutputSpecsSize; ++I) { @@ -496,9 +522,9 @@ } #define TFUTILS_GETDATATYPE_IMPL(T, E) \ - template <> int TensorSpec::getDataType() { return E; } + template <> TensorType TensorSpec::getDataType() { return TensorType::E; } -TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_IMPL) +SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL) #undef TFUTILS_GETDATATYPE_IMPL