diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td --- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td @@ -60,4 +60,34 @@ }]; } +def EmptyTensorToAllocTensorOp + : Op { + let description = [{ + Replace a tensor.empty with a bufferization.tensor_alloc. + + ### Return modes + + This operation does not return any results. It ignores non tensor.empty + operations. The transformation succeeds if all operations referred to by + `target` rewrite correctly. Otherwise, the transform silently fails. + + }]; + + let arguments = (ins PDL_Operation:$target); + let results = (outs); + + let assemblyFormat = "$target attr-dict"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::llvm::SmallVector<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); + }]; +} + #endif // BUFFERIZATION_TRANSFORM_OPS diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -14,7 +14,9 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformUtils.h" using namespace mlir; using namespace mlir::bufferization; @@ -67,6 +69,25 @@ effects.emplace_back(MemoryEffects::Free::get(), getTarget(), TransformMappingResource::get()); } + +//===----------------------------------------------------------------------===// +// EmptyTensorToAllocTensorOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +EmptyTensorToAllocTensorOp::applyToOne(Operation *target, + SmallVector &results, + transform::TransformState &state) { + if (auto emptyOp = dyn_cast_or_null(target)) { + TrivialPatternRewriter rewriter(target->getContext()); + rewriter.setInsertionPoint(target); + rewriter.replaceOpWithNewOp( + emptyOp, emptyOp.getType(), emptyOp.getDynamicSizes()); + return DiagnosedSilenceableFailure::success(); + } + return emitDefaultSilenceableFailure(target); +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir @@ -118,3 +118,49 @@ // CHECK: return %[[C]] : memref<12x6xf32> return %D : tensor<12x6xf32> } + +// ----- + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.empty"]} in %arg1 + transform.bufferization.empty_tensor_to_alloc_tensor %0 +} + +// Expect `bufferization.empty_tensor_to_alloc_tensor` to replace the tensor.empty. +func.func @empty_to_tensor_alloc() -> tensor<2x2xf32> { + // CHECK: bufferization.alloc_tensor + %0 = tensor.empty() : tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.empty"]} in %arg1 + transform.bufferization.empty_tensor_to_alloc_tensor %0 +} + +// Expect `bufferization.empty_tensor_to_alloc_tensor` to simply return success. +func.func @tensor_alloc_to_tensor_alloc() -> tensor<2x2xf32> { + // CHECK: bufferization.alloc_tensor + %0 = bufferization.alloc_tensor() : tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["arith.addf"]} in %arg1 + // expected-error @below {{failed to apply}} + transform.bufferization.empty_tensor_to_alloc_tensor %0 +} + +// Expect `bufferization.empty_tensor_to_alloc_tensor` to emit an error. +func.func @empty_to_tensor_alloc(%t1: tensor<2x2xf32>, %t2: tensor<2x2xf32>) -> tensor<2x2xf32> { + // expected-note @below {{when applied to this op}} + %0 = arith.addf %t1, %t2 : tensor<2x2xf32> + return %0 : tensor<2x2xf32> +}