diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td --- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td +++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td @@ -159,6 +159,42 @@ "$target attr-dict `:` functional-type(operands, results)"; } +def MemRefEraseDeadAllocAndStoresOp + : Op, + ReportTrackingListenerFailuresOpTrait + ]> { + let description = [{ + This applies memory optimization on memref. In particular it does store to + load forwarding, dead store elimination and dead alloc elimination. + + #### Return modes + + This operation applies a set of memory optimization on the whole region of + the operand. + + The transformation does not consume the target handle. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs); + + let assemblyFormat = "$target attr-dict `:` functional-type($target, results)"; + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<(ins "Value":$target)> + ]; + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + def MemRefMakeLoopIndependentOp : Op sizes); +// Track temporary allocations that are never read from. If this is the case +// it means both the allocations and associated stores can be removed. +void eraseDeadAllocAndStores(Operation *parentOp); + } // namespace memref } // namespace mlir diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -14,11 +14,13 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/Debug.h" @@ -132,6 +134,32 @@ return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// MemRefEraseDeadAllocAndStoresOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::MemRefEraseDeadAllocAndStoresOp::applyToOne( + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + // Apply store to load forwarding and dead store elimination. + vector::transferOpflowOpt(rewriter, target); + memref::eraseDeadAllocAndStores(target); + return DiagnosedSilenceableFailure::success(); +} + +void transform::MemRefEraseDeadAllocAndStoresOp::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getTarget(), effects); + transform::modifiesPayload(effects); +} +void transform::MemRefEraseDeadAllocAndStoresOp::build(OpBuilder &builder, + OperationState &result, + Value target) { + result.addOperands(target); +} + //===----------------------------------------------------------------------===// // MemRefMakeLoopIndependentOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp --- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp +++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" namespace mlir { namespace memref { @@ -120,5 +121,37 @@ return linearizedMemRefInfo; } +/// Returns true if all the uses of op are either Store/transfer_write. +/// There can be SubviewOp users as long as all its users are also +/// StoreOp/transfer_write. If return true it also fills out the uses, if it +/// returns false uses is unchanged. +static bool allUsesAreStores(Operation *op, std::vector &uses) { + std::vector opUses; + for (OpOperand &use : op->getUses()) { + Operation *useOp = use.getOwner(); + if (isa( + useOp) || + (isa(useOp) && allUsesAreStores(useOp, opUses))) { + opUses.push_back(useOp); + continue; + } + return false; + } + uses.insert(uses.end(), opUses.begin(), opUses.end()); + return true; +} + +void eraseDeadAllocAndStores(Operation *parentOp) { + std::vector opToErase; + parentOp->walk([&](memref::AllocOp op) { + if (allUsesAreStores(op, opToErase)) { + opToErase.push_back(op.getOperation()); + } + }); + for (Operation *op : opToErase) { + op->erase(); + } +} + } // namespace memref } // namespace mlir diff --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir --- a/mlir/test/Dialect/MemRef/transform-ops.mlir +++ b/mlir/test/Dialect/MemRef/transform-ops.mlir @@ -259,6 +259,50 @@ // ----- +// CHECK-LABEL: func.func @dead_alloc +func.func @dead_alloc() { + // CHECK-NOT: %{{.+}} = memref.alloc + %0 = memref.alloc() : memref<8x64xf32, 3> + %1 = memref.subview %0[0, 0] [8, 4] [1, 1] : memref<8x64xf32, 3> to + memref<8x4xf32, affine_map<(d0, d1) -> (d0 * 64 + d1)>, 3> + %c0 = arith.constant 0 : index + %cst_0 = arith.constant dense<0.000000e+00> : vector<1x4xf32> + vector.transfer_write %cst_0, %1[%c0, %c0] {in_bounds = [true, true]} : + vector<1x4xf32>, memref<8x4xf32, affine_map<(d0, d1) -> (d0 * 64 + d1)>, 3> + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.memref.erase_dead_alloc_and_stores %0 : (!transform.any_op) -> () +} + +// ----- + +// CHECK-LABEL: @store_to_load +// CHECK-SAME: (%[[ARG:.+]]: vector<4xf32>) +// CHECK-NOT: memref.alloc() +// CHECK-NOT: vector.transfer_write +// CHECK-NOT: vector.transfer_read +// CHECK: return %[[ARG]] : vector<4xf32> +func.func @store_to_load(%arg: vector<4xf32>) -> vector<4xf32> { + %c0 = arith.constant 0 : index + %cst_1 = arith.constant 0.000000e+00 : f32 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<64xf32> + vector.transfer_write %arg, %alloc[%c0] {in_bounds = [true]} : vector<4xf32>, memref<64xf32> + %r = vector.transfer_read %alloc[%c0], %cst_1 {in_bounds = [true]} : memref<64xf32>, vector<4xf32> + return %r : vector<4xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.memref.erase_dead_alloc_and_stores %0 : (!transform.any_op) -> () +} + +// ----- + // CHECK-LABEL: func @lower_to_llvm // CHECK-NOT: memref.alloc // CHECK: llvm.call @malloc diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -11491,6 +11491,7 @@ ":AffineDialect", ":ArithUtils", ":MemRefDialect", + ":VectorDialect", ], ) @@ -11594,11 +11595,13 @@ ":MemRefDialect", ":MemRefTransformOpsIncGen", ":MemRefTransforms", + ":MemRefUtils", ":NVGPUDialect", ":SCFDialect", ":TransformDialect", ":TransformUtils", ":VectorDialect", + ":VectorTransforms", "//llvm:Support", ], )