diff --git a/llvm/include/llvm/Analysis/Utils/TFUtils.h b/llvm/include/llvm/Analysis/Utils/TFUtils.h --- a/llvm/include/llvm/Analysis/Utils/TFUtils.h +++ b/llvm/include/llvm/Analysis/Utils/TFUtils.h @@ -66,10 +66,18 @@ bool operator!=(const TensorSpec &Other) const { return !(*this == Other); } + /// Get the number of elements in a tensor with this shape. + size_t getElementCount() const { return ElementCount; } + /// Get the size, in bytes, of one element. + size_t getElementByteSize() const; + + template bool isElementType() const { + return getDataType() == TypeIndex; + } + private: TensorSpec(const std::string &Name, int Port, int TypeIndex, - const std::vector &Shape) - : Name(Name), Port(Port), TypeIndex(TypeIndex), Shape(Shape) {} + const std::vector &Shape); template static int getDataType() { llvm_unreachable("Undefined tensor type"); @@ -79,6 +87,7 @@ int Port = 0; int TypeIndex = 0; std::vector Shape; + size_t ElementCount = 0; }; Optional getTensorSpecFromJSON(LLVMContext &Ctx, diff --git a/llvm/lib/Analysis/TFUtils.cpp b/llvm/lib/Analysis/TFUtils.cpp --- a/llvm/lib/Analysis/TFUtils.cpp +++ b/llvm/lib/Analysis/TFUtils.cpp @@ -24,6 +24,7 @@ #include "tensorflow/c/c_api_experimental.h" #include +#include using namespace llvm; @@ -84,6 +85,16 @@ std::vector Output; }; +size_t TensorSpec::getElementByteSize() const { + return TF_DataTypeSize(static_cast(TypeIndex)); +} + +TensorSpec::TensorSpec(const std::string &Name, int Port, int TypeIndex, + const std::vector &Shape) + : Name(Name), Port(Port), TypeIndex(TypeIndex), Shape(Shape), + ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1, + std::multiplies())) {} + Optional getTensorSpecFromJSON(LLVMContext &Ctx, const json::Value &Value) { auto EmitError = [&](const llvm::Twine &Message) -> Optional { diff --git a/llvm/unittests/Analysis/TFUtilsTest.cpp b/llvm/unittests/Analysis/TFUtilsTest.cpp --- a/llvm/unittests/Analysis/TFUtilsTest.cpp +++ b/llvm/unittests/Analysis/TFUtilsTest.cpp @@ -123,3 +123,18 @@ auto Spec = getTensorSpecFromJSON(Ctx, *Value); EXPECT_FALSE(Spec.hasValue()); } + +TEST(TFUtilsTest, TensorSpecSizesAndTypes) { + auto Spec1D = TensorSpec::createSpec("Hi1", {1}); + auto Spec2D = TensorSpec::createSpec("Hi2", {1, 1}); + auto Spec1DLarge = TensorSpec::createSpec("Hi3", {10}); + auto Spec3DLarge = TensorSpec::createSpec("Hi3", {2, 4, 10}); + EXPECT_TRUE(Spec1D.isElementType()); + EXPECT_FALSE(Spec3DLarge.isElementType()); + EXPECT_EQ(Spec1D.getElementCount(), 1); + EXPECT_EQ(Spec2D.getElementCount(), 1); + EXPECT_EQ(Spec1DLarge.getElementCount(), 10); + EXPECT_EQ(Spec3DLarge.getElementCount(), 80); + EXPECT_EQ(Spec3DLarge.getElementByteSize(), sizeof(float)); + EXPECT_EQ(Spec1D.getElementByteSize(), sizeof(int16_t)); +} \ No newline at end of file