diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp @@ -50,25 +50,25 @@ static bool checkInvarianceOfNestedIfOps(Operation *op, Value indVar, - SmallPtrSetImpl &definedOps, + SmallPtrSetImpl &opsWithUsers, SmallPtrSetImpl &opsToHoist); static bool isOpLoopInvariant(Operation &op, Value indVar, - SmallPtrSetImpl &definedOps, + SmallPtrSetImpl &opsWithUsers, SmallPtrSetImpl &opsToHoist); static bool areAllOpsInTheBlockListInvariant(Region &blockList, Value indVar, - SmallPtrSetImpl &definedOps, + SmallPtrSetImpl &opsWithUsers, SmallPtrSetImpl &opsToHoist); // Returns true if the individual op is loop invariant. bool isOpLoopInvariant(Operation &op, Value indVar, - SmallPtrSetImpl &definedOps, + SmallPtrSetImpl &opsWithUsers, SmallPtrSetImpl &opsToHoist) { LLVM_DEBUG(llvm::dbgs() << "iterating on op: " << op;); if (isa(op)) { - if (!checkInvarianceOfNestedIfOps(&op, indVar, definedOps, opsToHoist)) { + if (!checkInvarianceOfNestedIfOps(&op, indVar, opsWithUsers, opsToHoist)) { return false; } } else if (isa(op)) { @@ -82,7 +82,7 @@ // Register op in the set of ops defined inside the loop. This set is used // to prevent hoisting ops that depend on other ops defined inside the loop // which are themselves not being hoisted. - definedOps.insert(&op); + opsWithUsers.insert(&op); if (isa(op)) { Value memref = isa(op) @@ -135,7 +135,8 @@ // If the value was defined in the loop (outside of the // if/else region), and that operation itself wasn't meant to // be hoisted, then mark this operation loop dependent. - if (definedOps.count(operandSrc) && opsToHoist.count(operandSrc) == 0) { + if (opsWithUsers.count(operandSrc) && + opsToHoist.count(operandSrc) == 0) { return false; } } @@ -149,12 +150,12 @@ // Checks if all ops in a region (i.e. list of blocks) are loop invariant. bool areAllOpsInTheBlockListInvariant( - Region &blockList, Value indVar, SmallPtrSetImpl &definedOps, + Region &blockList, Value indVar, SmallPtrSetImpl &opsWithUsers, SmallPtrSetImpl &opsToHoist) { for (auto &b : blockList) { for (auto &op : b) { - if (!isOpLoopInvariant(op, indVar, definedOps, opsToHoist)) { + if (!isOpLoopInvariant(op, indVar, opsWithUsers, opsToHoist)) { return false; } } @@ -165,17 +166,17 @@ // Returns true if the affine.if op can be hoisted. bool checkInvarianceOfNestedIfOps(Operation *op, Value indVar, - SmallPtrSetImpl &definedOps, + SmallPtrSetImpl &opsWithUsers, SmallPtrSetImpl &opsToHoist) { assert(isa(op)); auto ifOp = cast(op); - if (!areAllOpsInTheBlockListInvariant(ifOp.thenRegion(), indVar, definedOps, + if (!areAllOpsInTheBlockListInvariant(ifOp.thenRegion(), indVar, opsWithUsers, opsToHoist)) { return false; } - if (!areAllOpsInTheBlockListInvariant(ifOp.elseRegion(), indVar, definedOps, + if (!areAllOpsInTheBlockListInvariant(ifOp.elseRegion(), indVar, opsWithUsers, opsToHoist)) { return false; } @@ -187,7 +188,7 @@ auto *loopBody = forOp.getBody(); auto indVar = forOp.getInductionVar(); - SmallPtrSet definedOps; + SmallPtrSet opsWithUsers; // This is the place where hoisted instructions would reside. OpBuilder b(forOp.getOperation()); @@ -195,10 +196,14 @@ SmallVector opsToMove; for (auto &op : *loopBody) { + // Prevent hoisting of the users of affine loop results. + if (isa(op) && (op.getNumResults() > 0)) { + opsWithUsers.insert(&op); + } // We don't hoist for loops. if (!isa(op)) { if (!isa(op)) { - if (isOpLoopInvariant(op, indVar, definedOps, opsToHoist)) { + if (isOpLoopInvariant(op, indVar, opsWithUsers, opsToHoist)) { opsToMove.push_back(&op); } } diff --git a/mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir b/mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir --- a/mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir +++ b/mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir @@ -613,3 +613,31 @@ // CHECK-NEXT: addf // CHECK-NEXT: affine.vector_store // CHECK-NEXT: affine.for + +// ----- + +// CHECK-LABEL: func @reduction_loop_no_invariant( +func @reduction_loop_no_invariant(%arg0 : memref<30x512xf32, 1>, + %arg1 : memref<30xf32, 1>) { + %accum = memref.alloca() : memref<64xf32> + %zero = constant dense<0.0> : vector<64xf32> + affine.for %dim1 = 0 to 30 { + %vecAccum = affine.for %dim0 = 0 to 512 step 64 iter_args (%prevAccum = %zero) -> vector<64xf32> { + %arg0Vector = affine.vector_load %arg0[%dim1, %dim0] : memref<30x512xf32, 1>, vector<64xf32> + %newAccum = addf %prevAccum, %arg0Vector : vector<64xf32> + affine.yield %newAccum : vector<64xf32> + } + %scalarAccum = vector.reduction "add", %vecAccum : vector<64xf32> into f32 + affine.store %scalarAccum, %arg1[%dim1] : memref<30xf32, 1> + } + return +} + +// CHECK: affine.for +// CHECK: %[[sum:.*]] = affine.for +// CHECK: affine.vector_load +// CHECK: addf +// CHECK: affine.yield +// CHECK: } +// CHECK: vector.reduction "add", %[[sum:.*]] +// CHECK: affine.store