diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h --- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h +++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h @@ -12,8 +12,15 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/OpImplementation.h" +namespace mlir { +namespace tensor { +class EmptyOp; +} // namespace tensor +} // namespace mlir + //===----------------------------------------------------------------------===// // Bufferization Transform Operations //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td --- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td @@ -13,6 +13,7 @@ include "mlir/Dialect/Transform/IR/TransformDialect.td" include "mlir/Dialect/Transform/IR/TransformEffects.td" include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" include "mlir/Dialect/PDL/IR/PDLTypes.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpBase.td" @@ -60,4 +61,41 @@ }]; } +//===----------------------------------------------------------------------===// +// EmptyTensorToAllocTensorOp +//===----------------------------------------------------------------------===// + + +def Transform_EmptyOp : Transform_ConcreteOpType<"tensor.empty">; +def Transform_AllocTensorOp : Transform_ConcreteOpType<"bufferization.alloc_tensor">; + +def EmptyTensorToAllocTensorOp + : Op { + let description = [{ + Replace a tensor.empty with a bufferization.tensor_alloc. + + ### Return modes + + This operation consumes the `target` handle and produces the `transformed` + handle. `target` is expected to be a `tensor.empty` operation. The transform + always succeeds. + }]; + + let arguments = (ins Transform_EmptyOp:$target); + let results = (outs Transform_AllocTensorOp:$transformed); + + let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::tensor::EmptyOp target, + ::llvm::SmallVector<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); + }]; +} + #endif // BUFFERIZATION_TRANSFORM_OPS diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" using namespace mlir; @@ -67,6 +68,23 @@ effects.emplace_back(MemoryEffects::Free::get(), getTarget(), TransformMappingResource::get()); } + +//===----------------------------------------------------------------------===// +// EmptyTensorToAllocTensorOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +EmptyTensorToAllocTensorOp::applyToOne(tensor::EmptyOp target, + SmallVector &results, + transform::TransformState &state) { + IRRewriter rewriter(target->getContext()); + rewriter.setInsertionPoint(target); + auto alloc = rewriter.replaceOpWithNewOp( + target, target.getType(), target.getDynamicSizes()); + results.push_back(alloc); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir @@ -118,3 +118,19 @@ // CHECK: return %[[C]] : memref<12x6xf32> return %D : tensor<12x6xf32> } + +// ----- + +transform.sequence failures(propagate) { + ^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.empty"]} in %arg1 + %1 = transform.cast %0 : !pdl.operation to !transform.op<"tensor.empty"> + transform.bufferization.empty_tensor_to_alloc_tensor %1 : (!transform.op<"tensor.empty">) -> !transform.op<"bufferization.alloc_tensor"> +} + +// Expect `bufferization.empty_tensor_to_alloc_tensor` to replace the tensor.empty. +func.func @empty_to_tensor_alloc() -> tensor<2x2xf32> { + // CHECK: bufferization.alloc_tensor + %0 = tensor.empty() : tensor<2x2xf32> + return %0 : tensor<2x2xf32> +}