diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -156,6 +156,36 @@ }]; } +def ApplyDeadCodeEliminationOp : TransformDialectOp<"apply_dce", + [TransformOpInterface, TransformEachOpTrait, + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { + let summary = "Eliminate dead operations in the body of the target op"; + let description = [{ + This transform applies dead code elimination (DCE) to the body of the + targeted op. + + Note: "transform.apply_patterns" with an empty region can also be used to + remove dead ops. However, that op applies additional simplifications such as + op folding and region simplification. + + This transform reads the target handle and modifies the payload. Note that + this transform may silently remove payload ops from handles. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs); + let assemblyFormat = "`to` $target attr-dict `:` type($target)"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + def ApplyPatternsOp : TransformDialectOp<"apply_patterns", [TransformOpInterface, TransformEachOpTrait, DeclareOpInterfaceMethods, diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -287,6 +287,71 @@ transform::modifiesPayload(effects); } +//===----------------------------------------------------------------------===// +// ApplyDeadCodeEliminationOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::ApplyDeadCodeEliminationOp::applyToOne( + transform::TransformRewriter &rewriter, Operation *target, + ApplyToEachResultList &results, transform::TransformState &state) { + // Make sure that this transform is not applied to itself. Modifying the + // transform IR while it is being interpreted is generally dangerous. + DiagnosedSilenceableFailure payloadCheck = + ensurePayloadIsSeparateFromTransform(*this, target); + if (!payloadCheck.succeeded()) + return payloadCheck; + + // Maintain a worklist of potentially dead ops. + SetVector worklist; + + // Helper function that adds all defining ops of used values (operands and + // operands of nested ops). + auto addDefiningOpsToWorklist = [&](Operation *op) { + op->walk([&](Operation *op) { + for (Value v : op->getOperands()) + if (Operation *defOp = v.getDefiningOp()) + if (target->isProperAncestor(defOp)) + worklist.insert(defOp); + }); + }; + + // Helper function that erases an op. + auto eraseOp = [&](Operation *op) { + // Remove op and nested ops from the worklist. + op->walk([&](Operation *op) { + auto it = llvm::find(worklist, op); + if (it != worklist.end()) + worklist.erase(it); + }); + rewriter.eraseOp(op); + }; + + // Initial walk over the IR. + target->walk([&](Operation *op) { + if (op != target && isOpTriviallyDead(op)) { + addDefiningOpsToWorklist(op); + eraseOp(op); + } + }); + + // Erase all ops that have become dead. + while (!worklist.empty()) { + Operation *op = worklist.pop_back_val(); + if (!isOpTriviallyDead(op)) + continue; + addDefiningOpsToWorklist(op); + eraseOp(op); + } + + return DiagnosedSilenceableFailure::success(); +} + +void transform::ApplyDeadCodeEliminationOp::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getTarget(), effects); + transform::modifiesPayload(effects); +} + //===----------------------------------------------------------------------===// // ApplyPatternsOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -1974,3 +1974,26 @@ %bar = transform.select "test.bar" in %0 : (!transform.any_op) -> !transform.any_op test_print_remark_at_operand %bar, "found bar" : !transform.any_op } + +// ----- + +// CHECK-LABEL: func @apply_dce( +// CHECK-NEXT: memref.store +// CHECK-NEXT: return +func.func @apply_dce(%f: f32, %m: memref<5xf32>, %idx: index) { + // Two dead ops, interleaved with a non-dead op. + %0 = tensor.empty() : tensor<5xf32> + memref.store %f, %m[%idx] : memref<5xf32> + %1 = tensor.insert %f into %0[%idx] : tensor<5xf32> + return +} + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + %func_op = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %empty_op = transform.structured.match ops{["tensor.empty"]} in %func_op : (!transform.any_op) -> !transform.any_op + transform.apply_dce to %func_op : !transform.any_op + + // expected-remark @below{{0}} + test_print_number_of_associated_payload_ir_ops %empty_op : !transform.any_op +}