diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h --- a/mlir/include/mlir/Dialect/Affine/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Utils.h @@ -327,6 +327,29 @@ /// that would change the read within `memOp`. template bool hasNoInterveningEffect(Operation *start, T memOp); + +/// Replace the `Affine for loop` with `newIterOperands` added as new +/// initialization values. `newYieldValuesFn` is a callback that can be used to +/// specify the additional values to be yielded by the loop. The number of +/// values returned by the callback should match the number of new +/// initialization values. This function +/// - Moves (i.e. doesn't clone) operations from the `loop` to the newly created +/// loop +/// - Replaces the uses of `loop` with the new loop. +/// - `loop` isn't erased, but is left in a "no-op" state where the body of the +/// loop just yields the basic block arguments that correspond to the +/// initialization values of a loop. The loop is dead after this method. +/// - If `replaceIterOperandsUsesInLoop` is true, all uses of the +/// `newIterOperands` within the generated new loop are replaced +/// with the corresponding `BlockArgument` in the loop body. +using NewYieldValueFn = std::function( + OpBuilder &b, Location loc, ArrayRef newBBArgs)>; +AffineForOp +replaceAffineForWithNewYields(OpBuilder &builder, AffineForOp loop, + ValueRange newIterOperands, + const NewYieldValueFn &newYieldValuesFn, + bool replaceIterOperandsUsesInLoop = true); + } // namespace mlir #endif // MLIR_DIALECT_AFFINE_UTILS_H diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -33,15 +33,15 @@ class FuncOp; } // namespace func -/// Replace the `loop` with `newIterOperands` added as new initialization -/// values. `newYieldValuesFn` is a callback that can be used to specify -/// the additional values to be yielded by the loop. The number of +/// Replace the `SCF for loop` with `newIterOperands` added as new +/// initialization values. `newYieldValuesFn` is a callback that can be used to +/// specify the additional values to be yielded by the loop. The number of /// values returned by the callback should match the number of new /// initialization values. This function -/// - Moves (i.e. doesnt clone) operations from the `loop` to the newly created +/// - Moves (i.e. doesn't clone) operations from the `loop` to the newly created /// loop /// - Replaces the uses of `loop` with the new loop. -/// - `loop` isnt erased, but is left in a "no-op" state where the body of the +/// - `loop` isn't erased, but is left in a "no-op" state where the body of the /// loop just yields the basic block arguments that correspond to the /// initialization values of a loop. The loop is dead after this method. /// - If `replaceIterOperandsUsesInLoop` is true, all uses of the @@ -49,10 +49,11 @@ /// with the corresponding `BlockArgument` in the loop body. using NewYieldValueFn = std::function( OpBuilder &b, Location loc, ArrayRef newBBArgs)>; -scf::ForOp replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop, - ValueRange newIterOperands, - const NewYieldValueFn &newYieldValuesFn, - bool replaceIterOperandsUsesInLoop = true); +scf::ForOp +replaceSCFForWithNewYields(OpBuilder &builder, scf::ForOp loop, + ValueRange newIterOperands, + const NewYieldValueFn &newYieldValuesFn, + bool replaceIterOperandsUsesInLoop = true); /// Update a perfectly nested loop nest to yield new values from the innermost /// loop and propagating it up through the loop nest. This function diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -1872,3 +1872,73 @@ results.push_back(residual); return results; } + +AffineForOp +mlir::replaceAffineForWithNewYields(OpBuilder &builder, AffineForOp loop, + ValueRange newIterOperands, + const NewYieldValueFn &newYieldValuesFn, + bool replaceIterOperandsUsesInLoop) { + // Create a new loop before the existing one, with the extra operands. + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPoint(loop); + auto operands = llvm::to_vector(loop.getIterOperands()); + operands.append(newIterOperands.begin(), newIterOperands.end()); + AffineForOp newLoop = builder.create( + loop.getLoc(), loop.getLowerBound().getOperands(), + loop.getLowerBound().getMap(), loop.getUpperBound().getOperands(), + loop.getUpperBound().getMap(), loop.getStep(), operands, + [](OpBuilder &, Location, Value, ValueRange) {}); + + Block *loopBody = loop.getBody(); + Block *newLoopBody = newLoop.getBody(); + + // Move the body of the original loop to the new loop. + newLoopBody->getOperations().splice(newLoopBody->end(), + loopBody->getOperations()); + + // Generate the new yield values to use by using the callback and append the + // yield values to the affine.yield operation. + auto yield = cast(newLoopBody->getTerminator()); + ArrayRef newBBArgs = + newLoopBody->getArguments().take_back(newIterOperands.size()); + { + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPoint(yield); + SmallVector newYieldedValues = + newYieldValuesFn(builder, loop.getLoc(), newBBArgs); + assert(newIterOperands.size() == newYieldedValues.size() && + "expected as many new yield values as new iter operands"); + yield.operandsMutable().append(newYieldedValues); + } + + // Remap the BlockArguments from the original loop to the new loop + // BlockArguments. + ArrayRef bbArgs = loopBody->getArguments(); + for (auto it : + llvm::zip(bbArgs, newLoopBody->getArguments().take_front(bbArgs.size()))) + std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); + + if (replaceIterOperandsUsesInLoop) { + // Replace all uses of `newIterOperands` with the corresponding basic block + // arguments. + for (auto it : llvm::zip(newIterOperands, newBBArgs)) { + std::get<0>(it).replaceUsesWithIf(std::get<1>(it), [&](OpOperand &use) { + Operation *user = use.getOwner(); + return newLoop->isProperAncestor(user); + }); + } + } + + // Replace all uses of the original loop with corresponding values from the + // new loop. + loop.replaceAllUsesWith( + newLoop.getResults().take_front(loop.getNumResults())); + + // Add a fake yield to the original loop body that just returns the + // BlockArguments corresponding to the iter_args. This makes it a no-op loop. + // The loop is dead. The caller is expected to erase it. + builder.setInsertionPointToEnd(loopBody); + builder.create(loop->getLoc(), loop.getRegionIterArgs()); + + return newLoop; +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -14,7 +14,9 @@ #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" +#include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -29,6 +31,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" using llvm::dbgs; @@ -321,7 +324,7 @@ ArrayRef newBBArgs) { return SmallVector{write.transferWriteOp.getVector()}; }; - auto newForOp = replaceLoopWithNewYields( + auto newForOp = replaceSCFForWithNewYields( b, forOp, read.transferReadOp.getVector(), yieldFn); // Transfer write has been hoisted, need to update the vector and tensor @@ -425,10 +428,10 @@ LLVM_DEBUG(DBGS() << "Candidate for hoisting: " << *transferRead.getOperation() << "\n"); - auto loop = dyn_cast(transferRead->getParentOp()); + auto loop = dyn_cast(transferRead->getParentOp()); LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp() << "\n"); - if (!loop) + if (!isa_and_nonnull(loop)) return WalkResult::advance(); LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation() @@ -513,18 +516,33 @@ ArrayRef newBBArgs) { return SmallVector{transferWrite.getVector()}; }; - auto newForOp = - replaceLoopWithNewYields(b, loop, transferRead.getVector(), yieldFn); // Transfer write has been hoisted, need to update the written vector by // the value yielded by the newForOp. - transferWrite.getVectorMutable().assign(newForOp.getResults().back()); - - changed = true; - loop.erase(); - // Need to interrupt and restart because erasing the loop messes up the - // walk. - return WalkResult::interrupt(); + return TypeSwitch(loop) + .Case([&](scf::ForOp scfForOp) { + auto newForOp = replaceSCFForWithNewYields( + b, scfForOp, transferRead.getVector(), yieldFn); + transferWrite.getVectorMutable().assign(newForOp.getResults().back()); + changed = true; + loop.erase(); + // Need to interrupt and restart because erasing the loop messes up + // the walk. + return WalkResult::interrupt(); + }) + .Case([&](AffineForOp affineForOp) { + auto newForOp = replaceAffineForWithNewYields( + b, cast(loop), transferRead.getVector(), yieldFn); + transferWrite.getVectorMutable().assign(newForOp.getResults().back()); + changed = true; + loop.erase(); + // Need to interrupt and restart because erasing the loop messes up + // the walk. + return WalkResult::interrupt(); + }) + .Default([](Operation *) { + return WalkResult::interrupt(); + }); }); } } diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -38,10 +38,10 @@ } // namespace scf::ForOp -mlir::replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop, - ValueRange newIterOperands, - const NewYieldValueFn &newYieldValuesFn, - bool replaceIterOperandsUsesInLoop) { +mlir::replaceSCFForWithNewYields(OpBuilder &builder, scf::ForOp loop, + ValueRange newIterOperands, + const NewYieldValueFn &newYieldValuesFn, + bool replaceIterOperandsUsesInLoop) { // Create a new loop before the existing one, with the extra operands. OpBuilder::InsertionGuard g(builder); builder.setInsertionPoint(loop); @@ -145,10 +145,10 @@ // } // ``` // - // The inner most loop is handled using the `replaceLoopWithNewYields` + // The innermost loop is handled using the `replaceSCFForWithNewYields` // that works on a single loop. if (loopNest.size() == 1) { - auto innerMostLoop = replaceLoopWithNewYields( + auto innerMostLoop = replaceSCFForWithNewYields( builder, loopNest.back(), newIterOperands, newYieldValueFn, replaceIterOperandsUsesInLoop); return {innerMostLoop}; @@ -168,8 +168,8 @@ [](OpResult r) -> Value { return r; })); }; scf::ForOp outerMostLoop = - replaceLoopWithNewYields(builder, loopNest.front(), newIterOperands, fn, - replaceIterOperandsUsesInLoop); + replaceSCFForWithNewYields(builder, loopNest.front(), newIterOperands, + fn, replaceIterOperandsUsesInLoop); newLoopNest.insert(newLoopNest.begin(), outerMostLoop); return newLoopNest; } diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir --- a/mlir/test/Dialect/Linalg/hoisting.mlir +++ b/mlir/test/Dialect/Linalg/hoisting.mlir @@ -469,3 +469,39 @@ return %1 : tensor } +// ----- + +// CHECK-LABEL: func @hoist_vector_transfer_pairs_in_affine_loops( +// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]+]]: memref<64x64xi32>, +// CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]+]]: memref<64x64xi32>, +// CHECK-SAME: %[[MEMREF2:[a-zA-Z0-9]+]]: memref<64x64xi32>) { +// CHECK: %[[C0:.*]] = arith.constant 0 : i32 +// CHECK: affine.for %[[I:.*]] = 0 to 64 { +// CHECK: affine.for %[[J:.*]] = 0 to 64 step 16 { +// CHECK: %[[R0:.*]] = vector.transfer_read %[[MEMREF2]][%[[I]], %[[J]]], %[[C0]] : memref<64x64xi32>, vector<16xi32> +// CHECK: %[[R:.*]] = affine.for %[[K:.*]] = 0 to 64 iter_args(%[[ACC:.*]] = %[[R0]]) -> (vector<16xi32>) { +// CHECK: %[[AV:.*]] = vector.transfer_read %[[MEMREF0]][%[[I]], %[[K]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32> +// CHECK: %[[BV:.*]] = vector.transfer_read %[[MEMREF1]][%[[K]], %[[J]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32> +// CHECK: %[[T0:.*]] = arith.muli %[[AV]], %[[BV]] : vector<16xi32> +// CHECK: %[[T1:.*]] = arith.addi %[[ACC]], %[[T0]] : vector<16xi32> +// CHECK: affine.yield %[[T1]] : vector<16xi32> +// CHECK: } +// CHECK: vector.transfer_write %[[R]], %[[MEMREF2]][%[[I]], %[[J]]] : vector<16xi32>, memref<64x64xi32> +// CHECK: } +// CHECK: } +func.func @hoist_vector_transfer_pairs_in_affine_loops(%memref0: memref<64x64xi32>, %memref1: memref<64x64xi32>, %memref2: memref<64x64xi32>) { + %c0_i32 = arith.constant 0 : i32 + affine.for %arg3 = 0 to 64 { + affine.for %arg4 = 0 to 64 step 16 { + affine.for %arg5 = 0 to 64 { + %0 = vector.transfer_read %memref0[%arg3, %arg5], %c0_i32 {permutation_map = affine_map<(d0, d1) -> (0)>} : memref<64x64xi32>, vector<16xi32> + %1 = vector.transfer_read %memref1[%arg5, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32> + %2 = vector.transfer_read %memref2[%arg3, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32> + %3 = arith.muli %0, %1 : vector<16xi32> + %4 = arith.addi %2, %3 : vector<16xi32> + vector.transfer_write %4, %memref2[%arg3, %arg4] : vector<16xi32>, memref<64x64xi32> + } + } + } + return +} diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp --- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp +++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp @@ -62,7 +62,7 @@ return newYieldValues; }; OpBuilder b(forOp); - replaceLoopWithNewYields(b, forOp, newInitValues, fn); + replaceSCFForWithNewYields(b, forOp, newInitValues, fn); }); } }