diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -115,12 +115,27 @@ if (tensorResultType.getRank() == 0) { // 0-d collapses must go through a different op builder. auto bufferType = buffer.getType().cast(); - // Assume identity layout: No offset. - assert(bufferType.getLayout().isIdentity() && - "non-zero offset for 0-d collapse not supported"); - MemRefLayoutAttrInterface layout; - auto resultType = MemRefType::get({}, tensorResultType.getElementType(), - layout, bufferType.getMemorySpace()); + MemRefType resultType; + + if (bufferType.getLayout().isIdentity()) { + // Standard layout: result type has no offset. + MemRefLayoutAttrInterface layout; + resultType = MemRefType::get({}, tensorResultType.getElementType(), + layout, bufferType.getMemorySpace()); + } else { + // Source memref has a layout map: result type has the same offset as + // the source type. + SmallVector strides; + int64_t offset; + if (failed(getStridesAndOffset(bufferType, strides, offset))) + return failure(); + AffineMap resultLayout = + makeStridedLinearLayoutMap({}, offset, op->getContext()); + resultType = + MemRefType::get({}, tensorResultType.getElementType(), resultLayout, + bufferType.getMemorySpaceAsInt()); + } + replaceOpWithNewBufferizedOp( rewriter, op, resultType, buffer, collapseShapeOp.reassociation()); return success(); diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -3,6 +3,8 @@ // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 20 + s0 + d1)> // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 140 + d1 * 20 + d2 * 5 + d3 + s0)> +// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0) -> (d0 + 1)> +// CHECK-DAG: #[[$MAP4:.*]] = affine_map<() -> (1)> // CHECK-LABEL: func @dim( // CHECK-SAME: %[[TENSOR:.*]]: tensor, @@ -361,3 +363,12 @@ // CHECK: return %[[r]] return %0 : tensor } + +// CHECK-LABEL: func @tensor.collapse_shape_of_slice( +func @tensor.collapse_shape_of_slice(%arg0: tensor<2xi32>) -> tensor { + // CHECK: memref.subview %{{.*}}[1] [1] [1] : memref<2xi32> to memref<1xi32, #[[$MAP3]]> + %0 = tensor.extract_slice %arg0[1] [1] [1] : tensor<2xi32> to tensor<1xi32> + // CHECK: memref.collapse_shape %{{.*}} [] : memref<1xi32, #[[$MAP3]]> into memref + %1 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor + return %1 : tensor +}