diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -1724,17 +1724,6 @@ mlir::Region &fir::IterWhileOp::getLoopBody() { return getRegion(); } -bool fir::IterWhileOp::isDefinedOutsideOfLoop(mlir::Value value) { - return !getRegion().isAncestor(value.getParentRegion()); -} - -mlir::LogicalResult -fir::IterWhileOp::moveOutOfLoop(llvm::ArrayRef ops) { - for (auto *op : ops) - op->moveBefore(*this); - return success(); -} - mlir::BlockArgument fir::IterWhileOp::iterArgToBlockArg(mlir::Value iterArg) { for (auto i : llvm::enumerate(getInitArgs())) if (iterArg == i.value()) @@ -2022,17 +2011,6 @@ mlir::Region &fir::DoLoopOp::getLoopBody() { return getRegion(); } -bool fir::DoLoopOp::isDefinedOutsideOfLoop(mlir::Value value) { - return !getRegion().isAncestor(value.getParentRegion()); -} - -mlir::LogicalResult -fir::DoLoopOp::moveOutOfLoop(llvm::ArrayRef ops) { - for (auto op : ops) - op->moveBefore(*this); - return success(); -} - /// Translate a value passed as an iter_arg to the corresponding block /// argument in the body of the loop. mlir::BlockArgument fir::DoLoopOp::iterArgToBlockArg(mlir::Value iterArg) { diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.h b/mlir/include/mlir/Interfaces/LoopLikeInterface.h --- a/mlir/include/mlir/Interfaces/LoopLikeInterface.h +++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.h @@ -28,7 +28,7 @@ namespace mlir { /// Move loop invariant code out of a `looplike` operation. -LogicalResult moveLoopInvariantCode(LoopLikeOpInterface looplike); +void moveLoopInvariantCode(LoopLikeOpInterface looplike); } // namespace mlir #endif // MLIR_INTERFACES_LOOPLIKEINTERFACE_H_ diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td --- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td +++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td @@ -30,7 +30,9 @@ explicit capture of dependencies, an implementation could check whether the value corresponds to a captured dependency. }], - "bool", "isDefinedOutsideOfLoop", (ins "::mlir::Value ":$value) + "bool", "isDefinedOutsideOfLoop", (ins "::mlir::Value ":$value), [{}], [{ + return value.getParentRegion()->isProperAncestor(&$_op.getLoopBody()); + }] >, InterfaceMethod<[{ Returns the region that makes up the body of the loop and should be @@ -39,10 +41,12 @@ "::mlir::Region &", "getLoopBody" >, InterfaceMethod<[{ - Moves the given vector of operations out of the loop. The vector is - sorted topologically. + Moves the given loop-invariant operation out of the loop. }], - "::mlir::LogicalResult", "moveOutOfLoop", (ins "::mlir::ArrayRef<::mlir::Operation *>":$ops) + "void", "moveOutOfLoop", + (ins "::mlir::Operation *":$op), [{}], [{ + op->moveBefore($_op); + }] >, InterfaceMethod<[{ If there is a single induction variable return it, otherwise return diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -1859,10 +1859,6 @@ Region &AffineForOp::getLoopBody() { return region(); } -bool AffineForOp::isDefinedOutsideOfLoop(Value value) { - return !region().isAncestor(value.getParentRegion()); -} - Optional AffineForOp::getSingleInductionVar() { return getInductionVar(); } @@ -1879,12 +1875,6 @@ return OpFoldResult(b.getI64IntegerAttr(getStep())); } -LogicalResult AffineForOp::moveOutOfLoop(ArrayRef ops) { - for (auto *op : ops) - op->moveBefore(*this); - return success(); -} - /// Returns true if the provided value is the induction variable of a /// AffineForOp. bool mlir::isForInductionVar(Value val) { @@ -3057,16 +3047,6 @@ Region &AffineParallelOp::getLoopBody() { return region(); } -bool AffineParallelOp::isDefinedOutsideOfLoop(Value value) { - return !region().isAncestor(value.getParentRegion()); -} - -LogicalResult AffineParallelOp::moveOutOfLoop(ArrayRef ops) { - for (Operation *op : ops) - op->moveBefore(*this); - return success(); -} - unsigned AffineParallelOp::getNumDims() { return steps().size(); } AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -271,12 +271,11 @@ << "\nInvolving: " << tensorBBArg << "\n"); // If a read slice is present, hoist it. - if (read.extractSliceOp && failed(forOp.moveOutOfLoop({read.extractSliceOp}))) - llvm_unreachable("Unexpected failure moving extract_slice out of loop"); + if (read.extractSliceOp) + forOp.moveOutOfLoop(read.extractSliceOp); // Hoist the transfer_read op. - if (failed(forOp.moveOutOfLoop({read.transferReadOp}))) - llvm_unreachable("Unexpected failure moving transfer read out of loop"); + forOp.moveOutOfLoop(read.transferReadOp); // TODO: don't hardcode /*numIvs=*/1. assert(tensorBBArg.getArgNumber() >= /*numIvs=*/1); @@ -396,11 +395,8 @@ changed = false; // First move loop invariant ops outside of their loop. This needs to be // done before as we cannot move ops without interputing the function walk. - func.walk([&](LoopLikeOpInterface loopLike) { - if (failed(moveLoopInvariantCode(loopLike))) - llvm_unreachable( - "Unexpected failure to move invariant code out of loop"); - }); + func.walk( + [&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); }); func.walk([&](vector::TransferReadOp transferRead) { if (!transferRead.getShapedType().isa()) @@ -484,9 +480,7 @@ } // Hoist read before. - if (failed(loop.moveOutOfLoop({transferRead}))) - llvm_unreachable( - "Unexpected failure to move transfer read out of loop"); + loop.moveOutOfLoop(transferRead); // Hoist write after. transferWrite->moveAfter(loop); diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -338,14 +338,9 @@ return signalPassFailure(); if (options.licm) { - if (funcOp - ->walk([&](LoopLikeOpInterface loopLike) { - if (failed(moveLoopInvariantCode(loopLike))) - return WalkResult::interrupt(); - return WalkResult::advance(); - }) - .wasInterrupted()) - return signalPassFailure(); + funcOp->walk([&](LoopLikeOpInterface loopLike) { + moveLoopInvariantCode(loopLike); + }); } // Gathers all innermost loops through a post order pruned walk. diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -449,16 +449,6 @@ Region &ForOp::getLoopBody() { return getRegion(); } -bool ForOp::isDefinedOutsideOfLoop(Value value) { - return !getRegion().isAncestor(value.getParentRegion()); -} - -LogicalResult ForOp::moveOutOfLoop(ArrayRef ops) { - for (auto *op : ops) - op->moveBefore(*this); - return success(); -} - ForOp mlir::scf::getForInductionVarOwner(Value val) { auto ivArg = val.dyn_cast(); if (!ivArg) @@ -2061,16 +2051,6 @@ Region &ParallelOp::getLoopBody() { return getRegion(); } -bool ParallelOp::isDefinedOutsideOfLoop(Value value) { - return !getRegion().isAncestor(value.getParentRegion()); -} - -LogicalResult ParallelOp::moveOutOfLoop(ArrayRef ops) { - for (auto *op : ops) - op->moveBefore(*this); - return success(); -} - ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) { auto ivArg = val.dyn_cast(); if (!ivArg) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -68,21 +68,6 @@ /// Returns the while loop body. Region &tosa::WhileOp::getLoopBody() { return body(); } -bool tosa::WhileOp::isDefinedOutsideOfLoop(Value value) { - return !body().isAncestor(value.getParentRegion()); -} - -LogicalResult WhileOp::moveOutOfLoop(ArrayRef ops) { - if (ops.empty()) - return success(); - - Operation *tosaWhileOp = this->getOperation(); - for (auto *op : ops) - op->moveBefore(tosaWhileOp); - - return success(); -} - //===----------------------------------------------------------------------===// // Tosa dialect initialization. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp --- a/mlir/lib/Interfaces/LoopLikeInterface.cpp +++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp @@ -26,76 +26,72 @@ // LoopLike Utilities //===----------------------------------------------------------------------===// -// Checks whether the given op can be hoisted by checking that -// - the op and any of its contained operations do not depend on SSA values -// defined inside of the loop (by means of calling definedOutside). -// - the op has no side-effects. If sideEffecting is Never, sideeffects of this -// op and its nested ops are ignored. -static bool canBeHoisted(Operation *op, - function_ref definedOutside) { - // Check that dependencies are defined outside of loop. - if (!llvm::all_of(op->getOperands(), definedOutside)) - return false; - // Check whether this op is side-effect free. If we already know that there - // can be no side-effects because the surrounding op has claimed so, we can - // (and have to) skip this step. +/// Returns true if the given operation is side-effect free as are all of its +/// nested operations. +/// +/// TODO: There is a duplicate function in ControlFlowSink. Move +/// `moveLoopInvariantCode` to TransformUtils and then factor out this function. +static bool isSideEffectFree(Operation *op) { if (auto memInterface = dyn_cast(op)) { + // If the op has side-effects, it cannot be moved. if (!memInterface.hasNoEffect()) return false; - // If the operation doesn't have side effects and it doesn't recursively - // have side effects, it can always be hoisted. + // If the op does not have recursive side effects, then it can be moved. if (!op->hasTrait()) return true; - - // Otherwise, if the operation doesn't provide the memory effect interface - // and it doesn't have recursive side effects we treat it conservatively as - // side-effecting. } else if (!op->hasTrait()) { + // Otherwise, if the op does not implement the memory effect interface and + // it does not have recursive side effects, then it cannot be known that the + // op is moveable. return false; } - // Recurse into the regions for this op and check whether the contained ops - // can be hoisted. - for (auto ®ion : op->getRegions()) { - for (auto &block : region) { - for (auto &innerOp : block) - if (!canBeHoisted(&innerOp, definedOutside)) - return false; - } - } + // Recurse into the regions and ensure that all nested ops can also be moved. + for (Region ®ion : op->getRegions()) + for (Operation &op : region.getOps()) + if (!isSideEffectFree(&op)) + return false; return true; } -LogicalResult mlir::moveLoopInvariantCode(LoopLikeOpInterface looplike) { - auto &loopBody = looplike.getLoopBody(); - - // We use two collections here as we need to preserve the order for insertion - // and this is easiest. - SmallPtrSet willBeMovedSet; - SmallVector opsToMove; +// Checks whether the given op can be hoisted by checking that +// - the op and none of its contained operations depend on values inside of the +// loop (by means of calling definedOutside). +// - the op has no side-effects. +static bool canBeHoisted(Operation *op, + function_ref definedOutside) { + if (!isSideEffectFree(op)) + return false; - // Helper to check whether an operation is loop invariant wrt. SSA properties. - auto isDefinedOutsideOfBody = [&](Value value) { - auto *definingOp = value.getDefiningOp(); - return (definingOp && !!willBeMovedSet.count(definingOp)) || - looplike.isDefinedOutsideOfLoop(value); + // Walk the nested operations and check that all used values are either + // defined outside of the loop or in a nested region, but not at the level of + // the loop body. + auto walkFn = [&](Operation *child) { + for (Value operand : child->getOperands()) { + // Ignore values defined in a nested region. + if (op->isAncestor(operand.getParentRegion()->getParentOp())) + continue; + if (!definedOutside(operand)) + return WalkResult::interrupt(); + } + return WalkResult::advance(); }; + return !op->walk(walkFn).wasInterrupted(); +} + +void mlir::moveLoopInvariantCode(LoopLikeOpInterface looplike) { + auto &loopBody = looplike.getLoopBody(); // Do not use walk here, as we do not want to go into nested regions and hoist // operations from there. These regions might have semantics unknown to this // rewriting. If the nested regions are loops, they will have been processed. - for (auto &block : loopBody) { - for (auto &op : block.without_terminator()) { - if (canBeHoisted(&op, isDefinedOutsideOfBody)) { - opsToMove.push_back(&op); - willBeMovedSet.insert(&op); - } + for (Block &block : loopBody) { + for (Operation &op : + llvm::make_early_inc_range(block.without_terminator())) { + if (canBeHoisted(&op, [&](Value value) { + return looplike.isDefinedOutsideOfLoop(value); + })) + looplike.moveOutOfLoop(&op); } } - - // For all instructions that we found to be invariant, move outside of the - // loop. - LogicalResult result = looplike.moveOutOfLoop(opsToMove); - LLVM_DEBUG(looplike.print(llvm::dbgs() << "\n\nModified loop:\n")); - return result; } diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp --- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp +++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp @@ -37,8 +37,7 @@ // the outer loop, which in turn can be further LICM'ed. getOperation()->walk([&](LoopLikeOpInterface loopLike) { LLVM_DEBUG(loopLike.print(llvm::dbgs() << "\nOriginal loop:\n")); - if (failed(moveLoopInvariantCode(loopLike))) - signalPassFailure(); + moveLoopInvariantCode(loopLike); }); } diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir --- a/mlir/test/Transforms/loop-invariant-code-motion.mlir +++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir @@ -157,10 +157,10 @@ affine.for %arg0 = 0 to 10 { affine.for %arg1 = 0 to 10 { affine.if affine_set<(d0, d1) : (d1 - d0 >= 0)> (%arg0, %arg0) { - %cf9 = arith.addf %cf8, %cf8 : f32 - affine.if affine_set<(d0, d1) : (d1 - d0 >= 0)> (%arg0, %arg0) { - %cf10 = arith.addf %cf9, %cf9 : f32 - } + %cf9 = arith.addf %cf8, %cf8 : f32 + affine.if affine_set<(d0, d1) : (d1 - d0 >= 0)> (%arg0, %arg0) { + %cf10 = arith.addf %cf9, %cf9 : f32 + } } } } @@ -168,6 +168,7 @@ // CHECK: memref.alloc // CHECK-NEXT: arith.constant // CHECK-NEXT: affine.for + // CHECK-NEXT: } // CHECK-NEXT: affine.for // CHECK-NEXT: affine.if // CHECK-NEXT: arith.addf @@ -175,7 +176,6 @@ // CHECK-NEXT: arith.addf // CHECK-NEXT: } // CHECK-NEXT: } - // CHECK-NEXT: } return @@ -319,6 +319,32 @@ scf.yield %val2: index } } - return + return } +// ----- + +// CHECK-LABEL: func @test_subgraph_move +func @test_subgraph_move() { + // CHECK: %[[C:.*]] = arith.constant + %0 = arith.constant 5 : i32 + // CHECK: %[[V0:.*]] = arith.addi %[[C]], %[[C]] + // CHECK-NEXT: %[[V1:.*]] = arith.addi %[[V0]], %[[C]] + // CHECK-NEXT: test.loop + // CHECK-NEXT: ^bb0(%[[ARG0:.*]]: i32) + // CHECK-NEXT: %[[V2:.*]] = arith.subi %[[ARG0]], %[[ARG0]] + // CHECK-NEXT: test.region_yield %[[V2]] + // CHECK: test.loop + // CHECK-NEXT: test.region_yield %[[V1]] + test.loop { + %1 = arith.addi %0, %0 : i32 + %2 = arith.addi %1, %0 : i32 + test.loop { + ^bb0(%arg0: i32): + %3 = arith.subi %arg0, %arg0 : i32 + test.region_yield %3 : i32 + } : () -> () + test.region_yield %2 : i32 + } : () -> () + return +} diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp --- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp +++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp @@ -44,7 +44,7 @@ auto loop = fakeRead->getParentOfType(); OpBuilder b(loop); - (void)loop.moveOutOfLoop({fakeRead}); + loop.moveOutOfLoop(fakeRead); fakeWrite->moveAfter(loop); auto newLoop = cloneWithNewYields(b, loop, fakeRead->getResult(0), fakeCompute->getResult(0)); diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -1299,6 +1299,12 @@ /*printBlockTerminators=*/false); } +//===----------------------------------------------------------------------===// +// TestLoopOp +//===----------------------------------------------------------------------===// + +Region &TestLoopOp::getLoopBody() { return getBody(); } + #include "TestOpEnums.cpp.inc" #include "TestOpInterfaces.cpp.inc" #include "TestOpStructs.cpp.inc" diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -20,6 +20,7 @@ include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/CopyOpInterface.td" +include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/DataLayoutInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -2743,17 +2744,33 @@ def TestResource : Resource<"TestResource">; def TestEffectsOpA : TEST_Op<"op_with_effects_a"> { - let arguments = (ins - Arg, "", [MemRead]>, - Arg:$first, - Arg:$second, - Arg, "", [MemRead]>:$optional_symbol - ); + let arguments = (ins + Arg, "", [MemRead]>, + Arg:$first, + Arg:$second, + Arg, "", [MemRead]>:$optional_symbol + ); - let results = (outs Res]>); + let results = (outs Res]>); } def TestEffectsOpB : TEST_Op<"op_with_effects_b", [MemoryEffects<[MemWrite]>]>; +//===----------------------------------------------------------------------===// +// Test Loop Op +//===----------------------------------------------------------------------===// + +def TestLoopOp : TEST_Op<"loop", + [DeclareOpInterfaceMethods, + NoSideEffect, RecursiveSideEffects]> { + let arguments = (ins Variadic:$args); + let results = (outs Variadic:$rets); + let regions = (region SizedRegion<1>:$body); + + let assemblyFormat = [{ $args $body attr-dict `:` + functional-type(operands, results) }]; +} + + #endif // TEST_OPS