diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp @@ -57,7 +57,7 @@ } // namespace static bool -checkInvarianceOfNestedIfOps(Operation *op, Value indVar, ValueRange iterArgs, +checkInvarianceOfNestedIfOps(AffineIfOp ifOp, Value indVar, ValueRange iterArgs, SmallPtrSetImpl &opsWithUsers, SmallPtrSetImpl &opsToHoist); static bool isOpLoopInvariant(Operation &op, Value indVar, ValueRange iterArgs, @@ -76,16 +76,14 @@ SmallPtrSetImpl &opsToHoist) { LLVM_DEBUG(llvm::dbgs() << "iterating on op: " << op;); - if (isa(op)) { - if (!checkInvarianceOfNestedIfOps(&op, indVar, iterArgs, opsWithUsers, - opsToHoist)) { + if (auto ifOp = dyn_cast(op)) { + if (!checkInvarianceOfNestedIfOps(ifOp, indVar, iterArgs, opsWithUsers, + opsToHoist)) return false; - } } else if (auto forOp = dyn_cast(op)) { if (!areAllOpsInTheBlockListInvariant(forOp.getLoopBody(), indVar, iterArgs, - opsWithUsers, opsToHoist)) { + opsWithUsers, opsToHoist)) return false; - } } else if (isa(op)) { // TODO: Support DMA ops. return false; @@ -93,15 +91,14 @@ // Register op in the set of ops that have users. opsWithUsers.insert(&op); if (isa(op)) { - Value memref = isa(op) - ? cast(op).getMemRef() - : cast(op).getMemRef(); + auto read = dyn_cast(op); + Value memref = read ? read.getMemRef() + : cast(op).getMemRef(); for (auto *user : memref.getUsers()) { // If this memref has a user that is a DMA, give up because these // operations write to this memref. - if (isa(op)) { + if (isa(op)) return false; - } // If the memref used by the load/store is used in a store elsewhere in // the loop nest, we do not hoist. Similarly, if the memref used in a // load is also being stored too, we do not hoist the load. @@ -112,16 +109,15 @@ SmallVector userIVs; getLoopIVs(*user, &userIVs); // Check that userIVs don't contain the for loop around the op. - if (llvm::is_contained(userIVs, getForInductionVarOwner(indVar))) { + if (llvm::is_contained(userIVs, getForInductionVarOwner(indVar))) return false; - } } } } } if (op.getNumOperands() == 0 && !isa(op)) { - LLVM_DEBUG(llvm::dbgs() << "\nNon-constant op with 0 operands\n"); + LLVM_DEBUG(llvm::dbgs() << "Non-constant op with 0 operands\n"); return false; } } @@ -131,29 +127,28 @@ auto *operandSrc = op.getOperand(i).getDefiningOp(); LLVM_DEBUG( - op.getOperand(i).print(llvm::dbgs() << "\nIterating on operand\n")); + op.getOperand(i).print(llvm::dbgs() << "Iterating on operand\n")); // If the loop IV is the operand, this op isn't loop invariant. if (indVar == op.getOperand(i)) { - LLVM_DEBUG(llvm::dbgs() << "\nLoop IV is the operand\n"); + LLVM_DEBUG(llvm::dbgs() << "Loop IV is the operand\n"); return false; } // If the one of the iter_args is the operand, this op isn't loop invariant. if (llvm::is_contained(iterArgs, op.getOperand(i))) { - LLVM_DEBUG(llvm::dbgs() << "\nOne of the iter_args is the operand\n"); + LLVM_DEBUG(llvm::dbgs() << "One of the iter_args is the operand\n"); return false; } - if (operandSrc != nullptr) { - LLVM_DEBUG(llvm::dbgs() << *operandSrc << "\nIterating on operand src\n"); + if (operandSrc) { + LLVM_DEBUG(llvm::dbgs() << *operandSrc << "Iterating on operand src\n"); - // If the value was defined in the loop (outside of the - // if/else region), and that operation itself wasn't meant to - // be hoisted, then mark this operation loop dependent. - if (opsWithUsers.count(operandSrc) && opsToHoist.count(operandSrc) == 0) { + // If the value was defined in the loop (outside of the if/else region), + // and that operation itself wasn't meant to be hoisted, then mark this + // operation loop dependent. + if (opsWithUsers.count(operandSrc) && opsToHoist.count(operandSrc) == 0) return false; - } } } @@ -170,9 +165,8 @@ for (auto &b : blockList) { for (auto &op : b) { - if (!isOpLoopInvariant(op, indVar, iterArgs, opsWithUsers, opsToHoist)) { + if (!isOpLoopInvariant(op, indVar, iterArgs, opsWithUsers, opsToHoist)) return false; - } } } @@ -180,22 +174,17 @@ } // Returns true if the affine.if op can be hoisted. -bool checkInvarianceOfNestedIfOps(Operation *op, Value indVar, +bool checkInvarianceOfNestedIfOps(AffineIfOp ifOp, Value indVar, ValueRange iterArgs, SmallPtrSetImpl &opsWithUsers, SmallPtrSetImpl &opsToHoist) { - assert(isa(op)); - auto ifOp = cast(op); - if (!areAllOpsInTheBlockListInvariant(ifOp.getThenRegion(), indVar, iterArgs, - opsWithUsers, opsToHoist)) { + opsWithUsers, opsToHoist)) return false; - } if (!areAllOpsInTheBlockListInvariant(ifOp.getElseRegion(), indVar, iterArgs, - opsWithUsers, opsToHoist)) { + opsWithUsers, opsToHoist)) return false; - } return true; }