diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -108,12 +108,27 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto collapseShapeOp = cast(op); + RankedTensorType tensorResultType = collapseShapeOp.getResultType(); Value buffer = *state.getBuffer(rewriter, collapseShapeOp->getOpOperand(0) /*src*/); - Type resultType = - getMemRefType(collapseShapeOp.getResultType(), state.getOptions()); + + if (tensorResultType.getRank() == 0) { + // 0-d collapses must go through a different op builder. + auto bufferType = buffer.getType().cast(); + // Assume identity layout: No offset. + assert(bufferType.getLayout().isIdentity() && + "non-zero offset for 0-d collapse not supported"); + MemRefLayoutAttrInterface layout; + auto resultType = MemRefType::get({}, tensorResultType.getElementType(), + layout, bufferType.getMemorySpace()); + replaceOpWithNewBufferizedOp( + rewriter, op, resultType, buffer, collapseShapeOp.reassociation()); + return success(); + } + + // Result type is inferred by the builder. replaceOpWithNewBufferizedOp( - rewriter, op, resultType, buffer, collapseShapeOp.reassociation()); + rewriter, op, buffer, collapseShapeOp.getReassociationIndices()); return success(); } }; @@ -175,12 +190,15 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { auto expandShapeOp = cast(op); + auto tensorResultType = expandShapeOp.getResultType(); Value buffer = *state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/); - Type resultType = - getMemRefType(expandShapeOp.getResultType(), state.getOptions()); + + // Memref result type is inferred by the builder based on reassociation + // indices and result shape. replaceOpWithNewBufferizedOp( - rewriter, op, resultType, buffer, expandShapeOp.reassociation()); + rewriter, op, tensorResultType.getShape(), buffer, + expandShapeOp.getReassociationIndices()); return success(); } }; diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -1,6 +1,8 @@ // RUN: mlir-opt %s -tensor-bufferize | FileCheck %s -// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 20 + s0 + d1)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 140 + d1 * 20 + d2 * 5 + d3 + s0)> // CHECK-LABEL: func @dim( // CHECK-SAME: %[[TENSOR:.*]]: tensor, @@ -242,7 +244,7 @@ func @tensor.extract_slice( %t1: tensor, %idx1: index, %idx2: index) -> tensor { // CHECK: %[[m:.*]] = bufferization.to_memref %[[t1]] : memref - // CHECK: %[[r:.*]] = memref.subview %[[m]][5, %[[idx2]]] [%[[idx1]], 10] [1, 1] : memref to memref + // CHECK: %[[r:.*]] = memref.subview %[[m]][5, %[[idx2]]] [%[[idx1]], 10] [1, 1] : memref to memref %0 = tensor.extract_slice %t1[5, %idx2][%idx1, 10][1, 1] : tensor to tensor // CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]] @@ -256,7 +258,7 @@ func @tensor.extract_slice_rank_reducing( %t1: tensor, %idx1: index, %idx2: index) -> tensor { // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref - // CHECK: %[[r:.*]] = memref.subview %[[m1]][5, %[[idx1]], 10] [%[[idx2]], 1, 15] [1, 1, 1] : memref to memref + // CHECK: %[[r:.*]] = memref.subview %[[m1]][5, %[[idx1]], 10] [%[[idx2]], 1, 15] [1, 1, 1] : memref to memref %0 = tensor.extract_slice %t1[5, %idx1, 10][%idx2, 1, 15][1, 1, 1] : tensor to tensor // CHECK: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]] @@ -316,6 +318,23 @@ return %0 : tensor<2x?x10xf32> } +// CHECK-LABEL: func @tensor.expand_shape_of_slice( +// CHECK-SAME: %[[t1:.*]]: tensor +func @tensor.expand_shape_of_slice( + %t1: tensor, %o1: index, %s1: index) -> tensor { + // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref + // CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}, 5] [%{{.*}}, 10] [1, 1] : memref to memref + %0 = tensor.extract_slice %t1[%o1, 5][%s1, 10][1, 1] : + tensor to tensor + // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] [ + // CHECK-SAME: [0, 1], [2, 3]] : memref into memref + %1 = tensor.expand_shape %0 [[0, 1], [2, 3]] : + tensor into tensor + // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]] + // CHECK: return %[[r]] + return %1 : tensor +} + // CHECK-LABEL: func @tensor.collapse_shape( // CHECK-SAME: %[[t1:.*]]: tensor<2x?x?xf32> func @tensor.collapse_shape(%t1: tensor<2x?x?xf32>) -> tensor { @@ -329,3 +348,16 @@ // CHECK: return %[[r]] return %0 : tensor } + +// CHECK-LABEL: func @tensor.collapse_shape_to_scalar( +// CHECK-SAME: %[[t1:.*]]: tensor<1x1x1xf32> +func @tensor.collapse_shape_to_scalar(%t1: tensor<1x1x1xf32>) -> tensor { + // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<1x1x1xf32> + // CHECK: %[[collapsed:.*]] = memref.collapse_shape %[[m1]] [] : memref<1x1x1xf32> into memref + %0 = tensor.collapse_shape %t1 [] + : tensor<1x1x1xf32> into tensor + + // CHECK: %[[r:.*]] = bufferization.to_tensor %[[collapsed]] + // CHECK: return %[[r]] + return %0 : tensor +}