diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td --- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td @@ -17,6 +17,13 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpBase.td" +def Transform_EmptyOp : Transform_ConcreteOpType<"tensor.empty">; +def Transform_AllocTensorOp : Transform_ConcreteOpType<"bufferization.alloc_tensor">; + +//===----------------------------------------------------------------------===// +// OneShotBufferizeOp +//===----------------------------------------------------------------------===// + def OneShotBufferizeOp : Op, @@ -61,12 +68,65 @@ } //===----------------------------------------------------------------------===// -// EmptyTensorToAllocTensorOp +// EliminateEmptyTensorsOp //===----------------------------------------------------------------------===// +def EliminateEmptyTensorsOp + : Op, + DeclareOpInterfaceMethods]> { + let description = [{ + Try to eliminate all `tensor.empty` ops within the targeted op by replacing + them with a destination tensor. + + `tensor.empty` ops cannot be bufferizes. They can either be converted to + `bufferization.alloc_tensor` or replaced with another tensor (via this + transform). `tensor.empty` does not specify the contents of the returned + tensor so their results can be replaced with arbitrary tensor values as long + as the dimensions match. -def Transform_EmptyOp : Transform_ConcreteOpType<"tensor.empty">; -def Transform_AllocTensorOp : Transform_ConcreteOpType<"bufferization.alloc_tensor">; + This transform looks for `tensor.empty` ops where the SSA use-def chain of + the result ends in a supported "anchor op" (always following the aliasing + OpOperand/OpResult chain). Currently supported anchor ops are: + - `tensor.insert_slice` + - `bufferization.yield` (inside `bufferization.alloc_tensor`) + + Example: + + ``` + %0 = tensor.empty() : tensor<5xf32> + %1 = linalg.fill ... outs(%0) + %2 = tensor.insert_slice %1 into %t[1][5][1] + ``` + + Is rewritten with: + ``` + %0 = tensor.extract_slice %t[1][5][1] + %1 = linalg.fill ... outs(%0) + %2 = tensor.insert_slice %1 into %t[1][5][1] + ``` + + The above example can bufferize without an allocation (in the absence of + other conflicts) because there is no longer a `tensor.empty` op. + + See `-eliminate-empty-tensors` for more details. + + #### Return modes + + This transform reads the target handle and modifies the payload. It does + not produce any handle. + }]; + + let arguments = (ins PDL_Operation:$target); + + let results = (outs); + + let assemblyFormat = "$target attr-dict"; +} + +//===----------------------------------------------------------------------===// +// EmptyTensorToAllocTensorOp +//===----------------------------------------------------------------------===// def EmptyTensorToAllocTensorOp : Op &effects) { + onlyReadsHandle(getTarget(), effects); + modifiesPayload(effects); +} + +DiagnosedSilenceableFailure +transform::EliminateEmptyTensorsOp::apply(TransformResults &transformResults, + TransformState &state) { + IRRewriter rewriter(getContext()); + OneShotBufferizationOptions options; + options.allowReturnAllocs = true; + + ArrayRef payloadOps = state.getPayloadOps(getTarget()); + for (Operation *target : payloadOps) { + OneShotAnalysisState state(target, options); + if (failed(analyzeOp(target, state))) + return mlir::emitSilenceableFailure(target->getLoc()) + << "failed to analyze op"; + if (failed(bufferization::insertSliceAnchoredEmptyTensorEliminationStep( + rewriter, target, state))) + return mlir::emitSilenceableFailure(target->getLoc()) + << "failed to eliminate insert_slice anchored tensor.empty ops"; + if (failed(bufferization::allocTensorAnchoredEmptyTensorEliminationStep( + rewriter, target, state))) + return mlir::emitSilenceableFailure(target->getLoc()) + << "failed to eliminate alloc_tensor anchored tensor.empty ops"; + } + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // EmptyTensorToAllocTensorOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir @@ -134,3 +134,24 @@ %0 = tensor.empty() : tensor<2x2xf32> return %0 : tensor<2x2xf32> } + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation + transform.bufferization.eliminate_empty_tensors %0 +} + +// CHECK-LABEL: func @empty_tensor_elimination( +// CHECK: tensor.extract_slice +// CHECK: linalg.fill +// CHECK: tensor.insert_slice +func.func @empty_tensor_elimination( + %t: tensor<10xf32>, %f: f32) -> tensor<10xf32> { + %0 = tensor.empty() : tensor<5xf32> + %1 = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32> + %2 = tensor.insert_slice %1 into %t [1][5][1] + : tensor<5xf32> into tensor<10xf32> + return %2 : tensor<10xf32> +} diff --git a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir --- a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir @@ -35,6 +35,35 @@ // ----- +// CHECK-LABEL: func @tensor_pad_wrapped_and_eliminated( +// CHECK: bufferization.alloc_tensor({{.*}}) init { +// CHECK-NEXT: ^{{.*}}(%[[bbarg:.*]]: tensor): +// CHECK: %[[filled:.*]] = linalg.fill {{.*}} outs(%[[bbarg]] : tensor) +// CHECK: %[[inserted:.*]] = tensor.insert_slice %{{.*}} into %[[filled]] +// CHECK: bufferization.yield %[[inserted]] +func.func @tensor_pad_wrapped_and_eliminated( + %t: tensor, %l2: index, %h1: index, %h2: index) -> tensor { + %0 = tensor.pad %t low[5, %l2] high[%h1, %h2] { + ^bb0(%arg0: index, %arg1: index): + %c = arith.constant 50 : index + tensor.yield %c : index + } : tensor to tensor + return %0 : tensor +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %p0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = transform.get_result %p0[0] : (!pdl.operation) -> !transform.any_value + %2 = transform.structured.bufferize_to_allocation %1 + + %p1 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %3 = transform.structured.rewrite_in_destination_passing_style %p1 : (!pdl.operation) -> !pdl.operation + transform.bufferization.eliminate_empty_tensors %arg1 +} + +// ----- + // CHECK-LABEL: func @materialization_of_bbarg( // CHECK-SAME: %[[t:.*]]: tensor // CHECK: %[[c0:.*]] = arith.constant 0 : index