diff --git a/llvm/include/llvm/Analysis/Utils/TFUtils.h b/llvm/include/llvm/Analysis/Utils/TFUtils.h --- a/llvm/include/llvm/Analysis/Utils/TFUtils.h +++ b/llvm/include/llvm/Analysis/Utils/TFUtils.h @@ -13,6 +13,7 @@ #ifdef LLVM_HAVE_TF_API #include "llvm/IR/LLVMContext.h" +#include "llvm/Support/JSON.h" #include #include @@ -58,6 +59,13 @@ int typeIndex() const { return TypeIndex; } const std::vector &shape() const { return Shape; } + bool operator==(const TensorSpec &Other) const { + return Name == Other.Name && Port == Other.Port && + TypeIndex == Other.TypeIndex && Shape == Other.Shape; + } + + bool operator!=(const TensorSpec &Other) const { return !(*this == Other); } + private: TensorSpec(const std::string &Name, int Port, int TypeIndex, const std::vector &Shape) @@ -73,6 +81,9 @@ std::vector Shape; }; +Optional getTensorSpecFromJSON(LLVMContext &Ctx, + const json::Value &Value); + class TFModelEvaluator final { public: /// The result of a model evaluation. Handles the lifetime of the output @@ -124,17 +135,28 @@ std::unique_ptr Impl; }; -template <> int TensorSpec::getDataType(); -template <> int TensorSpec::getDataType(); -template <> int TensorSpec::getDataType(); -template <> int TensorSpec::getDataType(); -template <> int TensorSpec::getDataType(); -template <> int TensorSpec::getDataType(); -template <> int TensorSpec::getDataType(); -template <> int TensorSpec::getDataType(); -template <> int TensorSpec::getDataType(); -template <> int TensorSpec::getDataType(); - +/// List of supported types, as a triple: +/// C++ type +/// short name (for strings, for instance) +/// capitalized short name (for enums, for instance) +#define TFUTILS_SUPPORTED_TYPES(M) \ + M(float, float, FLOAT) \ + M(double, double, DOUBLE) \ + M(int8_t, int8, INT8) \ + M(uint8_t, uint8, UINT8) \ + M(int16_t, int16, INT16) \ + M(uint16_t, uint16, UINT16) \ + M(int32_t, int32, INT32) \ + M(uint32_t, uint32, UINT32) \ + M(int64_t, int64, INT64) \ + M(uint64_t, uint64, UINT64) + +#define TFUTILS_GETDATATYPE_DEF(T, S, C) \ + template <> int TensorSpec::getDataType(); + +TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_DEF) + +#undef TFUTILS_GETDATATYPE_DEF } // namespace llvm #endif // LLVM_HAVE_TF_API diff --git a/llvm/lib/Analysis/TFUtils.cpp b/llvm/lib/Analysis/TFUtils.cpp --- a/llvm/lib/Analysis/TFUtils.cpp +++ b/llvm/lib/Analysis/TFUtils.cpp @@ -13,9 +13,10 @@ #include "llvm/Config/config.h" #if defined(LLVM_HAVE_TF_API) -#include "llvm/Analysis/Utils/TFUtils.h" #include "llvm/ADT/Twine.h" +#include "llvm/Analysis/Utils/TFUtils.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/JSON.h" #include "llvm/Support/ManagedStatic.h" #include "llvm/Support/raw_ostream.h" @@ -83,6 +84,41 @@ std::vector Output; }; +Optional getTensorSpecFromJSON(LLVMContext &Ctx, + const json::Value &Value) { + auto EmitError = [&](const llvm::Twine &Message) -> Optional { + std::string S; + llvm::raw_string_ostream OS(S); + OS << Value; + Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S); + return None; + }; + json::ObjectMapper Mapper(Value); + if (!Mapper) + return EmitError("Value is not a dict"); + + std::string TensorName; + int TensorPort = -1; + std::string TensorType; + std::vector TensorShape; + + if (!Mapper.map("name", TensorName)) + return EmitError("'name' property not present or not a string"); + if (!Mapper.map("type", TensorType)) + return EmitError("'type' property not present or not a string"); + if (!Mapper.map("port", TensorPort)) + return EmitError("'port' property not present or not an int"); + if (!Mapper.map>("shape", TensorShape)) + return EmitError("'shape' property not present or not an int array"); + +#define PARSE_TYPE(T, S, E) \ + if (TensorType == #S) \ + return TensorSpec::createSpec(TensorName, TensorShape, TensorPort); + TFUTILS_SUPPORTED_TYPES(PARSE_TYPE) +#undef PARSE_TYPE + return None; +} + class TFModelEvaluatorImpl { public: TFModelEvaluatorImpl(StringRef SavedModelPath, @@ -249,25 +285,12 @@ return TF_TensorData(Impl->getOutput()[Index]); } -template <> int TensorSpec::getDataType() { return TF_FLOAT; } - -template <> int TensorSpec::getDataType() { return TF_DOUBLE; } - -template <> int TensorSpec::getDataType() { return TF_INT8; } - -template <> int TensorSpec::getDataType() { return TF_UINT8; } - -template <> int TensorSpec::getDataType() { return TF_INT16; } - -template <> int TensorSpec::getDataType() { return TF_UINT16; } - -template <> int TensorSpec::getDataType() { return TF_INT32; } - -template <> int TensorSpec::getDataType() { return TF_UINT32; } +#define TFUTILS_GETDATATYPE_IMPL(T, S, E) \ + template <> int TensorSpec::getDataType() { return TF_##E; } -template <> int TensorSpec::getDataType() { return TF_INT64; } +TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_IMPL) -template <> int TensorSpec::getDataType() { return TF_UINT64; } +#undef TFUTILS_GETDATATYPE_IMPL TFModelEvaluator::EvaluationResult::~EvaluationResult() {} TFModelEvaluator::~TFModelEvaluator() {} diff --git a/llvm/unittests/Analysis/TFUtilsTest.cpp b/llvm/unittests/Analysis/TFUtilsTest.cpp --- a/llvm/unittests/Analysis/TFUtilsTest.cpp +++ b/llvm/unittests/Analysis/TFUtilsTest.cpp @@ -94,3 +94,32 @@ EXPECT_FALSE(ER.hasValue()); EXPECT_FALSE(Evaluator.isValid()); } + +TEST(TFUtilsTest, JSONParsing) { + auto Value = json::parse( + R"({"name": "tensor_name", + "port": 2, + "type": "int32", + "shape":[1,4] + })"); + EXPECT_TRUE(!!Value); + LLVMContext Ctx; + Optional Spec = getTensorSpecFromJSON(Ctx, *Value); + EXPECT_TRUE(Spec.hasValue()); + EXPECT_EQ(*Spec, TensorSpec::createSpec("tensor_name", {1, 4}, 2)); +} + +TEST(TFUtilsTest, JSONParsingInvalidTensorType) { + auto Value = json::parse( + R"( + {"name": "tensor_name", + "port": 2, + "type": "no such type", + "shape":[1,4] + } + )"); + EXPECT_TRUE(!!Value); + LLVMContext Ctx; + auto Spec = getTensorSpecFromJSON(Ctx, *Value); + EXPECT_FALSE(Spec.hasValue()); +}