diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -110,6 +110,10 @@ /// that can be folded. LogicalResult foldTensorCast(Operation *op); +/// Return the dimensions of the given tensor value. +SmallVector getMixedSizes(OpBuilder &builder, Location loc, + Value value); + /// Create a rank-reducing ExtractSliceOp @[0 .. 0] with strides [1 .. 1] and /// appropriate sizes (i.e. `tensor.getSizes()`) to reduce the rank of `tensor` /// to that of `targetType`. diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -36,6 +36,21 @@ return nullptr; } +SmallVector tensor::getMixedSizes(OpBuilder &builder, + Location loc, Value value) { + auto tensorType = value.getType().cast(); + SmallVector result; + for (int64_t i = 0; i < tensorType.getRank(); ++i) { + if (tensorType.isDynamicDim(i)) { + Value size = builder.create(loc, value, i); + result.push_back(size); + } else { + result.push_back(builder.getIndexAttr(tensorType.getDimSize(i))); + } + } + return result; +} + //===----------------------------------------------------------------------===// // CastOp //===----------------------------------------------------------------------===// @@ -1465,18 +1480,8 @@ OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) { auto rankedTensorType = tensor.getType().cast(); unsigned rank = rankedTensorType.getRank(); - auto shape = rankedTensorType.getShape(); SmallVector offsets(rank, b.getIndexAttr(0)); - SmallVector sizes; - for (unsigned i = 0, e = rank; i < e; ++i) { - OpFoldResult dim; - if (rankedTensorType.isDynamicDim(i)) - dim = b.createOrFold( - loc, tensor, b.create(loc, i)); - else - dim = b.getIndexAttr(shape[i]); - sizes.push_back(dim); - } + SmallVector sizes = getMixedSizes(b, loc, tensor); SmallVector strides(rank, b.getIndexAttr(1)); return b.createOrFold(loc, targetType, tensor, offsets, sizes, strides); @@ -1818,18 +1823,8 @@ Value dest) { auto rankedTensorType = dest.getType().cast(); unsigned rank = rankedTensorType.getRank(); - auto shape = rankedTensorType.getShape(); SmallVector offsets(rank, b.getIndexAttr(0)); - SmallVector sizes; - for (unsigned i = 0, e = rank; i < e; ++i) { - OpFoldResult dim; - if (rankedTensorType.isDynamicDim(i)) - dim = b.createOrFold( - loc, dest, b.create(loc, i)); - else - dim = b.getIndexAttr(shape[i]); - sizes.push_back(dim); - } + SmallVector sizes = getMixedSizes(b, loc, dest); SmallVector strides(rank, b.getIndexAttr(1)); return b.createOrFold(loc, tensor, dest, offsets, sizes, strides); diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -812,16 +812,10 @@ generateOp.getBody().begin()); // Create tensor::InsertSliceOp. - SmallVector sliceSizes, sliceStrides; - for (int64_t i = 0; i < resultType.getRank(); ++i) { - sliceStrides.push_back(rewriter.getIndexAttr(1)); - if (srcType.isDynamicDim(i)) { - Value size = rewriter.create(loc, padOp.getSource(), i); - sliceSizes.push_back(size); - } else { - sliceSizes.push_back(rewriter.getIndexAttr(srcType.getDimSize(i))); - } - } + SmallVector sliceSizes = + getMixedSizes(rewriter, loc, padOp.getSource()); + SmallVector sliceStrides(srcType.getRank(), + rewriter.getIndexAttr(1)); rewriter.replaceOpWithNewOp( padOp, padOp.getSource(), generateOp.getResult(), /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides);