diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h --- a/mlir/include/mlir/Dialect/Affine/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Utils.h @@ -327,6 +327,14 @@ /// that would change the read within `memOp`. template bool hasNoInterveningEffect(Operation *start, T memOp); + +using NewYieldValueFn = std::function( + OpBuilder &b, Location loc, ArrayRef newBBArgs)>; +AffineForOp replaceLoopWithNewYields(OpBuilder &builder, AffineForOp loop, + ValueRange newIterOperands, + const NewYieldValueFn &newYieldValuesFn, + bool replaceIterOperandsUsesInLoop = true); + } // namespace mlir #endif // MLIR_DIALECT_AFFINE_UTILS_H diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -1872,3 +1872,72 @@ results.push_back(residual); return results; } + +AffineForOp +mlir::replaceLoopWithNewYields(OpBuilder &builder, AffineForOp loop, + ValueRange newIterOperands, + const NewYieldValueFn &newYieldValuesFn, + bool replaceIterOperandsUsesInLoop) { + // Create a new loop before the existing one, with the extra operands. + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPoint(loop); + auto operands = llvm::to_vector(loop.getIterOperands()); + operands.append(newIterOperands.begin(), newIterOperands.end()); + AffineForOp newLoop = builder.create( + loop.getLoc(), loop.getLowerBound().getOperands(), loop.getLowerBound().getMap(), + loop.getUpperBound().getOperands(), loop.getUpperBound().getMap(), loop.getStep(), + operands, [](OpBuilder &, Location, Value, ValueRange) {}); + + Block *loopBody = loop.getBody(); + Block *newLoopBody = newLoop.getBody(); + + // Move the body of the original loop to the new loop. + newLoopBody->getOperations().splice(newLoopBody->end(), + loopBody->getOperations()); + + // Generate the new yield values to use by using the callback and append the + // yield values to the scf.yield operation. + auto yield = cast(newLoopBody->getTerminator()); + ArrayRef newBBArgs = + newLoopBody->getArguments().take_back(newIterOperands.size()); + { + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPoint(yield); + SmallVector newYieldedValues = + newYieldValuesFn(builder, loop.getLoc(), newBBArgs); + assert(newIterOperands.size() == newYieldedValues.size() && + "expected as many new yield values as new iter operands"); + yield.operandsMutable().append(newYieldedValues); + } + + // Remap the BlockArguments from the original loop to the new loop + // BlockArguments. + ArrayRef bbArgs = loopBody->getArguments(); + for (auto it : + llvm::zip(bbArgs, newLoopBody->getArguments().take_front(bbArgs.size()))) + std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); + + if (replaceIterOperandsUsesInLoop) { + // Replace all uses of `newIterOperands` with the corresponding basic block + // arguments. + for (auto it : llvm::zip(newIterOperands, newBBArgs)) { + std::get<0>(it).replaceUsesWithIf(std::get<1>(it), [&](OpOperand &use) { + Operation *user = use.getOwner(); + return newLoop->isProperAncestor(user); + }); + } + } + + // Replace all uses of the original loop with corresponding values from the + // new loop. + loop.replaceAllUsesWith( + newLoop.getResults().take_front(loop.getNumResults())); + + // Add a fake yield to the original loop body that just returns the + // BlockArguments corresponding to the iter_args. This makes it a no-op loop. + // The loop is dead. The caller is expected to erase it. + builder.setInsertionPointToEnd(loopBody); + builder.create(loop->getLoc(), loop.getRegionIterArgs()); + + return newLoop; +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -15,6 +15,8 @@ #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -425,7 +427,7 @@ LLVM_DEBUG(DBGS() << "Candidate for hoisting: " << *transferRead.getOperation() << "\n"); - auto loop = dyn_cast(transferRead->getParentOp()); + auto loop = dyn_cast(transferRead->getParentOp()); LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp() << "\n"); if (!loop) @@ -513,12 +515,21 @@ ArrayRef newBBArgs) { return SmallVector{transferWrite.getVector()}; }; - auto newForOp = - replaceLoopWithNewYields(b, loop, transferRead.getVector(), yieldFn); - // Transfer write has been hoisted, need to update the written vector by - // the value yielded by the newForOp. - transferWrite.getVectorMutable().assign(newForOp.getResults().back()); + if (isa(loop)) { + auto newForOp = + replaceLoopWithNewYields(b, cast(loop), transferRead.getVector(), yieldFn); + + // Transfer write has been hoisted, need to update the written vector by + // the value yielded by the newForOp. + transferWrite.getVectorMutable().assign(newForOp.getResults().back()); + } else if (isa(loop)) { + auto newForOp = + replaceLoopWithNewYields(b, cast(loop), transferRead.getVector(), yieldFn); + // Transfer write has been hoisted, need to update the written vector by + // the value yielded by the newForOp. + transferWrite.getVectorMutable().assign(newForOp.getResults().back()); + } changed = true; loop.erase(); diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir --- a/mlir/test/Dialect/Linalg/hoisting.mlir +++ b/mlir/test/Dialect/Linalg/hoisting.mlir @@ -469,3 +469,39 @@ return %1 : tensor } +// ----- + +// CHECK-LABEL: func @hoist_vector_transfer_pairs_in_affine_loops( +// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]+]]: memref<64x64xi32>, +// CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]+]]: memref<64x64xi32>, +// CHECK-SAME: %[[MEMREF2:[a-zA-Z0-9]+]]: memref<64x64xi32>) { +// CHECK: %[[C0:.*]] = arith.constant 0 : i32 +// CHECK: affine.for %[[I:.*]] = 0 to 64 { +// CHECK: affine.for %[[J:.*]] = 0 to 64 step 16 { +// CHECK: %[[R0:.*]] = vector.transfer_read %[[MEMREF2]][%[[I]], %[[J]]], %[[C0]] : memref<64x64xi32>, vector<16xi32> +// CHECK: %[[R:.*]] = affine.for %[[K:.*]] = 0 to 64 iter_args(%[[ACC:.*]] = %[[R0]]) -> (vector<16xi32>) { +// CHECK: %[[AV:.*]] = vector.transfer_read %[[MEMREF0]][%[[I]], %[[K]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32> +// CHECK: %[[BV:.*]] = vector.transfer_read %[[MEMREF1]][%[[K]], %[[J]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32> +// CHECK: %[[T0:.*]] = arith.muli %[[AV]], %[[BV]] : vector<16xi32> +// CHECK: %[[T1:.*]] = arith.addi %[[ACC]], %[[T0]] : vector<16xi32> +// CHECK: affine.yield %[[T1]] : vector<16xi32> +// CHECK: } +// CHECK: vector.transfer_write %[[R]], %[[MEMREF2]][%[[I]], %[[J]]] : vector<16xi32>, memref<64x64xi32> +// CHECK: } +// CHECK: } +func.func @hoist_vector_transfer_pairs_in_affine_loops(%memref0: memref<64x64xi32>, %memref1: memref<64x64xi32>, %memref2: memref<64x64xi32>) { + %c0_i32 = arith.constant 0 : i32 + affine.for %arg3 = 0 to 64 { + affine.for %arg4 = 0 to 64 step 16 { + affine.for %arg5 = 0 to 64 { + %0 = vector.transfer_read %memref0[%arg3, %arg5], %c0_i32 {permutation_map = affine_map<(d0, d1) -> (0)>} : memref<64x64xi32>, vector<16xi32> + %1 = vector.transfer_read %memref1[%arg5, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32> + %2 = vector.transfer_read %memref2[%arg3, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32> + %3 = arith.muli %0, %1 : vector<16xi32> + %4 = arith.addi %2, %3 : vector<16xi32> + vector.transfer_write %4, %memref2[%arg3, %arg4] : vector<16xi32>, memref<64x64xi32> + } + } + } + return +}