diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -80,6 +80,11 @@ /// no IndexAttr and that IndexType have no bitwidth. bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2); +/// Helper function to convert a vector of `OpFoldResult`s into a vector of +/// `Value`s. +SmallVector getAsValues(OpBuilder &b, Location loc, + ArrayRef valueOrAttrVec); + } // namespace mlir #endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp @@ -8,8 +8,8 @@ #include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Interfaces/InferTypeOpInterface.h" using namespace mlir; @@ -134,16 +134,6 @@ builder, loc, src, dstStaticShape, reassocation); } -/// Helper function to convert a vector of `OpFoldResult`s into a vector of -/// `Value`s. -static SmallVector getAsValues(OpBuilder &b, Location loc, - ArrayRef valueOrAttrVec) { - return llvm::to_vector<4>( - llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { - return getValueOrCreateConstantIndexOp(b, loc, value); - })); -} - template struct ReifyExpandOrCollapseShapeOp : public ReifyRankedShapedTypeOpInterface::ExternalModel< diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/IR/Matchers.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/APSInt.h" @@ -124,4 +125,14 @@ auto v1 = ofr1.dyn_cast(), v2 = ofr2.dyn_cast(); return v1 && v1 == v2; } + +/// Helper function to convert a vector of `OpFoldResult`s into a vector of +/// `Value`s. +SmallVector getAsValues(OpBuilder &b, Location loc, + ArrayRef valueOrAttrVec) { + return llvm::to_vector<4>( + llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value { + return getValueOrCreateConstantIndexOp(b, loc, value); + })); +} } // namespace mlir