diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -127,6 +127,35 @@ "`:` type($target) (`,` type($param)^)?"; } +def ApplyCommonSubexpressionEliminationOp : TransformDialectOp<"apply_cse", + [TransformOpInterface, TransformEachOpTrait, + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { + let summary = "Eliminate common subexpressions in the body of the target op"; + let description = [{ + This transform applies common subexpression elimination (CSE) to the body + of the targeted op. + + This transform reads the target handle and modifies the payload. Existing + handles to operations inside of the targeted op are retained and updated if + necessary. Note that this can lead to situations where a handle, that was + previously mapped to multiple distinct (but equivalent) operations, is now + mapped to the same operation multiple times. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs); + let assemblyFormat = "`to` $target attr-dict `:` type($target)"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + def ApplyPatternsOp : TransformDialectOp<"apply_patterns", [TransformOpInterface, TransformEachOpTrait, DeclareOpInterfaceMethods, diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -14,12 +14,14 @@ #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassRegistry.h" +#include "mlir/Transforms/CSE.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" @@ -257,6 +259,32 @@ modifiesPayload(effects); } +//===----------------------------------------------------------------------===// +// ApplyCommonSubexpressionEliminationOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::ApplyCommonSubexpressionEliminationOp::applyToOne( + transform::TransformRewriter &rewriter, Operation *target, + ApplyToEachResultList &results, transform::TransformState &state) { + // Make sure that this transform is not applied to itself. Modifying the + // transform IR while it is being interpreted is generally dangerous. + DiagnosedSilenceableFailure payloadCheck = + ensurePayloadIsSeparateFromTransform(*this, target); + if (!payloadCheck.succeeded()) + return payloadCheck; + + DominanceInfo domInfo; + mlir::eliminateCommonSubExpressions(rewriter, domInfo, target); + return DiagnosedSilenceableFailure::success(); +} + +void transform::ApplyCommonSubexpressionEliminationOp::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getTarget(), effects); + transform::modifiesPayload(effects); +} + //===----------------------------------------------------------------------===// // ApplyPatternsOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -1727,3 +1727,66 @@ test_notify_payload_op_replaced %0, %1 : (!transform.any_op, !transform.any_op) -> () test_print_remark_at_operand %0, "updated handle" : !transform.any_op } + +// ----- + +// CHECK-LABEL: func @test_apply_cse() +// CHECK: %[[const:.*]] = arith.constant 0 : index +// CHECK: %[[ex1:.*]] = scf.execute_region -> index { +// CHECK: scf.yield %[[const]] +// CHECK: } +// CHECK: %[[ex2:.*]] = scf.execute_region -> index { +// CHECK: scf.yield %[[const]] +// CHECK: } +// CHECK: return %[[const]], %[[ex1]], %[[ex2]] +func.func @test_apply_cse() -> (index, index, index) { + // expected-remark @below{{eliminated 1}} + // expected-remark @below{{eliminated 2}} + %0 = arith.constant 0 : index + %1 = scf.execute_region -> index { + %2 = arith.constant 0 : index + scf.yield %2 : index + } {first} + %3 = scf.execute_region -> index { + %4 = arith.constant 0 : index + scf.yield %4 : index + } {second} + return %0, %1, %3 : index, index, index +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %first = transform.structured.match attributes{first} in %0 : (!transform.any_op) -> !transform.any_op + %elim_first = transform.structured.match ops{["arith.constant"]} in %first : (!transform.any_op) -> !transform.any_op + %second = transform.structured.match attributes{first} in %0 : (!transform.any_op) -> !transform.any_op + %elim_second = transform.structured.match ops{["arith.constant"]} in %first : (!transform.any_op) -> !transform.any_op + + // There are 3 arith.constant ops. + %all = transform.structured.match ops{["arith.constant"]} in %0 : (!transform.any_op) -> !transform.any_op + // expected-remark @below{{3}} + test_print_number_of_associated_payload_ir_ops %all : !transform.any_op + // "deduplicate" has no effect because these are 3 different ops. + %merged_before = transform.merge_handles deduplicate %all : !transform.any_op + // expected-remark @below{{3}} + test_print_number_of_associated_payload_ir_ops %merged_before : !transform.any_op + + // Apply CSE. + transform.apply_cse to %0 : !transform.any_op + + // The handle is still mapped to 3 arith.constant ops. + // expected-remark @below{{3}} + test_print_number_of_associated_payload_ir_ops %all : !transform.any_op + // But they are all the same op. + %merged_after = transform.merge_handles deduplicate %all : !transform.any_op + // expected-remark @below{{1}} + test_print_number_of_associated_payload_ir_ops %merged_after : !transform.any_op + + // The other handles were also updated. + test_print_remark_at_operand %elim_first, "eliminated 1" : !transform.any_op + // expected-remark @below{{1}} + test_print_number_of_associated_payload_ir_ops %elim_first : !transform.any_op + test_print_remark_at_operand %elim_second, "eliminated 2" : !transform.any_op + // expected-remark @below{{1}} + test_print_number_of_associated_payload_ir_ops %elim_second : !transform.any_op +}