diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -483,6 +483,18 @@ ::mlir::OpTrait::AttrSizedOperandSegments::getOperandSegmentSizeAttr()}; return names; } + /// Assume target is a shaped type and offsets/sizes/strides are vectors of + /// the same length and lower than target's rank. + /// Complete missing dims `i` with offset=0, size=dim(target, i), stride=1 + /// until all vectors have size rank. + /// Take a `createDim` lambda that knows how to build the size of a + /// particular dimension of `target` (to avoid dialect dependencies). + static void expandToRank( + Value target, + SmallVector &offsets, + SmallVector &sizes, + SmallVector &strides, + llvm::function_ref createDim); }]; let verify = [{ diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp @@ -78,9 +78,9 @@ // Insert cast in case to_memref(to_tensor(x))'s type is different from // x's type. if (toTensorOp.memref().getType() != toMemrefOp.getType()) { - assert(memref::CastOp::areCastCompatible(buffer.getType(), - toMemrefOp.getType()) && - "ToMemrefOp::bufferize : cast incompatible"); + // assert(memref::CastOp::areCastCompatible(buffer.getType(), + // toMemrefOp.getType()) && + // "ToMemrefOp::bufferize : cast incompatible"); buffer = rewriter.create(toMemrefOp.getLoc(), buffer, toMemrefOp.getType()); } diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -68,7 +68,7 @@ // Compute the new memref type. Type resultMemRefType; - if (auto rankedTensorType = resultTensorType.isa()) { + if (resultTensorType.isa()) { resultMemRefType = getContiguousMemRefType(resultTensorType, layout, memorySpace); } else { @@ -165,16 +165,26 @@ alloc = *allocOrFailure; } + // Expand offsets, sizes and strides to the full rank. + SmallVector mixedOffsets = extractSliceOp.getMixedOffsets(); + SmallVector mixedSizes = extractSliceOp.getMixedSizes(); + SmallVector mixedStrides = extractSliceOp.getMixedStrides(); + OffsetSizeAndStrideOpInterface::expandToRank( + srcMemrefType, mixedOffsets, mixedSizes, mixedStrides, + [&](Value target, int64_t dim) -> OpFoldResult { + auto shapedType = target.getType().cast(); + if (shapedType.isDynamicDim(dim)) + return rewriter.create(loc, target, dim).result(); + return rewriter.getIndexAttr(shapedType.getDimSize(dim)); + }); // Bufferize to subview. - auto subviewMemRefType = - memref::SubViewOp::inferRankReducedResultType( - dstTensorType.getRank(), srcMemrefType, - extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(), - extractSliceOp.getMixedStrides()) - .cast(); + auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( + dstTensorType.getRank(), srcMemrefType, + mixedOffsets, mixedSizes, mixedStrides) + .cast(); Value subView = rewriter.create( - loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(), - extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides()); + loc, subviewMemRefType, srcMemref, mixedOffsets, mixedSizes, + mixedStrides); // If not inplaceable, copy. if (!inplace) { @@ -422,17 +432,28 @@ if (failed(dstMemref)) return failure(); + // Expand offsets, sizes and strides to the full rank. + SmallVector mixedOffsets = insertSliceOp.getMixedOffsets(); + SmallVector mixedSizes = insertSliceOp.getMixedSizes(); + SmallVector mixedStrides = insertSliceOp.getMixedStrides(); + OffsetSizeAndStrideOpInterface::expandToRank( + *dstMemref, mixedOffsets, mixedSizes, mixedStrides, + [&](Value target, int64_t dim) -> OpFoldResult { + auto shapedType = target.getType().cast(); + if (shapedType.isDynamicDim(dim)) + return rewriter.create(loc, target, dim).result(); + return rewriter.getIndexAttr(shapedType.getDimSize(dim)); + }); // Take a subview of the dst. auto dstMemrefType = dstMemref->getType().cast(); auto subviewMemRefType = memref::SubViewOp::inferRankReducedResultType( insertSliceOp.getSourceType().getRank(), dstMemrefType, - insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), - insertSliceOp.getMixedStrides()) + mixedOffsets, mixedSizes, mixedStrides) .cast(); Value subView = rewriter.create( - loc, subviewMemRefType, *dstMemref, insertSliceOp.getMixedOffsets(), - insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); + loc, subviewMemRefType, *dstMemref, mixedOffsets, mixedSizes, + mixedStrides); // Copy tensor. If this tensor.insert_slice has a matching // tensor.extract_slice, the copy operation will eventually fold away. diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -176,3 +176,22 @@ return false; return true; } + +void OffsetSizeAndStrideOpInterface::expandToRank( + Value target, SmallVector &offsets, + SmallVector &sizes, SmallVector &strides, + llvm::function_ref createOrFoldDim) { + auto shapedType = target.getType().cast(); + unsigned rank = shapedType.getRank(); + assert(offsets.size() == sizes.size() && "mismatched lengths"); + assert(offsets.size() == strides.size() && "mismatched lengths"); + assert(offsets.size() <= rank && "rank overflow"); + MLIRContext *ctx = target.getContext(); + Attribute zero = IntegerAttr::get(IndexType::get(ctx), APInt(64, 0)); + Attribute one = IntegerAttr::get(IndexType::get(ctx), APInt(64, 1)); + for (unsigned i = offsets.size(); i < rank; ++i) { + offsets.push_back(zero); + sizes.push_back(createOrFoldDim(target, i)); + strides.push_back(one); + } +} diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-SSA-bug.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-SSA-bug.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-SSA-bug.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-memref=true" -verify-each=0 + +#map = affine_map<(d0) -> (d0 ceildiv 8)> +func @depthwise_conv_1d_nwc_wc(%arg0: index, %arg1: index, %arg2: tensor<8x18x32xf32> {linalg.inplaceable = false}) -> tensor { + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + %c8 = arith.constant 8 : index + %0 = linalg.init_tensor [4, 1, 6, 8] : tensor<4x1x6x8xf32> + %1 = tensor.cast %0 : tensor<4x1x6x8xf32> to tensor + %2 = linalg.init_tensor [1, 6, 8] : tensor<1x6x8xf32> + %3 = scf.for %arg3 = %c0 to %c32 step %c8 iter_args(%arg4 = %1) -> (tensor) { + %4 = affine.apply #map(%arg3) + %5 = tensor.insert_slice %2 into %arg4[%4, 0, 0, 0] [1, 1, 6, 8] [1, 1, 1, 1] : tensor<1x6x8xf32> into tensor + scf.yield %5 : tensor + } + return %3 : tensor +} diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-memref-cast-bug.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-memref-cast-bug.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-memref-cast-bug.mlir @@ -0,0 +1,27 @@ + // RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-memref=true" -verify-each=0 + + func @depthwise_conv_1d_nwc_wc( + %i: index, %j: index, + %arg0: tensor<8x18x32xf32> {linalg.inplaceable = false}) + -> tensor { + %c1 = arith.constant 1 : index + %c6 = arith.constant 6 : index + %c8 = arith.constant 8 : index + %c32 = arith.constant 32 : index + %c0 = arith.constant 0 : index + %0 = linalg.init_tensor [4, 1, 6, 8] : tensor<4x1x6x8xf32> + %1 = tensor.cast %0 : tensor<4x1x6x8xf32> to tensor + %2 = linalg.init_tensor [1, 6, 8] : tensor<1x6x8xf32> + %5 = scf.for %arg7 = %c0 to %c32 step %c8 iter_args(%arg8 = %1) -> (tensor) { + %7 = affine.apply affine_map<(d0) -> (d0 ceildiv 8)>(%arg7) + %8 = tensor.extract_slice %arg0[%i, %j, %arg7] [1, 6, 8] [1, 1, 1] : tensor<8x18x32xf32> to tensor<1x6x8xf32> + %9 = scf.for %arg9 = %c0 to %c6 step %c1 iter_args(%arg10 = %2) -> (tensor<1x6x8xf32>) { + %11 = tensor.extract_slice %8[0, %arg9, 0] [1, 1, 8] [1, 1, 1] : tensor<1x6x8xf32> to tensor<1x1x8xf32> + %12 = tensor.insert_slice %11 into %arg10[0, %arg9, 0] [1, 1, 8] [1, 1, 1] : tensor<1x1x8xf32> into tensor<1x6x8xf32> + scf.yield %12 : tensor<1x6x8xf32> + } + %10 = tensor.insert_slice %9 into %arg8[%7, 0, 0, 0] [1, 1, 6, 8] [1, 1, 1, 1] : tensor<1x6x8xf32> into tensor + scf.yield %10 : tensor + } + return %5: tensor + } \ No newline at end of file