diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -110,6 +110,10 @@ /// that can be folded. LogicalResult foldTensorCast(Operation *op); +/// Return the dimensions of the given tensor value. +SmallVector getMixedSizes(OpBuilder &builder, Location loc, + Value value); + /// Create a rank-reducing ExtractSliceOp @[0 .. 0] with strides [1 .. 1] and /// appropriate sizes (i.e. `tensor.getSizes()`) to reduce the rank of `tensor` /// to that of `targetType`. diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -36,6 +36,21 @@ return nullptr; } +SmallVector tensor::getMixedSizes(OpBuilder &builder, + Location loc, Value value) { + auto tensorType = value.getType().cast(); + SmallVector result; + for (int64_t i = 0; i < tensorType.getRank(); ++i) { + if (tensorType.isDynamicDim(i)) { + Value size = builder.create(loc, value, i); + result.push_back(size); + } else { + result.push_back(builder.getIndexAttr(tensorType.getDimSize(i))); + } + } + return result; +} + //===----------------------------------------------------------------------===// // CastOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -808,16 +808,10 @@ generateOp.getBody().begin()); // Create tensor::InsertSliceOp. - SmallVector sliceSizes, sliceStrides; - for (int64_t i = 0; i < resultType.getRank(); ++i) { - sliceStrides.push_back(rewriter.getIndexAttr(1)); - if (srcType.isDynamicDim(i)) { - Value size = rewriter.create(loc, padOp.getSource(), i); - sliceSizes.push_back(size); - } else { - sliceSizes.push_back(rewriter.getIndexAttr(srcType.getDimSize(i))); - } - } + SmallVector sliceSizes = + getMixedSizes(rewriter, loc, padOp.getSource()); + SmallVector sliceStrides(srcType.getRank(), + rewriter.getIndexAttr(1)); rewriter.replaceOpWithNewOp( padOp, padOp.getSource(), generateOp.getResult(), /*offsets=*/padOp.getMixedLowPad(), sliceSizes, sliceStrides);