diff --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp --- a/mlir/lib/Interfaces/LoopLikeInterface.cpp +++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp @@ -10,6 +10,7 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" +#include using namespace mlir; @@ -26,75 +27,87 @@ // LoopLike Utilities //===----------------------------------------------------------------------===// -// Checks whether the given op can be hoisted by checking that -// - the op and any of its contained operations do not depend on SSA values -// defined inside of the loop (by means of calling definedOutside). -// - the op has no side-effects. If sideEffecting is Never, sideeffects of this -// op and its nested ops are ignored. -static bool canBeHoisted(Operation *op, - function_ref definedOutside) { - // Check that dependencies are defined outside of loop. - if (!llvm::all_of(op->getOperands(), definedOutside)) - return false; - // Check whether this op is side-effect free. If we already know that there - // can be no side-effects because the surrounding op has claimed so, we can - // (and have to) skip this step. +/// Returns true if the given operation is side-effect free as are all of its +/// nested operations. +/// +/// TODO: There is a duplicate function in ControlFlowSink. Move +/// `moveLoopInvariantCode` to TransformUtils and then factor out this function. +static bool isSideEffectFree(Operation *op) { if (auto memInterface = dyn_cast(op)) { + // If the op has side-effects, it cannot be moved. if (!memInterface.hasNoEffect()) return false; - // If the operation doesn't have side effects and it doesn't recursively - // have side effects, it can always be hoisted. + // If the op does not have recursive side effects, then it can be moved. if (!op->hasTrait()) return true; - - // Otherwise, if the operation doesn't provide the memory effect interface - // and it doesn't have recursive side effects we treat it conservatively as - // side-effecting. } else if (!op->hasTrait()) { + // Otherwise, if the op does not implement the memory effect interface and + // it does not have recursive side effects, then it cannot be known that the + // op is moveable. return false; } - // Recurse into the regions for this op and check whether the contained ops - // can be hoisted. - for (auto ®ion : op->getRegions()) { - for (auto &block : region) { - for (auto &innerOp : block) - if (!canBeHoisted(&innerOp, definedOutside)) - return false; - } - } + // Recurse into the regions and ensure that all nested ops can also be moved. + for (Region ®ion : op->getRegions()) + for (Operation &op : region.getOps()) + if (!isSideEffectFree(&op)) + return false; return true; } +// Checks whether the given op can be hoisted by checking that +// - the op and none of its contained operations depend on values inside of the +// loop (by means of calling definedOutside). +// - the op has no side-effects. +static bool canBeHoisted(Operation *op, + function_ref definedOutside) { + if (!isSideEffectFree(op)) + return false; + + // Do not move terminators. + if (op->getBlock()->getTerminator() == op) + return false; + + // Walk the nested operations and check that all used values are either + // defined outside of the loop or in a nested region, but not at the level of + // the loop body. + auto walkFn = [&](Operation *child) { + for (Value operand : child->getOperands()) { + // Ignore values defined in a nested region. + if (op->isAncestor(operand.getParentRegion()->getParentOp())) + continue; + if (!definedOutside(operand)) + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }; + return !op->walk(walkFn).wasInterrupted(); +} + void mlir::moveLoopInvariantCode(LoopLikeOpInterface looplike) { - auto &loopBody = looplike.getLoopBody(); - - // We use two collections here as we need to preserve the order for insertion - // and this is easiest. - SmallPtrSet willBeMovedSet; - SmallVector opsToMove; - - // Helper to check whether an operation is loop invariant wrt. SSA properties. - auto isDefinedOutsideOfBody = [&](Value value) { - auto *definingOp = value.getDefiningOp(); - return (definingOp && !!willBeMovedSet.count(definingOp)) || - looplike.isDefinedOutsideOfLoop(value); + Region *loopBody = &looplike.getLoopBody(); + + std::queue worklist; + // Add top-level operations in the loop body to the worklist. + for (Operation &op : loopBody->getOps()) + worklist.push(&op); + + auto definedOutside = [&](Value value) { + return looplike.isDefinedOutsideOfLoop(value); }; - // Do not use walk here, as we do not want to go into nested regions and hoist - // operations from there. These regions might have semantics unknown to this - // rewriting. If the nested regions are loops, they will have been processed. - for (auto &block : loopBody) { - for (auto &op : block.without_terminator()) { - if (canBeHoisted(&op, isDefinedOutsideOfBody)) { - opsToMove.push_back(&op); - willBeMovedSet.insert(&op); - } - } - } + while (!worklist.empty()) { + Operation *op = worklist.front(); + worklist.pop(); + // Skip ops that have already been moved. Check if the op can be hoisted. + if (op->getParentRegion() != loopBody || !canBeHoisted(op, definedOutside)) + continue; - // For all instructions that we found to be invariant, move outside of the - // loop. - for (Operation *op : opsToMove) looplike.moveOutOfLoop(op); + // Since the op has been moved, we need to check its users within the + // top-level of the loop body. + for (Operation *user : op->getUsers()) + if (user->getParentRegion() == loopBody) + worklist.push(user); + } } diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir --- a/mlir/test/Transforms/loop-invariant-code-motion.mlir +++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir @@ -157,10 +157,10 @@ affine.for %arg0 = 0 to 10 { affine.for %arg1 = 0 to 10 { affine.if affine_set<(d0, d1) : (d1 - d0 >= 0)> (%arg0, %arg0) { - %cf9 = arith.addf %cf8, %cf8 : f32 - affine.if affine_set<(d0, d1) : (d1 - d0 >= 0)> (%arg0, %arg0) { - %cf10 = arith.addf %cf9, %cf9 : f32 - } + %cf9 = arith.addf %cf8, %cf8 : f32 + affine.if affine_set<(d0, d1) : (d1 - d0 >= 0)> (%arg0, %arg0) { + %cf10 = arith.addf %cf9, %cf9 : f32 + } } } } @@ -168,6 +168,7 @@ // CHECK: memref.alloc // CHECK-NEXT: arith.constant // CHECK-NEXT: affine.for + // CHECK-NEXT: } // CHECK-NEXT: affine.for // CHECK-NEXT: affine.if // CHECK-NEXT: arith.addf @@ -175,7 +176,6 @@ // CHECK-NEXT: arith.addf // CHECK-NEXT: } // CHECK-NEXT: } - // CHECK-NEXT: } return @@ -319,6 +319,54 @@ scf.yield %val2: index } } - return + return } +// ----- + +// Test invariant nested loop is hoisted. +// CHECK-LABEL: func @test_invariant_nested_loop +func @test_invariant_nested_loop() { + // CHECK: %[[C:.*]] = arith.constant + %0 = arith.constant 5 : i32 + // CHECK: %[[V0:.*]] = arith.addi %[[C]], %[[C]] + // CHECK-NEXT: %[[V1:.*]] = arith.addi %[[V0]], %[[C]] + // CHECK-NEXT: test.graph_loop + // CHECK-NEXT: ^bb0(%[[ARG0:.*]]: i32) + // CHECK-NEXT: %[[V2:.*]] = arith.subi %[[ARG0]], %[[ARG0]] + // CHECK-NEXT: test.region_yield %[[V2]] + // CHECK: test.graph_loop + // CHECK-NEXT: test.region_yield %[[V1]] + test.graph_loop { + %1 = arith.addi %0, %0 : i32 + %2 = arith.addi %1, %0 : i32 + test.graph_loop { + ^bb0(%arg0: i32): + %3 = arith.subi %arg0, %arg0 : i32 + test.region_yield %3 : i32 + } : () -> () + test.region_yield %2 : i32 + } : () -> () + return +} + + +// ----- + +// Test ops in a graph region are hoisted. +// CHECK-LABEL: func @test_invariants_in_graph_region +func @test_invariants_in_graph_region() { + // CHECK: test.single_no_terminator_op + test.single_no_terminator_op : { + // CHECK-NEXT: %[[C:.*]] = arith.constant + // CHECK-NEXT: %[[V1:.*]] = arith.addi %[[C]], %[[C]] + // CHECK-NEXT: %[[V0:.*]] = arith.addi %[[C]], %[[V1]] + test.graph_loop { + %v0 = arith.addi %c0, %v1 : i32 + %v1 = arith.addi %c0, %c0 : i32 + %c0 = arith.constant 5 : i32 + test.region_yield %v0 : i32 + } : () -> () + } + return +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -10,7 +10,9 @@ #define TEST_OPS include "TestDialect.td" +include "TestInterfaces.td" include "mlir/Dialect/DLTI/DLTIBase.td" +include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" include "mlir/IR/EnumAttr.td" include "mlir/IR/OpBase.td" include "mlir/IR/OpAsmInterface.td" @@ -22,9 +24,8 @@ include "mlir/Interfaces/CopyOpInterface.td" include "mlir/Interfaces/DataLayoutInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" -include "TestInterfaces.td" // Include the attribute definitions. @@ -2748,14 +2749,14 @@ def TestResource : Resource<"TestResource">; def TestEffectsOpA : TEST_Op<"op_with_effects_a"> { - let arguments = (ins - Arg, "", [MemRead]>, - Arg:$first, - Arg:$second, - Arg, "", [MemRead]>:$optional_symbol - ); + let arguments = (ins + Arg, "", [MemRead]>, + Arg:$first, + Arg:$second, + Arg, "", [MemRead]>:$optional_symbol + ); - let results = (outs Res]>); + let results = (outs Res]>); } def TestEffectsOpB : TEST_Op<"op_with_effects_b", @@ -2769,4 +2770,25 @@ def TestEffectsWrite : TEST_Op<"op_with_memwrite", [MemoryEffects<[MemWrite]>]>; +//===----------------------------------------------------------------------===// +// Test Loop Op with a graph region +//===----------------------------------------------------------------------===// + +// Test loop op with a graph region. +def TestGraphLoopOp : TEST_Op<"graph_loop", + [LoopLikeOpInterface, NoSideEffect, + RecursiveSideEffects, SingleBlock, + RegionKindInterface, HasOnlyGraphRegion]> { + let arguments = (ins Variadic:$args); + let results = (outs Variadic:$rets); + let regions = (region SizedRegion<1>:$body); + + let assemblyFormat = [{ $args $body attr-dict `:` + functional-type(operands, results) }]; + + let extraClassDeclaration = [{ + mlir::Region &getLoopBody() { return getBody(); } + }]; +} + #endif // TEST_OPS