diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -80,6 +80,46 @@ } }; +/// Bufferization of tensor.collapse_shape. Replace with memref.collapse_shape. +struct CollapseShapeOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + return false; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + return false; + } + + SmallVector + getAliasingOpResult(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + if (&opOperand == &op->getOpOperand(0) /*src*/) + return {op->getOpResult(0)}; + return {}; + } + + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const BufferizationState &state) const { + return BufferRelation::Equivalent; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationState &state) const { + auto collapseShapeOp = cast(op); + Value buffer = + *state.getBuffer(rewriter, collapseShapeOp->getOpOperand(0) /*src*/); + Type resultType = + getMemRefType(collapseShapeOp.getResultType(), state.getOptions()); + replaceOpWithNewBufferizedOp( + rewriter, op, resultType, buffer, collapseShapeOp.reassociation()); + return success(); + } +}; + /// Bufferization of tensor.dim. Replace with memref.dim. struct DimOpInterface : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + return false; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + return false; + } + + SmallVector + getAliasingOpResult(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + if (&opOperand == &op->getOpOperand(0) /*src*/) + return {op->getOpResult(0)}; + return {}; + } + + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const BufferizationState &state) const { + return BufferRelation::Equivalent; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationState &state) const { + auto expandShapeOp = cast(op); + Value buffer = + *state.getBuffer(rewriter, expandShapeOp->getOpOperand(0) /*src*/); + Type resultType = + getMemRefType(expandShapeOp.getResultType(), state.getOptions()); + replaceOpWithNewBufferizedOp( + rewriter, op, resultType, buffer, expandShapeOp.reassociation()); + return success(); + } +}; + /// Bufferization of tensor.extract_slice. Replace with memref.subview. struct ExtractSliceOpInterface : public BufferizableOpInterface::ExternalModel(); + registry.addOpInterface(); registry.addOpInterface(); + registry.addOpInterface(); registry.addOpInterface(); registry.addOpInterface(); registry.addOpInterface(); diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -301,3 +301,31 @@ // CHECK: return %[[r]] return %0 : tensor<5xf32> } + +// CHECK-LABEL: func @tensor.expand_shape( +// CHECK-SAME: %[[t1:.*]]: tensor +func @tensor.expand_shape(%t1: tensor) -> tensor<2x?x10xf32> { + // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref + // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] [ + // CHECK-SAME: [0, 1], [2]] : memref into memref<2x?x10xf32> + %0 = tensor.expand_shape %t1 [[0, 1], [2]] + : tensor into tensor<2x?x10xf32> + + // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]] + // CHECK: return %[[r]] + return %0 : tensor<2x?x10xf32> +} + +// CHECK-LABEL: func @tensor.collapse_shape( +// CHECK-SAME: %[[t1:.*]]: tensor<2x?x?xf32> +func @tensor.collapse_shape(%t1: tensor<2x?x?xf32>) -> tensor { + // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref<2x?x?xf32> + // CHECK: %[[collapsed:.*]] = memref.collapse_shape %[[m1]] [ + // CHECK-SAME: [0, 1], [2]] : memref<2x?x?xf32> into memref + %0 = tensor.collapse_shape %t1 [[0, 1], [2]] + : tensor<2x?x?xf32> into tensor + + // CHECK: %[[r:.*]] = bufferization.to_tensor %[[collapsed]] + // CHECK: return %[[r]] + return %0 : tensor +}