diff --git a/llvm/include/llvm/Analysis/TensorSpec.h b/llvm/include/llvm/Analysis/TensorSpec.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/Analysis/TensorSpec.h @@ -0,0 +1,130 @@ +//===- TensorSpec.h - type descriptor for a tensor --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +#ifndef LLVM_ANALYSIS_TENSORSPEC_H +#define LLVM_ANALYSIS_TENSORSPEC_H + +#include "llvm/Config/llvm-config.h" + +#include "llvm/ADT/StringMap.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/Support/JSON.h" + +#include +#include + +namespace llvm { +/// TensorSpec encapsulates the specification of a tensor: its dimensions, or +/// "shape" (row-major), its type (see TensorSpec::getDataType specializations +/// for supported types), its name and port (see "TensorFlow: Large-Scale +/// Machine Learning on Heterogeneous Distributed Systems", section 4.2, para 2: +/// https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45166.pdf) +/// +/// Known tensor types. The left part is the C type, the right is a name we +/// can use to identify the type (to implement TensorSpec equality checks), and +/// to use, if needed, when mapping to an underlying evaluator's type system. +/// The main requirement is that the C type we use has the same size and +/// encoding (e.g. endian-ness) as the one used by the evaluator. +#define SUPPORTED_TENSOR_TYPES(M) \ + M(float, Float) \ + M(double, Double) \ + M(int8_t, Int8) \ + M(uint8_t, UInt8) \ + M(int16_t, Int16) \ + M(uint16_t, UInt16) \ + M(int32_t, Int32) \ + M(uint32_t, UInt32) \ + M(int64_t, Int64) \ + M(uint64_t, UInt64) + +enum class TensorType { + Invalid, +#define _TENSOR_TYPE_ENUM_MEMBERS(_, Name) Name, + SUPPORTED_TENSOR_TYPES(_TENSOR_TYPE_ENUM_MEMBERS) +#undef _TENSOR_TYPE_ENUM_MEMBERS +}; + +class TensorSpec final { +public: + template + static TensorSpec createSpec(const std::string &Name, + const std::vector &Shape, + int Port = 0) { + return TensorSpec(Name, Port, getDataType(), sizeof(T), Shape); + } + + const std::string &name() const { return Name; } + int port() const { return Port; } + TensorType type() const { return Type; } + const std::vector &shape() const { return Shape; } + + bool operator==(const TensorSpec &Other) const { + return Name == Other.Name && Port == Other.Port && Type == Other.Type && + Shape == Other.Shape; + } + + bool operator!=(const TensorSpec &Other) const { return !(*this == Other); } + + /// Get the number of elements in a tensor with this shape. + size_t getElementCount() const { return ElementCount; } + /// Get the size, in bytes, of one element. + size_t getElementByteSize() const { return ElementSize; } + + template bool isElementType() const { + return getDataType() == Type; + } + +private: + TensorSpec(const std::string &Name, int Port, TensorType Type, + size_t ElementSize, const std::vector &Shape); + + template static TensorType getDataType(); + + std::string Name; + int Port = 0; + TensorType Type = TensorType::Invalid; + std::vector Shape; + size_t ElementCount = 0; + size_t ElementSize = 0; +}; + +/// Construct a TensorSpec from a JSON dictionary of the form: +/// { "name": , +/// "port": , +/// "type": , +/// "shape": } +/// For the "type" field, see the C++ primitive types used in +/// TFUTILS_SUPPORTED_TYPES. +Optional getTensorSpecFromJSON(LLVMContext &Ctx, + const json::Value &Value); + +struct LoggedFeatureSpec { + TensorSpec Spec; + Optional LoggingName; + const std::string &getLoggingName() const { + return LoggingName ? *LoggingName : Spec.name(); + } +}; + +/// Load the output specs. If SpecFileOverride is not empty, that path is used. +/// Otherwise, the file is assumed to be called 'output_spec.json' and be found +/// under ModelPath (the model directory). +/// The first output tensor name must match ExpectedDecisionName. +/// In case of error, the return is None and the error is logged. +Optional> +loadOutputSpecs(LLVMContext &Ctx, StringRef ExpectedDecisionName, + StringRef ModelPath, StringRef SpecFileOverride = StringRef()); + +#define TFUTILS_GETDATATYPE_DEF(T, Name) \ + template <> TensorType TensorSpec::getDataType(); +SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_DEF) + +#undef TFUTILS_GETDATATYPE_DEF +} // namespace llvm + +#endif // LLVM_ANALYSIS_TENSORSPEC_H diff --git a/llvm/include/llvm/Analysis/Utils/TFUtils.h b/llvm/include/llvm/Analysis/Utils/TFUtils.h --- a/llvm/include/llvm/Analysis/Utils/TFUtils.h +++ b/llvm/include/llvm/Analysis/Utils/TFUtils.h @@ -13,6 +13,7 @@ #ifdef LLVM_HAVE_TF_API #include "llvm/ADT/StringMap.h" +#include "llvm/Analysis/TensorSpec.h" #include "llvm/IR/LLVMContext.h" #include "llvm/Support/JSON.h" @@ -38,110 +39,6 @@ class TFModelEvaluatorImpl; class EvaluationResultImpl; -/// TensorSpec encapsulates the specification of a tensor: its dimensions, or -/// "shape" (row-major), its type (see TensorSpec::getDataType specializations -/// for supported types), its name and port (see "TensorFlow: Large-Scale -/// Machine Learning on Heterogeneous Distributed Systems", section 4.2, para 2: -/// https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45166.pdf) -/// -/// TensorSpec is used to set up a TFModelEvaluator by describing the expected -/// inputs and outputs. - -/// Known tensor types. The left part is the C type, the right is a name we -/// can use to identify the type (to implement TensorSpec equality checks), and -/// to use, if needed, when mapping to an underlying evaluator's type system. -/// The main requirement is that the C type we use has the same size and -/// encoding (e.g. endian-ness) as the one used by the evaluator. -#define SUPPORTED_TENSOR_TYPES(M) \ - M(float, Float) \ - M(double, Double) \ - M(int8_t, Int8) \ - M(uint8_t, UInt8) \ - M(int16_t, Int16) \ - M(uint16_t, UInt16) \ - M(int32_t, Int32) \ - M(uint32_t, UInt32) \ - M(int64_t, Int64) \ - M(uint64_t, UInt64) - -enum class TensorType { - Invalid, -#define _TENSOR_TYPE_ENUM_MEMBERS(_, Name) Name, - SUPPORTED_TENSOR_TYPES(_TENSOR_TYPE_ENUM_MEMBERS) -#undef _TENSOR_TYPE_ENUM_MEMBERS -}; - -class TensorSpec final { -public: - template - static TensorSpec createSpec(const std::string &Name, - const std::vector &Shape, - int Port = 0) { - return TensorSpec(Name, Port, getDataType(), sizeof(T), Shape); - } - - const std::string &name() const { return Name; } - int port() const { return Port; } - TensorType type() const { return Type; } - const std::vector &shape() const { return Shape; } - - bool operator==(const TensorSpec &Other) const { - return Name == Other.Name && Port == Other.Port && Type == Other.Type && - Shape == Other.Shape; - } - - bool operator!=(const TensorSpec &Other) const { return !(*this == Other); } - - /// Get the number of elements in a tensor with this shape. - size_t getElementCount() const { return ElementCount; } - /// Get the size, in bytes, of one element. - size_t getElementByteSize() const { return ElementSize; } - - template bool isElementType() const { - return getDataType() == Type; - } - -private: - TensorSpec(const std::string &Name, int Port, TensorType Type, - size_t ElementSize, const std::vector &Shape); - - template static TensorType getDataType(); - - std::string Name; - int Port = 0; - TensorType Type = TensorType::Invalid; - std::vector Shape; - size_t ElementCount = 0; - size_t ElementSize = 0; -}; - -/// Construct a TensorSpec from a JSON dictionary of the form: -/// { "name": , -/// "port": , -/// "type": , -/// "shape": } -/// For the "type" field, see the C++ primitive types used in -/// TFUTILS_SUPPORTED_TYPES. -Optional getTensorSpecFromJSON(LLVMContext &Ctx, - const json::Value &Value); - -struct LoggedFeatureSpec { - TensorSpec Spec; - Optional LoggingName; - const std::string &getLoggingName() const { - return LoggingName ? *LoggingName : Spec.name(); - } -}; - -/// Load the output specs. If SpecFileOverride is not empty, that path is used. -/// Otherwise, the file is assumed to be called 'output_spec.json' and be found -/// under ModelPath (the model directory). -/// The first output tensor name must match ExpectedDecisionName. -/// In case of error, the return is None and the error is logged. -Optional> -loadOutputSpecs(LLVMContext &Ctx, StringRef ExpectedDecisionName, - StringRef ModelPath, StringRef SpecFileOverride = StringRef()); - /// Logging utility - given an ordered specification of features, and assuming /// a scalar reward, allow logging feature values and rewards, and then print /// as tf.train.SequenceExample text protobuf. @@ -286,11 +183,6 @@ std::unique_ptr Impl; }; -#define TFUTILS_GETDATATYPE_DEF(T, Name) \ - template <> TensorType TensorSpec::getDataType(); -SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_DEF) - -#undef TFUTILS_GETDATATYPE_DEF } // namespace llvm #endif // LLVM_HAVE_TF_API diff --git a/llvm/lib/Analysis/CMakeLists.txt b/llvm/lib/Analysis/CMakeLists.txt --- a/llvm/lib/Analysis/CMakeLists.txt +++ b/llvm/lib/Analysis/CMakeLists.txt @@ -131,6 +131,7 @@ SyncDependenceAnalysis.cpp SyntheticCountsUtils.cpp TFUtils.cpp + TensorSpec.cpp TargetLibraryInfo.cpp TargetTransformInfo.cpp Trace.cpp diff --git a/llvm/lib/Analysis/TFUtils.cpp b/llvm/lib/Analysis/TFUtils.cpp --- a/llvm/lib/Analysis/TFUtils.cpp +++ b/llvm/lib/Analysis/TFUtils.cpp @@ -132,113 +132,6 @@ std::vector Output; }; -TensorSpec::TensorSpec(const std::string &Name, int Port, TensorType Type, - size_t ElementSize, const std::vector &Shape) - : Name(Name), Port(Port), Type(Type), Shape(Shape), - ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1, - std::multiplies())), - ElementSize(ElementSize) {} - -Optional getTensorSpecFromJSON(LLVMContext &Ctx, - const json::Value &Value) { - auto EmitError = [&](const llvm::Twine &Message) -> Optional { - std::string S; - llvm::raw_string_ostream OS(S); - OS << Value; - Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S); - return None; - }; - // FIXME: accept a Path as a parameter, and use it for error reporting. - json::Path::Root Root("tensor_spec"); - json::ObjectMapper Mapper(Value, Root); - if (!Mapper) - return EmitError("Value is not a dict"); - - std::string TensorName; - int TensorPort = -1; - std::string TensorType; - std::vector TensorShape; - - if (!Mapper.map("name", TensorName)) - return EmitError("'name' property not present or not a string"); - if (!Mapper.map("type", TensorType)) - return EmitError("'type' property not present or not a string"); - if (!Mapper.map("port", TensorPort)) - return EmitError("'port' property not present or not an int"); - if (!Mapper.map>("shape", TensorShape)) - return EmitError("'shape' property not present or not an int array"); - -#define PARSE_TYPE(T, E) \ - if (TensorType == #T) \ - return TensorSpec::createSpec(TensorName, TensorShape, TensorPort); - SUPPORTED_TENSOR_TYPES(PARSE_TYPE) -#undef PARSE_TYPE - return None; -} - -Optional> -loadOutputSpecs(LLVMContext &Ctx, StringRef ExpectedDecisionName, - StringRef ModelPath, StringRef SpecFileOverride) { - SmallVector OutputSpecsPath; - StringRef FileName = SpecFileOverride; - if (FileName.empty()) { - llvm::sys::path::append(OutputSpecsPath, ModelPath, "output_spec.json"); - FileName = {OutputSpecsPath.data(), OutputSpecsPath.size()}; - } - - auto BufferOrError = MemoryBuffer::getFileOrSTDIN(FileName); - if (!BufferOrError) { - Ctx.emitError("Error opening output specs file: " + FileName + " : " + - BufferOrError.getError().message()); - return None; - } - auto ParsedJSONValues = json::parse(BufferOrError.get()->getBuffer()); - if (!ParsedJSONValues) { - Ctx.emitError("Could not parse specs file: " + FileName); - return None; - } - auto ValuesArray = ParsedJSONValues->getAsArray(); - if (!ValuesArray) { - Ctx.emitError("Expected an array of {tensor_spec:, " - "logging_name:} dictionaries"); - return None; - } - std::vector Ret; - for (const auto &Value : *ValuesArray) - if (const auto *Obj = Value.getAsObject()) - if (const auto *SpecPart = Obj->get("tensor_spec")) - if (auto TensorSpec = getTensorSpecFromJSON(Ctx, *SpecPart)) - if (auto LoggingName = Obj->getString("logging_name")) { - if (!TensorSpec->isElementType() && - !TensorSpec->isElementType() && - !TensorSpec->isElementType()) { - Ctx.emitError( - "Only int64, int32, and float tensors are supported. " - "Found unsupported type for tensor named " + - TensorSpec->name()); - return None; - } - Ret.push_back({*TensorSpec, LoggingName->str()}); - } - - if (ValuesArray->size() != Ret.size()) { - Ctx.emitError( - "Unable to parse output spec. It should be a json file containing an " - "array of dictionaries. Each dictionary must have a 'tensor_spec' key, " - "with a json object describing a TensorSpec; and a 'logging_name' key, " - "which is a string to use as name when logging this tensor in the " - "training log."); - return None; - } - if (Ret.empty() || *Ret[0].LoggingName != ExpectedDecisionName) { - Ctx.emitError("The first output spec must describe the decision tensor, " - "and must have the logging_name " + - StringRef(ExpectedDecisionName)); - return None; - } - return Ret; -} - class TFModelEvaluatorImpl { public: TFModelEvaluatorImpl(StringRef SavedModelPath, @@ -519,13 +412,6 @@ return TF_TensorData(Impl->getOutput()[Index]); } -#define TFUTILS_GETDATATYPE_IMPL(T, E) \ - template <> TensorType TensorSpec::getDataType() { return TensorType::E; } - -SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL) - -#undef TFUTILS_GETDATATYPE_IMPL - TFModelEvaluator::EvaluationResult::~EvaluationResult() {} TFModelEvaluator::~TFModelEvaluator() {} diff --git a/llvm/lib/Analysis/TensorSpec.cpp b/llvm/lib/Analysis/TensorSpec.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Analysis/TensorSpec.cpp @@ -0,0 +1,144 @@ +//===- TensorSpec.cpp - tensor type abstraction ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implementation file for the abstraction of a tensor type, and JSON loading +// utils. +// +//===----------------------------------------------------------------------===// +#include "llvm/Config/config.h" + +#include "llvm/ADT/Twine.h" +#include "llvm/Analysis/TensorSpec.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/JSON.h" +#include "llvm/Support/ManagedStatic.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/raw_ostream.h" +#include +#include + +using namespace llvm; + +namespace llvm { + +#define TFUTILS_GETDATATYPE_IMPL(T, E) \ + template <> TensorType TensorSpec::getDataType() { return TensorType::E; } + +SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL) + +#undef TFUTILS_GETDATATYPE_IMPL + +TensorSpec::TensorSpec(const std::string &Name, int Port, TensorType Type, + size_t ElementSize, const std::vector &Shape) + : Name(Name), Port(Port), Type(Type), Shape(Shape), + ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1, + std::multiplies())), + ElementSize(ElementSize) {} + +Optional getTensorSpecFromJSON(LLVMContext &Ctx, + const json::Value &Value) { + auto EmitError = [&](const llvm::Twine &Message) -> Optional { + std::string S; + llvm::raw_string_ostream OS(S); + OS << Value; + Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S); + return None; + }; + // FIXME: accept a Path as a parameter, and use it for error reporting. + json::Path::Root Root("tensor_spec"); + json::ObjectMapper Mapper(Value, Root); + if (!Mapper) + return EmitError("Value is not a dict"); + + std::string TensorName; + int TensorPort = -1; + std::string TensorType; + std::vector TensorShape; + + if (!Mapper.map("name", TensorName)) + return EmitError("'name' property not present or not a string"); + if (!Mapper.map("type", TensorType)) + return EmitError("'type' property not present or not a string"); + if (!Mapper.map("port", TensorPort)) + return EmitError("'port' property not present or not an int"); + if (!Mapper.map>("shape", TensorShape)) + return EmitError("'shape' property not present or not an int array"); + +#define PARSE_TYPE(T, E) \ + if (TensorType == #T) \ + return TensorSpec::createSpec(TensorName, TensorShape, TensorPort); + SUPPORTED_TENSOR_TYPES(PARSE_TYPE) +#undef PARSE_TYPE + return None; +} + +Optional> +loadOutputSpecs(LLVMContext &Ctx, StringRef ExpectedDecisionName, + StringRef ModelPath, StringRef SpecFileOverride) { + SmallVector OutputSpecsPath; + StringRef FileName = SpecFileOverride; + if (FileName.empty()) { + llvm::sys::path::append(OutputSpecsPath, ModelPath, "output_spec.json"); + FileName = {OutputSpecsPath.data(), OutputSpecsPath.size()}; + } + + auto BufferOrError = MemoryBuffer::getFileOrSTDIN(FileName); + if (!BufferOrError) { + Ctx.emitError("Error opening output specs file: " + FileName + " : " + + BufferOrError.getError().message()); + return None; + } + auto ParsedJSONValues = json::parse(BufferOrError.get()->getBuffer()); + if (!ParsedJSONValues) { + Ctx.emitError("Could not parse specs file: " + FileName); + return None; + } + auto ValuesArray = ParsedJSONValues->getAsArray(); + if (!ValuesArray) { + Ctx.emitError("Expected an array of {tensor_spec:, " + "logging_name:} dictionaries"); + return None; + } + std::vector Ret; + for (const auto &Value : *ValuesArray) + if (const auto *Obj = Value.getAsObject()) + if (const auto *SpecPart = Obj->get("tensor_spec")) + if (auto TensorSpec = getTensorSpecFromJSON(Ctx, *SpecPart)) + if (auto LoggingName = Obj->getString("logging_name")) { + if (!TensorSpec->isElementType() && + !TensorSpec->isElementType() && + !TensorSpec->isElementType()) { + Ctx.emitError( + "Only int64, int32, and float tensors are supported. " + "Found unsupported type for tensor named " + + TensorSpec->name()); + return None; + } + Ret.push_back({*TensorSpec, LoggingName->str()}); + } + + if (ValuesArray->size() != Ret.size()) { + Ctx.emitError( + "Unable to parse output spec. It should be a json file containing an " + "array of dictionaries. Each dictionary must have a 'tensor_spec' key, " + "with a json object describing a TensorSpec; and a 'logging_name' key, " + "which is a string to use as name when logging this tensor in the " + "training log."); + return None; + } + if (Ret.empty() || *Ret[0].LoggingName != ExpectedDecisionName) { + Ctx.emitError("The first output spec must describe the decision tensor, " + "and must have the logging_name " + + StringRef(ExpectedDecisionName)); + return None; + } + return Ret; +} +} // namespace llvm diff --git a/llvm/unittests/Analysis/CMakeLists.txt b/llvm/unittests/Analysis/CMakeLists.txt --- a/llvm/unittests/Analysis/CMakeLists.txt +++ b/llvm/unittests/Analysis/CMakeLists.txt @@ -45,6 +45,7 @@ VectorFunctionABITest.cpp SparsePropagation.cpp TargetLibraryInfoTest.cpp + TensorSpecTest.cpp TBAATest.cpp UnrollAnalyzerTest.cpp ValueLatticeTest.cpp diff --git a/llvm/unittests/Analysis/TFUtilsTest.cpp b/llvm/unittests/Analysis/TFUtilsTest.cpp --- a/llvm/unittests/Analysis/TFUtilsTest.cpp +++ b/llvm/unittests/Analysis/TFUtilsTest.cpp @@ -102,50 +102,6 @@ EXPECT_FALSE(Evaluator.isValid()); } -TEST(TFUtilsTest, JSONParsing) { - auto Value = json::parse( - R"({"name": "tensor_name", - "port": 2, - "type": "int32_t", - "shape":[1,4] - })"); - EXPECT_TRUE(!!Value); - LLVMContext Ctx; - Optional Spec = getTensorSpecFromJSON(Ctx, *Value); - EXPECT_TRUE(Spec.hasValue()); - EXPECT_EQ(*Spec, TensorSpec::createSpec("tensor_name", {1, 4}, 2)); -} - -TEST(TFUtilsTest, JSONParsingInvalidTensorType) { - auto Value = json::parse( - R"( - {"name": "tensor_name", - "port": 2, - "type": "no such type", - "shape":[1,4] - } - )"); - EXPECT_TRUE(!!Value); - LLVMContext Ctx; - auto Spec = getTensorSpecFromJSON(Ctx, *Value); - EXPECT_FALSE(Spec.hasValue()); -} - -TEST(TFUtilsTest, TensorSpecSizesAndTypes) { - auto Spec1D = TensorSpec::createSpec("Hi1", {1}); - auto Spec2D = TensorSpec::createSpec("Hi2", {1, 1}); - auto Spec1DLarge = TensorSpec::createSpec("Hi3", {10}); - auto Spec3DLarge = TensorSpec::createSpec("Hi3", {2, 4, 10}); - EXPECT_TRUE(Spec1D.isElementType()); - EXPECT_FALSE(Spec3DLarge.isElementType()); - EXPECT_EQ(Spec1D.getElementCount(), 1U); - EXPECT_EQ(Spec2D.getElementCount(), 1U); - EXPECT_EQ(Spec1DLarge.getElementCount(), 10U); - EXPECT_EQ(Spec3DLarge.getElementCount(), 80U); - EXPECT_EQ(Spec3DLarge.getElementByteSize(), sizeof(float)); - EXPECT_EQ(Spec1D.getElementByteSize(), sizeof(int16_t)); -} - #define PROTO_CHECKER(FNAME, TYPE, INDEX, EXP) \ do { \ const auto &V = Expected.feature_lists() \ diff --git a/llvm/unittests/Analysis/TensorSpecTest.cpp b/llvm/unittests/Analysis/TensorSpecTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/Analysis/TensorSpecTest.cpp @@ -0,0 +1,61 @@ +//===- TensorSpecTest.cpp - test for TensorSpec ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/TensorSpec.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Testing/Support/SupportHelpers.h" +#include "gtest/gtest.h" + +using namespace llvm; + +extern const char *TestMainArgv0; + +TEST(TensorSpecTest, JSONParsing) { + auto Value = json::parse( + R"({"name": "tensor_name", + "port": 2, + "type": "int32_t", + "shape":[1,4] + })"); + EXPECT_TRUE(!!Value); + LLVMContext Ctx; + Optional Spec = getTensorSpecFromJSON(Ctx, *Value); + EXPECT_TRUE(Spec.hasValue()); + EXPECT_EQ(*Spec, TensorSpec::createSpec("tensor_name", {1, 4}, 2)); +} + +TEST(TensorSpecTest, JSONParsingInvalidTensorType) { + auto Value = json::parse( + R"( + {"name": "tensor_name", + "port": 2, + "type": "no such type", + "shape":[1,4] + } + )"); + EXPECT_TRUE(!!Value); + LLVMContext Ctx; + auto Spec = getTensorSpecFromJSON(Ctx, *Value); + EXPECT_FALSE(Spec.hasValue()); +} + +TEST(TensorSpecTest, TensorSpecSizesAndTypes) { + auto Spec1D = TensorSpec::createSpec("Hi1", {1}); + auto Spec2D = TensorSpec::createSpec("Hi2", {1, 1}); + auto Spec1DLarge = TensorSpec::createSpec("Hi3", {10}); + auto Spec3DLarge = TensorSpec::createSpec("Hi3", {2, 4, 10}); + EXPECT_TRUE(Spec1D.isElementType()); + EXPECT_FALSE(Spec3DLarge.isElementType()); + EXPECT_EQ(Spec1D.getElementCount(), 1U); + EXPECT_EQ(Spec2D.getElementCount(), 1U); + EXPECT_EQ(Spec1DLarge.getElementCount(), 10U); + EXPECT_EQ(Spec3DLarge.getElementCount(), 80U); + EXPECT_EQ(Spec3DLarge.getElementByteSize(), sizeof(float)); + EXPECT_EQ(Spec1D.getElementByteSize(), sizeof(int16_t)); +}