diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h @@ -28,8 +28,8 @@ // Creates dim ops or constant ops for each dimension of the ranked tensor // argument and returns these as values. -SmallVector createDimValues(OpBuilder &b, Location loc, - Value rankedTensor); +SmallVector createDimValues(OpBuilder &b, Location loc, + Value rankedTensor); } // namespace tensor } // namespace mlir diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp --- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp @@ -51,13 +51,17 @@ return dynamicDims; } -SmallVector mlir::tensor::createDimValues(OpBuilder &b, Location loc, - Value rankedTensor) { +SmallVector +mlir::tensor::createDimValues(OpBuilder &b, Location loc, Value rankedTensor) { auto tensorTy = rankedTensor.getType().cast(); - SmallVector dims; + SmallVector dims; for (const auto &en : llvm::enumerate(tensorTy.getShape())) { - dims.push_back( - b.createOrFold(loc, rankedTensor, en.index())); + if (ShapedType::isDynamic(en.value())) { + dims.push_back( + b.createOrFold(loc, rankedTensor, en.index())); + } else { + dims.push_back(b.getIndexAttr(en.value())); + } } return dims; }