diff --git a/llvm/include/llvm/Analysis/Utils/TFUtils.h b/llvm/include/llvm/Analysis/Utils/TFUtils.h --- a/llvm/include/llvm/Analysis/Utils/TFUtils.h +++ b/llvm/include/llvm/Analysis/Utils/TFUtils.h @@ -90,6 +90,13 @@ size_t ElementCount = 0; }; +/// Construct a TensorSpec from a JSON dictionary of the form: +/// { "name": , +/// "port": , +/// "type": , +/// "shape": } +/// For the "type" field, see the C++ primitive types used in +/// TFUTILS_SUPPORTED_TYPES. Optional getTensorSpecFromJSON(LLVMContext &Ctx, const json::Value &Value); @@ -155,23 +162,22 @@ std::unique_ptr Impl; }; -/// List of supported types, as a triple: -/// C++ type -/// short name (for strings, for instance) -/// capitalized short name (for enums, for instance) +/// List of supported types, as a pair: +/// - C++ type +/// - enum name (implementation-specific) #define TFUTILS_SUPPORTED_TYPES(M) \ - M(float, float, FLOAT) \ - M(double, double, DOUBLE) \ - M(int8_t, int8, INT8) \ - M(uint8_t, uint8, UINT8) \ - M(int16_t, int16, INT16) \ - M(uint16_t, uint16, UINT16) \ - M(int32_t, int32, INT32) \ - M(uint32_t, uint32, UINT32) \ - M(int64_t, int64, INT64) \ - M(uint64_t, uint64, UINT64) - -#define TFUTILS_GETDATATYPE_DEF(T, S, C) \ + M(float, TF_FLOAT) \ + M(double, TF_DOUBLE) \ + M(int8_t, TF_INT8) \ + M(uint8_t, TF_UINT8) \ + M(int16_t, TF_INT16) \ + M(uint16_t, TF_UINT16) \ + M(int32_t, TF_INT32) \ + M(uint32_t, TF_UINT32) \ + M(int64_t, TF_INT64) \ + M(uint64_t, TF_UINT64) + +#define TFUTILS_GETDATATYPE_DEF(T, E) \ template <> int TensorSpec::getDataType(); TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_DEF) diff --git a/llvm/lib/Analysis/TFUtils.cpp b/llvm/lib/Analysis/TFUtils.cpp --- a/llvm/lib/Analysis/TFUtils.cpp +++ b/llvm/lib/Analysis/TFUtils.cpp @@ -122,8 +122,8 @@ if (!Mapper.map>("shape", TensorShape)) return EmitError("'shape' property not present or not an int array"); -#define PARSE_TYPE(T, S, E) \ - if (TensorType == #S) \ +#define PARSE_TYPE(T, E) \ + if (TensorType == #T) \ return TensorSpec::createSpec(TensorName, TensorShape, TensorPort); TFUTILS_SUPPORTED_TYPES(PARSE_TYPE) #undef PARSE_TYPE @@ -307,8 +307,8 @@ return TF_TensorData(Impl->getOutput()[Index]); } -#define TFUTILS_GETDATATYPE_IMPL(T, S, E) \ - template <> int TensorSpec::getDataType() { return TF_##E; } +#define TFUTILS_GETDATATYPE_IMPL(T, E) \ + template <> int TensorSpec::getDataType() { return E; } TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_IMPL) diff --git a/llvm/lib/Analysis/models/inliner/output_spec.json b/llvm/lib/Analysis/models/inliner/output_spec.json --- a/llvm/lib/Analysis/models/inliner/output_spec.json +++ b/llvm/lib/Analysis/models/inliner/output_spec.json @@ -4,7 +4,7 @@ "tensor_spec": { "name": "StatefulPartitionedCall", "port": 0, - "type": "int64", + "type": "int64_t", "shape": [ 1 ] diff --git a/llvm/test/Transforms/Inline/ML/Inputs/test_output_spec.json b/llvm/test/Transforms/Inline/ML/Inputs/test_output_spec.json --- a/llvm/test/Transforms/Inline/ML/Inputs/test_output_spec.json +++ b/llvm/test/Transforms/Inline/ML/Inputs/test_output_spec.json @@ -4,7 +4,7 @@ "tensor_spec": { "name": "StatefulPartitionedCall", "port": 0, - "type": "int64", + "type": "int64_t", "shape": [ 1 ] @@ -15,7 +15,7 @@ "tensor_spec": { "name": "StatefulPartitionedCall", "port": 0, - "type": "int64", + "type": "int64_t", "shape": [ 1 ] diff --git a/llvm/unittests/Analysis/TFUtilsTest.cpp b/llvm/unittests/Analysis/TFUtilsTest.cpp --- a/llvm/unittests/Analysis/TFUtilsTest.cpp +++ b/llvm/unittests/Analysis/TFUtilsTest.cpp @@ -103,7 +103,7 @@ auto Value = json::parse( R"({"name": "tensor_name", "port": 2, - "type": "int32", + "type": "int32_t", "shape":[1,4] })"); EXPECT_TRUE(!!Value);