diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h @@ -28,11 +28,16 @@ ArrayRef low, ArrayRef high, bool nofold, Location loc, OpBuilder &builder); -// Creates dim ops for each dynamic dimension of the raked tensor argument and +// Creates dim ops for each dynamic dimension of the ranked tensor argument and // returns these as values. SmallVector createDynamicDimValues(OpBuilder &b, Location loc, Value rankedTensor); +// Creates dim ops or constant ops for each dimension of the ranked tensor +// argument and returns these as values. +SmallVector createDimValues(OpBuilder &b, Location loc, + Value rankedTensor); + } // namespace tensor } // namespace mlir diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp --- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp @@ -68,3 +68,14 @@ } return dynamicDims; } + +SmallVector mlir::tensor::createDimValues(OpBuilder &b, Location loc, + Value rankedTensor) { + auto tensorTy = rankedTensor.getType().cast(); + SmallVector dims; + for (const auto &en : llvm::enumerate(tensorTy.getShape())) { + dims.push_back( + b.createOrFold(loc, rankedTensor, en.index())); + } + return dims; +}