diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h @@ -23,6 +23,7 @@ #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/LoopLikeInterface.h" namespace mlir { namespace transform { diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -223,6 +223,35 @@ let assemblyFormat = "attr-dict"; } +def ApplyLoopInvariantCodeMotionOp : TransformDialectOp<"apply_licm", + [TransformOpInterface, TransformEachOpTrait, + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { + let summary = "Move loop-invariant code out of a loop-like op"; + let description = [{ + This transform moves side-effect free, loop invariant code out of the + targeted loop-like op. The targeted op must implement the + `LoopLikeOpInterface`. + + Note: To move invariant ops from a loop nest, this transform must be applied + to each loop of the loop nest, starting with the inner-most loop. + + This transform reads the target handle and modifies the payload. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs); + let assemblyFormat = "`to` $target attr-dict `:` type($target)"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::LoopLikeOpInterface target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass", [TransformOpInterface, TransformEachOpTrait, FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> { diff --git a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt @@ -14,6 +14,7 @@ LINK_LIBS PUBLIC MLIRCastInterfaces MLIRIR + MLIRLoopLikeInterface MLIRParser MLIRPass MLIRRewrite diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -23,6 +23,7 @@ #include "mlir/Pass/PassRegistry.h" #include "mlir/Transforms/CSE.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallPtrSet.h" @@ -391,6 +392,27 @@ op.getCanonicalizationPatterns(patterns, ctx); } +//===----------------------------------------------------------------------===// +// ApplyLoopInvariantCodeMotionOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::ApplyLoopInvariantCodeMotionOp::applyToOne( + transform::TransformRewriter &rewriter, LoopLikeOpInterface target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + // Currently, LICM does not remove operations, so we don't need tracking. + // If this ever changes, add a LICM entry point that takes a rewriter. + moveLoopInvariantCode(target); + return DiagnosedSilenceableFailure::success(); +} + +void transform::ApplyLoopInvariantCodeMotionOp::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getTarget(), effects); + transform::modifiesPayload(effects); +} + //===----------------------------------------------------------------------===// // ApplyRegisteredPassOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -1790,3 +1790,39 @@ // expected-remark @below{{1}} test_print_number_of_associated_payload_ir_ops %elim_second : !transform.any_op } + +// ----- + +// CHECK-LABEL: func @test_licm( +// CHECK: arith.muli +// CHECK: scf.for {{.*}} { +// CHECK: vector.print +// CHECK: } +func.func @test_licm(%arg0: index, %arg1: index, %arg2: index) { + scf.for %iv = %arg0 to %arg1 step %arg2 { + %0 = arith.muli %arg0, %arg1 : index + vector.print %0 : index + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["scf.for"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_licm to %0 : !transform.any_op +} + +// ----- + +// expected-note @below{{when applied to this op}} +module { + func.func @test_licm_invalid() { + return + } + + transform.sequence failures(propagate) { + ^bb1(%arg1: !transform.any_op): + // expected-error @below{{transform applied to the wrong op kind}} + transform.apply_licm to %arg1 : !transform.any_op + } +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -10237,6 +10237,7 @@ ":CastInterfaces", ":ControlFlowInterfaces", ":IR", + ":LoopLikeInterface", ":Pass", ":Rewrite", ":SideEffectInterfaces",