diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -156,4 +156,31 @@ let hasCanonicalizer = 1; } +//===----------------------------------------------------------------------===// +// WaitForBufferizationOp +//===----------------------------------------------------------------------===// + +def Bufferization_WaitForBufferizationOp + : Bufferization_Op<"wait_for_bufferization", []> { + let summary = "wait for the operand to be bufferized"; + let description = [{ + This op canonicalizes away once the tensor operand has been bufferized. The + tensor operand is bufferized if it is the result of a `to_tensor` op. + + This op is useful for partial bufferization of certain ops. E.g., such ops + may have a region that yields a tensor value, but their corresponding memref + variant may not be yielding anything. + + This op is used internally during bufferization. It should not be created + outside of `BufferizableOpInterface::bufferize` implementations. + }]; + + let arguments = (ins AnyTensor:$tensor); + let results = (outs); + // This op is fully verified by traits. + let verifier = ?; + let assemblyFormat = "$tensor attr-dict `:` type($tensor)"; + let hasCanonicalizer = 1; +} + #endif // BUFFERIZATION_OPS diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -297,6 +297,33 @@ return builder.create(alloc.getLoc(), alloc).getResult(); } +//===----------------------------------------------------------------------===// +// WaitForBufferizationOp +//===----------------------------------------------------------------------===// + +namespace { +/// Replace tensor.cast + to_memref by to_memref + memref.cast. +struct FoldWaitForBufferizationOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(WaitForBufferizationOp waitOp, + PatternRewriter &rewriter) const final { + if (waitOp.tensor().getDefiningOp()) { + rewriter.eraseOp(waitOp); + return success(); + } + + return failure(); + } +}; +} // namespace + +void WaitForBufferizationOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir @@ -219,3 +219,34 @@ // CHECK-SCF: return %[[r_tensor]], %[[pos]] return %r1, %r2 : tensor, index } + +// ----- + +// CHECK-LABEL: func @wait_for_bufferization_folds_away( +// CHECK-SAME: %[[m1:.*]]: memref {linalg.inplaceable = true}) -> f32 { + // CHECK-NOT: wait_for_bufferization + // CHECK-NOT: to_tensor + // CHECK-NOT: to_memref + bufferization.wait_for_bufferization %t1 : tensor + %c0 = arith.constant 0 : index + %1 = tensor.extract %t1[%c0] : tensor + return %1 : f32 +} + +// ----- + +// CHECK-LABEL: func @wait_for_bufferization_does_not_fold_away( +// CHECK-SAME: %[[m1:.*]]: memref {linalg.inplaceable = true}) -> f32 { + // CHECK: %[[m1_tensor:.*]] = bufferization.to_tensor %[[m1]] + // CHECK: %[[dummy:.*]] = "test.dummy_op"(%[[m1_tensor]]) + %0 = "test.dummy_op"(%t1) : (tensor) -> tensor + // CHECK: bufferization.wait_for_bufferization %[[dummy]] + bufferization.wait_for_bufferization %0 : tensor + %c0 = arith.constant 0 : index + %1 = tensor.extract %0[%c0] : tensor + return %1 : f32 +}