diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h --- a/mlir/include/mlir/Dialect/Affine/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Utils.h @@ -327,6 +327,7 @@ /// that would change the read within `memOp`. template bool hasNoInterveningEffect(Operation *start, T memOp); + } // namespace mlir #endif // MLIR_DIALECT_AFFINE_UTILS_H diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -33,26 +33,105 @@ class FuncOp; } // namespace func -/// Replace the `loop` with `newIterOperands` added as new initialization -/// values. `newYieldValuesFn` is a callback that can be used to specify -/// the additional values to be yielded by the loop. The number of -/// values returned by the callback should match the number of new -/// initialization values. This function -/// - Moves (i.e. doesnt clone) operations from the `loop` to the newly created +/// Replace the `for loop` with `newIterOperands` added as new initialization +/// values. `newYieldValuesFn` is a callback that can be used to specify the +/// additional values to be yielded by the loop. The number of values returned +/// by the callback should match the number of new initialization values. This +/// function: +/// - Moves (i.e. doesn't clone) operations from the `loop` to the newly created /// loop /// - Replaces the uses of `loop` with the new loop. -/// - `loop` isnt erased, but is left in a "no-op" state where the body of the +/// - `loop` isn't erased, but is left in a "no-op" state where the body of the /// loop just yields the basic block arguments that correspond to the /// initialization values of a loop. The loop is dead after this method. /// - If `replaceIterOperandsUsesInLoop` is true, all uses of the /// `newIterOperands` within the generated new loop are replaced /// with the corresponding `BlockArgument` in the loop body. +/// +/// As it is, this function template works for `scf::ForOp` and `AffineForOp` +/// loops. Adapting it to other `for` operations may require changes. using NewYieldValueFn = std::function( OpBuilder &b, Location loc, ArrayRef newBBArgs)>; -scf::ForOp replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop, - ValueRange newIterOperands, - const NewYieldValueFn &newYieldValuesFn, - bool replaceIterOperandsUsesInLoop = true); + +template +ForOpTy replaceForOpWithNewYields(OpBuilder &builder, ForOpTy loop, + ValueRange newIterOperands, + const NewYieldValueFn &newYieldValuesFn, + bool replaceIterOperandsUsesInLoop = true) { + // Create a new loop before the existing one, with the extra operands. + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPoint(loop); + auto operands = llvm::to_vector(loop.getIterOperands()); + operands.append(newIterOperands.begin(), newIterOperands.end()); + ForOpTy newLoop; + if constexpr (std::is_same_v) + newLoop = builder.create( + loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), + loop.getStep(), operands, + [](OpBuilder &, Location, Value, ValueRange) {}); + else + newLoop = builder.create( + loop.getLoc(), loop.getLowerBound().getOperands(), + loop.getLowerBound().getMap(), loop.getUpperBound().getOperands(), + loop.getUpperBound().getMap(), loop.getStep(), operands, + [](OpBuilder &, Location, Value, ValueRange) {}); + + Block *loopBody = loop.getBody(); + Block *newLoopBody = newLoop.getBody(); + + // Move the body of the original loop to the new loop. + newLoopBody->getOperations().splice(newLoopBody->end(), + loopBody->getOperations()); + + // Generate the new yield values to use by using the callback and append the + // yield values to the yield operation. + auto yield = cast(newLoopBody->getTerminator()); + ArrayRef newBBArgs = + newLoopBody->getArguments().take_back(newIterOperands.size()); + { + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPoint(yield); + SmallVector newYieldedValues = + newYieldValuesFn(builder, loop.getLoc(), newBBArgs); + assert(newIterOperands.size() == newYieldedValues.size() && + "expected as many new yield values as new iter operands"); + if constexpr (std::is_same_v) + yield.getResultsMutable().append(newYieldedValues); + else + yield.operandsMutable().append(newYieldedValues); + } + + // Remap the BlockArguments from the original loop to the new loop + // BlockArguments. + ArrayRef bbArgs = loopBody->getArguments(); + for (auto it : + llvm::zip(bbArgs, newLoopBody->getArguments().take_front(bbArgs.size()))) + std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); + + if (replaceIterOperandsUsesInLoop) { + // Replace all uses of `newIterOperands` with the corresponding basic block + // arguments. + for (auto it : llvm::zip(newIterOperands, newBBArgs)) { + std::get<0>(it).replaceUsesWithIf(std::get<1>(it), [&](OpOperand &use) { + Operation *user = use.getOwner(); + return newLoop->isProperAncestor(user); + }); + } + } + + // Replace all uses of the original loop with corresponding values from the + // new loop. + loop.replaceAllUsesWith( + newLoop.getResults().take_front(loop.getNumResults())); + + // Add a fake yield to the original loop body that just returns the + // BlockArguments corresponding to the iter_args. This makes it a no-op loop. + // The loop is dead. The caller is expected to erase it. + builder.setInsertionPointToEnd(loopBody); + builder.create(loop->getLoc(), loop.getRegionIterArgs()); + + return newLoop; +} /// Update a perfectly nested loop nest to yield new values from the innermost /// loop and propagating it up through the loop nest. This function diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -14,7 +14,9 @@ #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" +#include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -29,6 +31,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" using llvm::dbgs; @@ -321,7 +324,7 @@ ArrayRef newBBArgs) { return SmallVector{write.transferWriteOp.getVector()}; }; - auto newForOp = replaceLoopWithNewYields( + auto newForOp = replaceForOpWithNewYields( b, forOp, read.transferReadOp.getVector(), yieldFn); // Transfer write has been hoisted, need to update the vector and tensor @@ -425,10 +428,10 @@ LLVM_DEBUG(DBGS() << "Candidate for hoisting: " << *transferRead.getOperation() << "\n"); - auto loop = dyn_cast(transferRead->getParentOp()); + auto loop = dyn_cast(transferRead->getParentOp()); LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp() << "\n"); - if (!loop) + if (!isa_and_nonnull(loop)) return WalkResult::advance(); LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation() @@ -513,18 +516,34 @@ ArrayRef newBBArgs) { return SmallVector{transferWrite.getVector()}; }; - auto newForOp = - replaceLoopWithNewYields(b, loop, transferRead.getVector(), yieldFn); // Transfer write has been hoisted, need to update the written vector by // the value yielded by the newForOp. - transferWrite.getVectorMutable().assign(newForOp.getResults().back()); - - changed = true; - loop.erase(); - // Need to interrupt and restart because erasing the loop messes up the - // walk. - return WalkResult::interrupt(); + return TypeSwitch(loop) + .Case([&](scf::ForOp scfForOp) { + auto newForOp = replaceForOpWithNewYields( + b, scfForOp, transferRead.getVector(), yieldFn); + transferWrite.getVectorMutable().assign( + newForOp.getResults().back()); + changed = true; + loop.erase(); + // Need to interrupt and restart because erasing the loop messes up + // the walk. + return WalkResult::interrupt(); + }) + .Case([&](AffineForOp affineForOp) { + auto newForOp = + replaceForOpWithNewYields( + b, affineForOp, transferRead.getVector(), yieldFn); + transferWrite.getVectorMutable().assign( + newForOp.getResults().back()); + changed = true; + loop.erase(); + // Need to interrupt and restart because erasing the loop messes up + // the walk. + return WalkResult::interrupt(); + }) + .Default([](Operation *) { return WalkResult::interrupt(); }); }); } } diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -37,74 +37,6 @@ }; } // namespace -scf::ForOp -mlir::replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop, - ValueRange newIterOperands, - const NewYieldValueFn &newYieldValuesFn, - bool replaceIterOperandsUsesInLoop) { - // Create a new loop before the existing one, with the extra operands. - OpBuilder::InsertionGuard g(builder); - builder.setInsertionPoint(loop); - auto operands = llvm::to_vector(loop.getIterOperands()); - operands.append(newIterOperands.begin(), newIterOperands.end()); - scf::ForOp newLoop = builder.create( - loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), - operands, [](OpBuilder &, Location, Value, ValueRange) {}); - - Block *loopBody = loop.getBody(); - Block *newLoopBody = newLoop.getBody(); - - // Move the body of the original loop to the new loop. - newLoopBody->getOperations().splice(newLoopBody->end(), - loopBody->getOperations()); - - // Generate the new yield values to use by using the callback and append the - // yield values to the scf.yield operation. - auto yield = cast(newLoopBody->getTerminator()); - ArrayRef newBBArgs = - newLoopBody->getArguments().take_back(newIterOperands.size()); - { - OpBuilder::InsertionGuard g(builder); - builder.setInsertionPoint(yield); - SmallVector newYieldedValues = - newYieldValuesFn(builder, loop.getLoc(), newBBArgs); - assert(newIterOperands.size() == newYieldedValues.size() && - "expected as many new yield values as new iter operands"); - yield.getResultsMutable().append(newYieldedValues); - } - - // Remap the BlockArguments from the original loop to the new loop - // BlockArguments. - ArrayRef bbArgs = loopBody->getArguments(); - for (auto it : - llvm::zip(bbArgs, newLoopBody->getArguments().take_front(bbArgs.size()))) - std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); - - if (replaceIterOperandsUsesInLoop) { - // Replace all uses of `newIterOperands` with the corresponding basic block - // arguments. - for (auto it : llvm::zip(newIterOperands, newBBArgs)) { - std::get<0>(it).replaceUsesWithIf(std::get<1>(it), [&](OpOperand &use) { - Operation *user = use.getOwner(); - return newLoop->isProperAncestor(user); - }); - } - } - - // Replace all uses of the original loop with corresponding values from the - // new loop. - loop.replaceAllUsesWith( - newLoop.getResults().take_front(loop.getNumResults())); - - // Add a fake yield to the original loop body that just returns the - // BlockArguments corresponding to the iter_args. This makes it a no-op loop. - // The loop is dead. The caller is expected to erase it. - builder.setInsertionPointToEnd(loopBody); - builder.create(loop->getLoc(), loop.getRegionIterArgs()); - - return newLoop; -} - SmallVector mlir::replaceLoopNestWithNewYields( OpBuilder &builder, ArrayRef loopNest, ValueRange newIterOperands, const NewYieldValueFn &newYieldValueFn, @@ -145,10 +77,10 @@ // } // ``` // - // The inner most loop is handled using the `replaceLoopWithNewYields` + // The innermost loop is handled using the `replaceForWithNewYields` // that works on a single loop. if (loopNest.size() == 1) { - auto innerMostLoop = replaceLoopWithNewYields( + auto innerMostLoop = replaceForOpWithNewYields( builder, loopNest.back(), newIterOperands, newYieldValueFn, replaceIterOperandsUsesInLoop); return {innerMostLoop}; @@ -168,8 +100,9 @@ [](OpResult r) -> Value { return r; })); }; scf::ForOp outerMostLoop = - replaceLoopWithNewYields(builder, loopNest.front(), newIterOperands, fn, - replaceIterOperandsUsesInLoop); + replaceForOpWithNewYields( + builder, loopNest.front(), newIterOperands, fn, + replaceIterOperandsUsesInLoop); newLoopNest.insert(newLoopNest.begin(), outerMostLoop); return newLoopNest; } diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir --- a/mlir/test/Dialect/Linalg/hoisting.mlir +++ b/mlir/test/Dialect/Linalg/hoisting.mlir @@ -469,3 +469,39 @@ return %1 : tensor } +// ----- + +// CHECK-LABEL: func @hoist_vector_transfer_pairs_in_affine_loops( +// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]+]]: memref<64x64xi32>, +// CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]+]]: memref<64x64xi32>, +// CHECK-SAME: %[[MEMREF2:[a-zA-Z0-9]+]]: memref<64x64xi32>) { +// CHECK: %[[C0:.*]] = arith.constant 0 : i32 +// CHECK: affine.for %[[I:.*]] = 0 to 64 { +// CHECK: affine.for %[[J:.*]] = 0 to 64 step 16 { +// CHECK: %[[R0:.*]] = vector.transfer_read %[[MEMREF2]][%[[I]], %[[J]]], %[[C0]] : memref<64x64xi32>, vector<16xi32> +// CHECK: %[[R:.*]] = affine.for %[[K:.*]] = 0 to 64 iter_args(%[[ACC:.*]] = %[[R0]]) -> (vector<16xi32>) { +// CHECK: %[[AV:.*]] = vector.transfer_read %[[MEMREF0]][%[[I]], %[[K]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32> +// CHECK: %[[BV:.*]] = vector.transfer_read %[[MEMREF1]][%[[K]], %[[J]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32> +// CHECK: %[[T0:.*]] = arith.muli %[[AV]], %[[BV]] : vector<16xi32> +// CHECK: %[[T1:.*]] = arith.addi %[[ACC]], %[[T0]] : vector<16xi32> +// CHECK: affine.yield %[[T1]] : vector<16xi32> +// CHECK: } +// CHECK: vector.transfer_write %[[R]], %[[MEMREF2]][%[[I]], %[[J]]] : vector<16xi32>, memref<64x64xi32> +// CHECK: } +// CHECK: } +func.func @hoist_vector_transfer_pairs_in_affine_loops(%memref0: memref<64x64xi32>, %memref1: memref<64x64xi32>, %memref2: memref<64x64xi32>) { + %c0_i32 = arith.constant 0 : i32 + affine.for %arg3 = 0 to 64 { + affine.for %arg4 = 0 to 64 step 16 { + affine.for %arg5 = 0 to 64 { + %0 = vector.transfer_read %memref0[%arg3, %arg5], %c0_i32 {permutation_map = affine_map<(d0, d1) -> (0)>} : memref<64x64xi32>, vector<16xi32> + %1 = vector.transfer_read %memref1[%arg5, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32> + %2 = vector.transfer_read %memref2[%arg3, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32> + %3 = arith.muli %0, %1 : vector<16xi32> + %4 = arith.addi %2, %3 : vector<16xi32> + vector.transfer_write %4, %memref2[%arg3, %arg4] : vector<16xi32>, memref<64x64xi32> + } + } + } + return +} diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp --- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp +++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp @@ -62,7 +62,8 @@ return newYieldValues; }; OpBuilder b(forOp); - replaceLoopWithNewYields(b, forOp, newInitValues, fn); + replaceForOpWithNewYields(b, forOp, + newInitValues, fn); }); } }