diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp @@ -45,30 +45,31 @@ : public AffineLoopInvariantCodeMotionBase { void runOnFunction() override; void runOnAffineForOp(AffineForOp forOp); + SmallPtrSet opsWithUsers; }; } // end anonymous namespace static bool checkInvarianceOfNestedIfOps(Operation *op, Value indVar, - SmallPtrSetImpl &definedOps, + const SmallPtrSetImpl &opsWithUsers, SmallPtrSetImpl &opsToHoist); static bool isOpLoopInvariant(Operation &op, Value indVar, - SmallPtrSetImpl &definedOps, + const SmallPtrSetImpl &opsWithUsers, SmallPtrSetImpl &opsToHoist); -static bool -areAllOpsInTheBlockListInvariant(Region &blockList, Value indVar, - SmallPtrSetImpl &definedOps, - SmallPtrSetImpl &opsToHoist); +static bool areAllOpsInTheBlockListInvariant( + Region &blockList, Value indVar, + const SmallPtrSetImpl &opsWithUsers, + SmallPtrSetImpl &opsToHoist); // Returns true if the individual op is loop invariant. bool isOpLoopInvariant(Operation &op, Value indVar, - SmallPtrSetImpl &definedOps, + const SmallPtrSetImpl &opsWithUsers, SmallPtrSetImpl &opsToHoist) { LLVM_DEBUG(llvm::dbgs() << "iterating on op: " << op;); if (isa(op)) { - if (!checkInvarianceOfNestedIfOps(&op, indVar, definedOps, opsToHoist)) { + if (!checkInvarianceOfNestedIfOps(&op, indVar, opsWithUsers, opsToHoist)) { return false; } } else if (isa(op)) { @@ -82,7 +83,6 @@ // Register op in the set of ops defined inside the loop. This set is used // to prevent hoisting ops that depend on other ops defined inside the loop // which are themselves not being hoisted. - definedOps.insert(&op); if (isa(op)) { Value memref = isa(op) @@ -135,7 +135,8 @@ // If the value was defined in the loop (outside of the // if/else region), and that operation itself wasn't meant to // be hoisted, then mark this operation loop dependent. - if (definedOps.count(operandSrc) && opsToHoist.count(operandSrc) == 0) { + if (opsWithUsers.count(operandSrc) && + opsToHoist.count(operandSrc) == 0) { return false; } } @@ -149,12 +150,13 @@ // Checks if all ops in a region (i.e. list of blocks) are loop invariant. bool areAllOpsInTheBlockListInvariant( - Region &blockList, Value indVar, SmallPtrSetImpl &definedOps, + Region &blockList, Value indVar, + const SmallPtrSetImpl &opsWithUsers, SmallPtrSetImpl &opsToHoist) { for (auto &b : blockList) { for (auto &op : b) { - if (!isOpLoopInvariant(op, indVar, definedOps, opsToHoist)) { + if (!isOpLoopInvariant(op, indVar, opsWithUsers, opsToHoist)) { return false; } } @@ -164,18 +166,19 @@ } // Returns true if the affine.if op can be hoisted. -bool checkInvarianceOfNestedIfOps(Operation *op, Value indVar, - SmallPtrSetImpl &definedOps, - SmallPtrSetImpl &opsToHoist) { +bool checkInvarianceOfNestedIfOps( + Operation *op, Value indVar, + const SmallPtrSetImpl &opsWithUsers, + SmallPtrSetImpl &opsToHoist) { assert(isa(op)); auto ifOp = cast(op); - if (!areAllOpsInTheBlockListInvariant(ifOp.thenRegion(), indVar, definedOps, + if (!areAllOpsInTheBlockListInvariant(ifOp.thenRegion(), indVar, opsWithUsers, opsToHoist)) { return false; } - if (!areAllOpsInTheBlockListInvariant(ifOp.elseRegion(), indVar, definedOps, + if (!areAllOpsInTheBlockListInvariant(ifOp.elseRegion(), indVar, opsWithUsers, opsToHoist)) { return false; } @@ -187,7 +190,6 @@ auto *loopBody = forOp.getBody(); auto indVar = forOp.getInductionVar(); - SmallPtrSet definedOps; // This is the place where hoisted instructions would reside. OpBuilder b(forOp.getOperation()); @@ -198,7 +200,7 @@ // We don't hoist for loops. if (!isa(op)) { if (!isa(op)) { - if (isOpLoopInvariant(op, indVar, definedOps, opsToHoist)) { + if (isOpLoopInvariant(op, indVar, opsWithUsers, opsToHoist)) { opsToMove.push_back(&op); } } @@ -215,12 +217,21 @@ } void LoopInvariantCodeMotion::runOnFunction() { - // Walk through all loops in a function in innermost-loop-first order. This - // way, we first LICM from the inner loop, and place the ops in - // the outer loop, which in turn can be further LICM'ed. - getFunction().walk([&](AffineForOp op) { - LLVM_DEBUG(op->print(llvm::dbgs() << "\nOriginal loop\n")); - runOnAffineForOp(op); + + getFunction().walk([&](AffineForOp forOp) { + LLVM_DEBUG(forOp->print(llvm::dbgs() << "\nOriginal loop\n")); + opsWithUsers.clear(); + // Collect operations that have users. These operations + // are not hoisted. + forOp->walk([&](Operation *op) { + if (op->getNumResults() > 0 && !op->use_empty()) { + opsWithUsers.insert(op); + } + }); + // Walk through all loops in a function in innermost-loop-first order. This + // way, we first LICM from the inner loop, and place the ops in + // the outer loop, which in turn can be further LICM'ed. + runOnAffineForOp(forOp); }); } diff --git a/mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir b/mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir --- a/mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir +++ b/mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir @@ -613,3 +613,70 @@ // CHECK-NEXT: addf // CHECK-NEXT: affine.vector_store // CHECK-NEXT: affine.for + +// ----- + +#set = affine_set<(d0): (d0 - 10 >= 0)> +// CHECK-LABEL: func @affine_if_not_invariant( +func @affine_if_not_invariant(%buffer: memref<1024xf32>) -> f32 { + %sum_init_0 = constant 0.0 : f32 + %sum_init_1 = constant 1.0 : f32 + %res = affine.for %i = 0 to 10 step 2 iter_args(%sum_iter = %sum_init_0) -> f32 { + %t = affine.load %buffer[%i] : memref<1024xf32> + %sum_next = affine.if #set(%i) -> (f32) { + %new_sum = addf %sum_iter, %t : f32 + affine.yield %new_sum : f32 + } else { + affine.yield %sum_iter : f32 + } + %modified_sum = addf %sum_next, %sum_init_1 : f32 + affine.yield %modified_sum : f32 + } + return %res : f32 +} + +// CHECK: constant 0.000000e+00 : f32 +// CHECK-NEXT: constant 1.000000e+00 : f32 +// CHECK-NEXT: affine.for +// CHECK-NEXT: affine.load +// CHECK-NEXT: affine.if +// CHECK-NEXT: addf +// CHECK-NEXT: affine.yield +// CHECK-NEXT: } else { +// CHECK-NEXT: affine.yield +// CHECK-NEXT: } +// CHECK-NEXT: addf +// CHECK-NEXT: affine.yield +// CHECK-NEXT: } + +// ----- + +// CHECK-LABEL: func @affine_for_not_invariant( +func @affine_for_not_invariant(%in : memref<30x512xf32, 1>, + %out : memref<30x1xf32, 1>) { + %sum_0 = constant 0.0 : f32 + %cst_0 = constant 1.1 : f32 + affine.for %j = 0 to 30 { + %sum = affine.for %i = 0 to 512 iter_args(%sum_iter = %sum_0) -> (f32) { + %t = affine.load %in[%j,%i] : memref<30x512xf32,1> + %sum_next = addf %sum_iter, %t : f32 + affine.yield %sum_next : f32 + } + %mod_sum = mulf %sum, %cst_0 : f32 + affine.store %mod_sum, %out[%j, 0] : memref<30x1xf32, 1> + } + return +} + +// CHECK: constant 0.000000e+00 : f32 +// CHECK-NEXT: constant 1.100000e+00 : f32 +// CHECK-NEXT: affine.for +// CHECK-NEXT: affine.for +// CHECK-NEXT: affine.load +// CHECK-NEXT: addf +// CHECK-NEXT: affine.yield +// CHECK-NEXT: } +// CHECK-NEXT: mulf +// CHECK-NEXT: affine.store + +// ----- \ No newline at end of file