diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h @@ -33,6 +33,11 @@ SmallVector createDynamicDimValues(OpBuilder &b, Location loc, Value rankedTensor); +// Creates dim ops or constant ops for each dimension of the raked tensor +// argument and returns these as values. +SmallVector createDimValues(OpBuilder &b, Location loc, + Value rankedTensor); + } // namespace tensor } // namespace mlir diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp --- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp @@ -68,3 +68,14 @@ } return dynamicDims; } + +SmallVector mlir::tensor::createDimValues(OpBuilder &b, Location loc, + Value rankedTensor) { + auto tensorTy = rankedTensor.getType().cast(); + SmallVector dims; + for (const auto &en : llvm::enumerate(tensorTy.getShape())) { + dims.push_back( + b.createOrFold(loc, rankedTensor, en.index())); + } + return dims; +}