diff --git a/llvm/include/llvm/Analysis/Utils/TFUtils.h b/llvm/include/llvm/Analysis/Utils/TFUtils.h --- a/llvm/include/llvm/Analysis/Utils/TFUtils.h +++ b/llvm/include/llvm/Analysis/Utils/TFUtils.h @@ -101,18 +101,29 @@ class EvaluationResult { public: EvaluationResult(const EvaluationResult &) = delete; + EvaluationResult &operator=(const EvaluationResult &Other) = delete; + EvaluationResult(EvaluationResult &&Other); + EvaluationResult &operator=(EvaluationResult &&Other); + ~EvaluationResult(); - /// Get a pointer to the first element of the tensor at Index. + /// Get a (const) pointer to the first element of the tensor at Index. template T *getTensorValue(size_t Index) { return static_cast(getUntypedTensorValue(Index)); } + template const T *getTensorValue(size_t Index) const { + return static_cast(getUntypedTensorValue(Index)); + } + + /// Get a (const) pointer to the untyped data of the tensor. + void *getUntypedTensorValue(size_t Index); + const void *getUntypedTensorValue(size_t Index) const; + private: friend class TFModelEvaluator; EvaluationResult(std::unique_ptr Impl); - void *getUntypedTensorValue(size_t Index); std::unique_ptr Impl; }; diff --git a/llvm/lib/Analysis/TFUtils.cpp b/llvm/lib/Analysis/TFUtils.cpp --- a/llvm/lib/Analysis/TFUtils.cpp +++ b/llvm/lib/Analysis/TFUtils.cpp @@ -292,10 +292,21 @@ TFModelEvaluator::EvaluationResult::EvaluationResult(EvaluationResult &&Other) : Impl(std::move(Other.Impl)) {} +TFModelEvaluator::EvaluationResult & +TFModelEvaluator::EvaluationResult::operator=(EvaluationResult &&Other) { + Impl = std::move(Other.Impl); + return *this; +} + void *TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) { return TF_TensorData(Impl->getOutput()[Index]); } +const void * +TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) const { + return TF_TensorData(Impl->getOutput()[Index]); +} + #define TFUTILS_GETDATATYPE_IMPL(T, S, E) \ template <> int TensorSpec::getDataType() { return TF_##E; } diff --git a/llvm/unittests/Analysis/TFUtilsTest.cpp b/llvm/unittests/Analysis/TFUtilsTest.cpp --- a/llvm/unittests/Analysis/TFUtilsTest.cpp +++ b/llvm/unittests/Analysis/TFUtilsTest.cpp @@ -56,6 +56,8 @@ EXPECT_TRUE(ER.hasValue()); float Ret = *ER->getTensorValue(0); EXPECT_EQ(static_cast(Ret), 80); + EXPECT_EQ(ER->getUntypedTensorValue(0), + reinterpret_cast(ER->getTensorValue(0))); } // The input vector should be unchanged for (auto I = 0; I < KnownSize; ++I) { @@ -137,4 +139,4 @@ EXPECT_EQ(Spec3DLarge.getElementCount(), 80); EXPECT_EQ(Spec3DLarge.getElementByteSize(), sizeof(float)); EXPECT_EQ(Spec1D.getElementByteSize(), sizeof(int16_t)); -} \ No newline at end of file +}