diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -164,6 +164,74 @@ }]; } +//===----------------------------------------------------------------------===// +// EliminateLinalgOpAnchoredEmptyTensorsOp +//===----------------------------------------------------------------------===// + +def EliminateLinalgOpAnchoredEmptyTensorsOp + : Op, + DeclareOpInterfaceMethods]> { + let description = [{ + Try to eliminate all `tensor.empty` op uses that are anchored on a LinalgOp + within the targeted op. + + This op is similar to `bufferization.eliminate_empty_tensors`, but specific + to LinalgOps. + + `tensor.empty` ops cannot be bufferized. They can either be converted to + `bufferization.alloc_tensor` or replaced with another tensor (via this + transform). `tensor.empty` does not specify the contents of the returned + tensor so their results can be replaced with arbitrary tensor values as long + as the dimensions match. + + This transform looks for `tensor.empty` ops where the SSA use-def chain of + the result ends in a supported LinalgOp (always following the aliasing + OpOperand/OpResult chain). The following LinalgOps are supported: + - Only parallel iterator types. + - The use-def chain ends in an input operand of the LinalgOp. + - The LinalgOp has an unused output operand with the same shape and + indexing map. + + Example: + + ``` + %0 = tensor.empty() + %1 = linalg.matmul ins(...) outs(%0) + %2 = linalg.generic ins(%1) outs(%dest) { + ^bb0(%in: f32, %out: f32): + // out not used + } + ``` + + Is rewritten with: + ``` + %0 = tensor.empty() + %1 = linalg.matmul ins(...) outs(%dest) + %2 = linalg.generic ins(%0) outs(%1) { + ^bb0(%in: f32, %out: f32): + // Use %out instead of %in + } + ``` + + After this transformation, the "ins" operand has no uses inside the body of + the LinalgOp and can be folded away with existing cleanup patterns. + Afterwards, the tensor::EmptyOp can also fold away, so that the example can + bufferize without an allocation (in the absence of other conflicts). + + #### Return modes + + This transform reads the target handle and modifies the payload. It does + not produce any handle. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + + let results = (outs); + + let assemblyFormat = "$target attr-dict `:` type($target)"; +} + //===----------------------------------------------------------------------===// // FuseOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -27,6 +27,10 @@ #include "llvm/ADT/SmallSet.h" namespace mlir { +namespace bufferization { +class OneShotAnalysisState; +} // namespace bufferization + namespace linalg { class LinalgOp; @@ -38,6 +42,68 @@ /// Return vector::CombiningKind for the given op. std::optional getCombinerOpKind(Operation *combinerOp); +//===----------------------------------------------------------------------===// +// Bufferization-related transforms. +//===----------------------------------------------------------------------===// + +/// Materialize a buffer allocation for the given tensor.pad op and lower the +/// op to linalg.fill/linalg.generic + memref.tensor_store. E.g.: +/// +/// %0 = tensor.pad low[%l] high[%h] %t ... +/// +/// is lowered to: +/// +/// %alloc = memref.alloc +/// linalg.fill ... outs(%alloc) +/// %subview = memref.subview %alloc [%l] [...] [1] +/// memref.tensor_store %t, %subview +/// %0 = bufferization.to_tensor %alloc restrict writable +/// +/// In addition to rewriting the IR as shown above, the result of the +/// bufferization.to_tensor op is returned. +Value bufferizeToAllocation(RewriterBase &rewriter, tensor::PadOp padOp, + Attribute memorySpace = {}); + +/// Materialize a buffer allocation for the given tensor value. E.g.: +/// +/// %alloc = memref.alloc +/// memref.tensor_store %value, %alloc +/// %0 = bufferization.to_tensor %alloc restrict writable +/// +/// In case `value` is a tensor.pad result, the corresponding overload is used +/// internally to produce a better bufferization. +Value bufferizeToAllocation(RewriterBase &rewriter, Value value, + Attribute memorySpace = {}); + +/// Try to eliminate tensor::EmptyOps inside `op` that are anchored on a +/// LinalgOp. This transforms looks for LinalgOps that have an unused output +/// operand and an input operand that is rooted in a tensor::EmptyOp. The +/// tensor::EmptyOp uses are replaced with the output operand and the two +/// operands of the LinalgOp are swapped. +/// +/// Example: +/// %0 = tensor.empty() +/// %1 = linalg.matmul ins(...) outs(%0) +/// %2 = linalg.generic ins(%1) outs(%dest) { +/// ^bb0(%in: f32, %out: f32): +/// // out not used +/// } +/// +/// The IR is transformed as follows: +/// %0 = tensor.empty() +/// %1 = linalg.matmul ins(...) outs(%dest) +/// %2 = linalg.generic ins(%0) outs(%1) { +/// ^bb0(%in: f32, %out: f32): +/// // Use %out instead of %in +/// } +/// +/// The "ins" operand has no uses inside the body of the LinalgOp and can be +/// folded away with existing cleanup patterns. Afterwards, the tensor::EmptyOp +/// can also fold away. +LogicalResult linalgOpAnchoredEmptyTensorEliminationStep( + RewriterBase &rewriter, Operation *op, + bufferization::OneShotAnalysisState &state); + //===----------------------------------------------------------------------===// // Structs that configure the behavior of various transformations. //===----------------------------------------------------------------------===// @@ -308,35 +374,6 @@ using LinalgLoops = SmallVector; -/// Materialize a buffer allocation for the given tensor.pad op and lower the -/// op to linalg.fill/linalg.generic + memref.tensor_store. E.g.: -/// -/// %0 = tensor.pad low[%l] high[%h] %t ... -/// -/// is lowered to: -/// -/// %alloc = memref.alloc -/// linalg.fill ... outs(%alloc) -/// %subview = memref.subview %alloc [%l] [...] [1] -/// memref.tensor_store %t, %subview -/// %0 = bufferization.to_tensor %alloc restrict writable -/// -/// In addition to rewriting the IR as shown above, the result of the -/// bufferization.to_tensor op is returned. -Value bufferizeToAllocation(RewriterBase &rewriter, tensor::PadOp padOp, - Attribute memorySpace = {}); - -/// Materialize a buffer allocation for the given tensor value. E.g.: -/// -/// %alloc = memref.alloc -/// memref.tensor_store %value, %alloc -/// %0 = bufferization.to_tensor %alloc restrict writable -/// -/// In case `value` is a tensor.pad result, the corresponding overload is used -/// internally to produce a better bufferization. -Value bufferizeToAllocation(RewriterBase &rewriter, Value value, - Attribute memorySpace = {}); - /// Fuse two `linalg.generic` operations that have a producer-consumer /// relationship captured through `fusedOperand`. The method expects /// that `areElementwiseOpsFusable` returns true for the given `fusedOperand`. diff --git a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt @@ -15,6 +15,7 @@ LINK_LIBS PUBLIC MLIRAffineDialect MLIRArithDialect + MLIRBufferizationTransforms MLIRFuncDialect MLIRIR MLIRLinalgDialect diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/TransformOps/Syntax.h" @@ -226,6 +227,37 @@ #undef DOWNSCALE return emitDefaultSilenceableFailure(target); } + +//===----------------------------------------------------------------------===// +// EliminateLinalgOpAnchoredEmptyTensorsOp +//===----------------------------------------------------------------------===// + +void transform::EliminateLinalgOpAnchoredEmptyTensorsOp::getEffects( + SmallVectorImpl &effects) { + onlyReadsHandle(getTarget(), effects); + modifiesPayload(effects); +} + +DiagnosedSilenceableFailure +transform::EliminateLinalgOpAnchoredEmptyTensorsOp::apply( + transform::TransformRewriter &rewriter, TransformResults &transformResults, + TransformState &state) { + bufferization::OneShotBufferizationOptions options; + options.allowReturnAllocs = true; + + for (Operation *target : state.getPayloadOps(getTarget())) { + bufferization::OneShotAnalysisState state(target, options); + if (failed(analyzeOp(target, state))) + return mlir::emitSilenceableFailure(target->getLoc()) + << "failed to analyze op"; + if (failed(linalg::linalgOpAnchoredEmptyTensorEliminationStep( + rewriter, target, state))) + return mlir::emitSilenceableFailure(target->getLoc()) + << "failed to eliminate LinalgOp anchored tensor.empty ops"; + } + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // FuseOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -11,6 +11,7 @@ DropUnitDims.cpp ElementwiseOpFusion.cpp ElementwiseToLinalg.cpp + EliminateEmptyTensors.cpp EraseUnusedOperandsAndResults.cpp FusePadOpWithLinalgProducer.cpp Fusion.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/EliminateEmptyTensors.cpp @@ -0,0 +1,107 @@ +//===- EmptyTensorElimination.cpp - tensor.empty op elimination -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" + +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/Bufferization/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" + +using namespace mlir; +using namespace mlir::bufferization; +using namespace mlir::linalg; + +/// Get an output operand that matches the given input operand and can be used +/// to eliminate a tensor.empty op. +static OpOperand *getUnusedOutOperand(LinalgOp op, OpOperand *in) { + for (OpOperand *operand : op.getDpsInitOperands()) { + // Operand must be unused. + if (op.payloadUsesValueFromOperand(operand)) + continue; + // Types must match. + if (operand->get().getType() != in->get().getType()) + continue; + // Indexing maps must match. + if (op.getMatchingIndexingMap(operand) != op.getMatchingIndexingMap(in)) + continue; + return operand; + } + return nullptr; +} + +LogicalResult linalg::linalgOpAnchoredEmptyTensorEliminationStep( + RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) { + OpBuilder::InsertionGuard g(rewriter); + DominanceInfo domInfo; + + op->walk([&](LinalgOp op) { + // Only ops with all "parallel" iterator types are supported. + if (op.getNumParallelLoops() != op.getNumLoops()) + return WalkResult::skip(); + + for (OpOperand *in : op.getDpsInputOperands()) { + // Skip non-tensor operands. + if (!in->get().getType().isa()) + continue; + + // Find tensor.empty ops on the reverse SSA use-def chain. Only follow + // equivalent tensors. I.e., stop when there are ops such as extract_slice + // on the path. + TraversalConfig config; + config.followEquivalentOnly = true; + config.alwaysIncludeLeaves = false; + SetVector emptyTensors = state.findValueInReverseUseDefChain( + in->get(), /*condition=*/ + [&](Value val) { return val.getDefiningOp(); }, + config); + if (emptyTensors.empty()) + continue; + + // Find matching out operand. + OpOperand *out = getUnusedOutOperand(op, in); + if (!out) + continue; + + // Check if this transform would violate dominance. + if (!llvm::all_of(emptyTensors, [&](Value v) { + return domInfo.properlyDominates(out->get(), v.getDefiningOp()); + })) + continue; + + // Replace all uses of the tensor.empty, but do not delete it yet. It will + // fold away later (to not invalidate DominanceInfo). + for (Value v : emptyTensors) { + assert(v.getDefiningOp() && "expected tensor.empty"); + rewriter.replaceAllUsesWith(v, out->get()); + } + + // Turn the "in" into an "out". + rewriter.updateRootInPlace(op, [&]() { + out->set(in->get()); + // The original "in" could be removed entirely here (because it will no + // longer have any uses in the payload), but we delegate this to + // existing cleanup patterns that remove unused operands. + in->set(emptyTensors.front()); + BlockArgument outArg = op.getMatchingBlockArgument(out); + assert(outArg.getUses().empty() && "expected that out has no uses"); + BlockArgument inArg = op.getMatchingBlockArgument(in); + rewriter.replaceAllUsesWith(inArg, outArg); + assert(!op.payloadUsesValueFromOperand(in) && + "expected that the in operand is now unused"); + }); + + state.resetCache(); + } + + return WalkResult::advance(); + }); + return success(); +} \ No newline at end of file diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/one-shot-bufferize-empty-tensor-elimination.mlir @@ -0,0 +1,42 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @eliminate_tensor_empty( +// CHECK-SAME: %[[arg0:.*]]: tensor<50x91xf32>, +// CHECK-NOT: tensor.empty +// CHECK: %[[filled:.*]] = linalg.fill {{.*}} outs(%[[arg0]] +// CHECK: %[[matmul:.*]] = linalg.matmul {{.*}} outs(%[[filled]] +// CHECK: %[[generic:.*]] = linalg.generic {{.*}} outs(%[[matmul]] +// CHECK: return %[[generic]] +func.func @eliminate_tensor_empty( + %arg0: tensor<50x91xf32>, %arg1: tensor<91xf32>, %arg2: tensor<50x1280xf32>, + %arg3: tensor<1280x91xf32>) -> tensor<50x91xf32> +{ + %cst = arith.constant 0.0 : f32 + %0 = tensor.empty() : tensor<50x91xf32> + %1 = linalg.fill ins(%cst : f32) + outs(%0 : tensor<50x91xf32>) -> tensor<50x91xf32> + %2 = linalg.matmul + ins(%arg2, %arg3 : tensor<50x1280xf32>, tensor<1280x91xf32>) + outs(%1 : tensor<50x91xf32>) -> tensor<50x91xf32> + %3 = linalg.generic + {indexing_maps = [affine_map<(d0, d1) -> (d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg1, %2 : tensor<91xf32>, tensor<50x91xf32>) + outs(%arg0 : tensor<50x91xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %16 = arith.addf %in, %in_0 : f32 + linalg.yield %16 : f32 + } -> tensor<50x91xf32> + return %3 : tensor<50x91xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.eliminate_empty_tensors %0 : !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.linalg.erase_unnecessary_inputs + } : !transform.any_op +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -9295,6 +9295,7 @@ ":Analysis", ":ArithDialect", ":AsmParser", + ":BufferizationTransforms", ":DialectUtils", ":FuncDialect", ":GPUDialect",