diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td --- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td @@ -17,6 +17,13 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpBase.td" +def Transform_EmptyOp : Transform_ConcreteOpType<"tensor.empty">; +def Transform_AllocTensorOp : Transform_ConcreteOpType<"bufferization.alloc_tensor">; + +//===----------------------------------------------------------------------===// +// OneShotBufferizeOp +//===----------------------------------------------------------------------===// + def OneShotBufferizeOp : Op, + DeclareOpInterfaceMethods]> { + let description = [{ + Try to eliminate all `tensor.empty` ops within the targeted op by replacing + them with a destination tensor. -def Transform_EmptyOp : Transform_ConcreteOpType<"tensor.empty">; -def Transform_AllocTensorOp : Transform_ConcreteOpType<"bufferization.alloc_tensor">; + `tensor.empty` ops cannot be bufferizes. They can either be converted to + `bufferization.alloc_tensor` or replaced with another tensor (via this + transform). `tensor.empty` does not specify the contents of the returned + tensor so their results can be replaced with arbitrary tensor values as long + as the dimensions match. + + This transform looks for `tensor.empty` ops where the SSA use-def chain of + the result ends in a supported "anchor op" (always following the aliasing + OpOperand/OpResult chain). Currently supported anchor ops are: + - `tensor.insert_slice` + - `bufferization.yield` (inside `bufferization.alloc_tensor`) + + Example: + + ``` + %0 = tensor.empty() : tensor<5xf32> + %1 = linalg.fill ... outs(%0) + %2 = tensor.insert_slice %1 into %t[1][5][1] + ``` + + Is rewritten with: + ``` + %0 = tensor.extract_slice %t[1][5][1] + %1 = linalg.fill ... outs(%0) + %2 = tensor.insert_slice %1 into %t[1][5][1] + ``` + + The above example can bufferize without an allocation (in the absence of + other conflicts) because there is no longer a `tensor.empty` op. + + See `-eliminate-empty-tensors` for more details. + + #### Return modes + + This transform reads the target handle and modifies the payload. It does + not produce any handle. + }]; + + let arguments = (ins PDL_Operation:$target); + + let results = (outs); + + let assemblyFormat = "$target attr-dict"; +} + +//===----------------------------------------------------------------------===// +// EmptyTensorToAllocTensorOp +//===----------------------------------------------------------------------===// def EmptyTensorToAllocTensorOp : Op &effects) { + onlyReadsHandle(getTarget(), effects); + modifiesPayload(effects); +} + +DiagnosedSilenceableFailure +transform::EliminateEmptyTensorsOp::apply(TransformResults &transformResults, + TransformState &state) { + IRRewriter rewriter(getContext()); + OneShotBufferizationOptions options; + options.allowReturnAllocs = true; + + ArrayRef payloadOps = state.getPayloadOps(getTarget()); + for (Operation *target : payloadOps) { + OneShotAnalysisState state(target, options); + if (failed(analyzeOp(target, state))) + return mlir::emitSilenceableFailure(target->getLoc()) + << "failed to analyze op"; + if (failed(bufferization::insertSliceAnchoredEmptyTensorEliminationStep( + rewriter, target, state))) + return mlir::emitSilenceableFailure(target->getLoc()) + << "failed to eliminate insert_slice anchored tensor.empty ops"; + } + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // EmptyTensorToAllocTensorOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir @@ -130,3 +130,24 @@ %0 = tensor.empty() : tensor<2x2xf32> return %0 : tensor<2x2xf32> } + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!pdl.operation) -> !pdl.operation + transform.bufferization.eliminate_empty_tensors %0 +} + +// CHECK-LABEL: func @empty_tensor_elimination( +// CHECK: tensor.extract_slice +// CHECK: linalg.fill +// CHECK: tensor.insert_slice +func.func @empty_tensor_elimination( + %t: tensor<10xf32>, %f: f32) -> tensor<10xf32> { + %0 = tensor.empty() : tensor<5xf32> + %1 = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32> + %2 = tensor.insert_slice %1 into %t [1][5][1] + : tensor<5xf32> into tensor<10xf32> + return %2 : tensor<10xf32> +}