diff --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h --- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h +++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h @@ -39,14 +39,41 @@ //===----------------------------------------------------------------------===// // Codegen-compatible structures for Vector type. //===----------------------------------------------------------------------===// -template -struct Vector { +namespace detail { + template constexpr unsigned nextPowerOf2(); + template <> constexpr unsigned nextPowerOf2<0>() { return 1; } + template <> constexpr unsigned nextPowerOf2<1>() { return 1; } + template constexpr unsigned nextPowerOf2() { + return (!(N & (N - 1))) ? N : 2 * nextPowerOf2<(N + 1) / 2>(); + } +} // end namespace detail + +// N-D vectors recurse down to 1-D. +template struct Vector { + constexpr Vector &operator[](unsigned i) { return vector[i]; } + constexpr const Vector &operator[](unsigned i) const { + return vector[i]; + } + +private: Vector vector[Dim]; }; -template -struct Vector { +// 1-D vectors in LLVM are automatically padded to the next power of 2. +// We insert explicit padding in to account for this. +template struct Vector { + Vector() { + static_assert(detail::nextPowerOf2() >= sizeof(T[Dim]), + "size error"); + static_assert(detail::nextPowerOf2() < 2 * sizeof(T[Dim]), + "size error"); + } + constexpr T &operator[](unsigned i) { return vector[i]; } + constexpr const T &operator[](unsigned i) const { return vector[i]; } + +private: T vector[Dim]; + char padding[detail::nextPowerOf2() - sizeof(T[Dim])]; }; template diff --git a/mlir/include/mlir/ExecutionEngine/RunnerUtils.h b/mlir/include/mlir/ExecutionEngine/RunnerUtils.h --- a/mlir/include/mlir/ExecutionEngine/RunnerUtils.h +++ b/mlir/include/mlir/ExecutionEngine/RunnerUtils.h @@ -92,7 +92,7 @@ static_assert(sizeof(val) == M * StaticSizeMult::value * sizeof(T), "Incorrect vector size!"); // First - os << "(" << val.vector[0]; + os << "(" << val[0]; if (M > 1) os << ", "; if (sizeof...(Dims) > 1) @@ -100,14 +100,14 @@ // Kernel for (unsigned i = 1; i + 1 < M; ++i) { printSpace(os, 2 * sizeof...(Dims)); - os << val.vector[i] << ", "; + os << val[i] << ", "; if (sizeof...(Dims) > 1) os << "\n"; } // Last if (M > 1) { printSpace(os, sizeof...(Dims)); - os << val.vector[M - 1]; + os << val[M - 1]; } os << ")"; }