diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp @@ -49,30 +49,32 @@ } // end anonymous namespace static bool -checkInvarianceOfNestedIfOps(Operation *op, Value indVar, +checkInvarianceOfNestedIfOps(Operation *op, Value indVar, ValueRange iterArgs, SmallPtrSetImpl &opsWithUsers, SmallPtrSetImpl &opsToHoist); -static bool isOpLoopInvariant(Operation &op, Value indVar, +static bool isOpLoopInvariant(Operation &op, Value indVar, ValueRange iterArgs, SmallPtrSetImpl &opsWithUsers, SmallPtrSetImpl &opsToHoist); static bool areAllOpsInTheBlockListInvariant(Region &blockList, Value indVar, + ValueRange iterArgs, SmallPtrSetImpl &opsWithUsers, SmallPtrSetImpl &opsToHoist); // Returns true if the individual op is loop invariant. -bool isOpLoopInvariant(Operation &op, Value indVar, +bool isOpLoopInvariant(Operation &op, Value indVar, ValueRange iterArgs, SmallPtrSetImpl &opsWithUsers, SmallPtrSetImpl &opsToHoist) { LLVM_DEBUG(llvm::dbgs() << "iterating on op: " << op;); if (isa(op)) { - if (!checkInvarianceOfNestedIfOps(&op, indVar, opsWithUsers, opsToHoist)) { + if (!checkInvarianceOfNestedIfOps(&op, indVar, iterArgs, opsWithUsers, + opsToHoist)) { return false; } } else if (auto forOp = dyn_cast(op)) { - if (!areAllOpsInTheBlockListInvariant(forOp.getLoopBody(), indVar, + if (!areAllOpsInTheBlockListInvariant(forOp.getLoopBody(), indVar, iterArgs, opsWithUsers, opsToHoist)) { return false; } @@ -129,6 +131,12 @@ return false; } + // If the one of the iter_args is the operand, this op isn't loop invariant. + if (llvm::is_contained(iterArgs, op.getOperand(i))) { + LLVM_DEBUG(llvm::dbgs() << "\nOne of the iter_args is the operand\n"); + return false; + } + if (operandSrc != nullptr) { LLVM_DEBUG(llvm::dbgs() << *operandSrc << "\nIterating on operand src\n"); @@ -148,12 +156,13 @@ // Checks if all ops in a region (i.e. list of blocks) are loop invariant. bool areAllOpsInTheBlockListInvariant( - Region &blockList, Value indVar, SmallPtrSetImpl &opsWithUsers, + Region &blockList, Value indVar, ValueRange iterArgs, + SmallPtrSetImpl &opsWithUsers, SmallPtrSetImpl &opsToHoist) { for (auto &b : blockList) { for (auto &op : b) { - if (!isOpLoopInvariant(op, indVar, opsWithUsers, opsToHoist)) { + if (!isOpLoopInvariant(op, indVar, iterArgs, opsWithUsers, opsToHoist)) { return false; } } @@ -164,18 +173,19 @@ // Returns true if the affine.if op can be hoisted. bool checkInvarianceOfNestedIfOps(Operation *op, Value indVar, + ValueRange iterArgs, SmallPtrSetImpl &opsWithUsers, SmallPtrSetImpl &opsToHoist) { assert(isa(op)); auto ifOp = cast(op); - if (!areAllOpsInTheBlockListInvariant(ifOp.thenRegion(), indVar, opsWithUsers, - opsToHoist)) { + if (!areAllOpsInTheBlockListInvariant(ifOp.thenRegion(), indVar, iterArgs, + opsWithUsers, opsToHoist)) { return false; } - if (!areAllOpsInTheBlockListInvariant(ifOp.elseRegion(), indVar, opsWithUsers, - opsToHoist)) { + if (!areAllOpsInTheBlockListInvariant(ifOp.elseRegion(), indVar, iterArgs, + opsWithUsers, opsToHoist)) { return false; } @@ -185,6 +195,7 @@ void LoopInvariantCodeMotion::runOnAffineForOp(AffineForOp forOp) { auto *loopBody = forOp.getBody(); auto indVar = forOp.getInductionVar(); + ValueRange iterArgs = forOp.getIterOperands(); // This is the place where hoisted instructions would reside. OpBuilder b(forOp.getOperation()); @@ -200,7 +211,7 @@ if (!op.use_empty()) opsWithUsers.insert(&op); if (!isa(op)) { - if (isOpLoopInvariant(op, indVar, opsWithUsers, opsToHoist)) { + if (isOpLoopInvariant(op, indVar, iterArgs, opsWithUsers, opsToHoist)) { opsToMove.push_back(&op); } } diff --git a/mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir b/mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir --- a/mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir +++ b/mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir @@ -716,3 +716,20 @@ // CHECK-NEXT: } // CHECK-NEXT: mulf // CHECK-NEXT: affine.store + +// ----- + +// CHECK-LABEL: func @use_of_iter_args_not_invariant +func @use_of_iter_args_not_invariant(%m : memref<10xindex>) { + %sum_1 = constant 0 : index + %v0 = affine.for %arg1 = 0 to 11 iter_args (%prevAccum = %sum_1) -> index { + %newAccum = addi %prevAccum, %sum_1 : index + affine.yield %newAccum : index + } + return +} + +// CHECK: constant +// CHECK-NEXT: affine.for +// CHECK-NEXT: addi +// CHECK-NEXT: affine.yield