diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -14,7 +14,9 @@ #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" +#include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -29,6 +31,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" using llvm::dbgs; @@ -425,10 +428,10 @@ LLVM_DEBUG(DBGS() << "Candidate for hoisting: " << *transferRead.getOperation() << "\n"); - auto loop = dyn_cast(transferRead->getParentOp()); + auto loop = dyn_cast(transferRead->getParentOp()); LLVM_DEBUG(DBGS() << "Parent op: " << *transferRead->getParentOp() << "\n"); - if (!loop) + if (!isa_and_nonnull(loop)) return WalkResult::advance(); LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation() @@ -513,18 +516,43 @@ ArrayRef newBBArgs) { return SmallVector{transferWrite.getVector()}; }; - auto newForOp = - replaceLoopWithNewYields(b, loop, transferRead.getVector(), yieldFn); // Transfer write has been hoisted, need to update the written vector by // the value yielded by the newForOp. - transferWrite.getVectorMutable().assign(newForOp.getResults().back()); - - changed = true; - loop.erase(); - // Need to interrupt and restart because erasing the loop messes up the - // walk. - return WalkResult::interrupt(); + return TypeSwitch(loop) + .Case([&](scf::ForOp scfForOp) { + auto newForOp = replaceLoopWithNewYields( + b, scfForOp, transferRead.getVector(), yieldFn); + transferWrite.getVectorMutable().assign( + newForOp.getResults().back()); + changed = true; + loop.erase(); + // Need to interrupt and restart because erasing the loop messes up + // the walk. + return WalkResult::interrupt(); + }) + .Case([&](AffineForOp affineForOp) { + auto newForOp = replaceForOpWithNewYields( + b, affineForOp, transferRead.getVector(), + SmallVector{transferWrite.getVector()}, + transferWrite.getVector()); + // Replace all uses of the `transferRead` with the corresponding + // basic block argument. + transferRead.getVector().replaceUsesWithIf( + newForOp.getLoopBody().getArguments().back(), + [&](OpOperand &use) { + Operation *user = use.getOwner(); + return newForOp->isProperAncestor(user); + }); + transferWrite.getVectorMutable().assign( + newForOp.getResults().back()); + changed = true; + loop.erase(); + // Need to interrupt and restart because erasing the loop messes up + // the walk. + return WalkResult::interrupt(); + }) + .Default([](Operation *) { return WalkResult::interrupt(); }); }); } } diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir --- a/mlir/test/Dialect/Linalg/hoisting.mlir +++ b/mlir/test/Dialect/Linalg/hoisting.mlir @@ -469,3 +469,39 @@ return %1 : tensor } +// ----- + +// CHECK-LABEL: func @hoist_vector_transfer_pairs_in_affine_loops( +// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]+]]: memref<64x64xi32>, +// CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]+]]: memref<64x64xi32>, +// CHECK-SAME: %[[MEMREF2:[a-zA-Z0-9]+]]: memref<64x64xi32>) { +// CHECK: %[[C0:.*]] = arith.constant 0 : i32 +// CHECK: affine.for %[[I:.*]] = 0 to 64 { +// CHECK: affine.for %[[J:.*]] = 0 to 64 step 16 { +// CHECK: %[[R0:.*]] = vector.transfer_read %[[MEMREF2]][%[[I]], %[[J]]], %[[C0]] : memref<64x64xi32>, vector<16xi32> +// CHECK: %[[R:.*]] = affine.for %[[K:.*]] = 0 to 64 iter_args(%[[ACC:.*]] = %[[R0]]) -> (vector<16xi32>) { +// CHECK: %[[AV:.*]] = vector.transfer_read %[[MEMREF0]][%[[I]], %[[K]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32> +// CHECK: %[[BV:.*]] = vector.transfer_read %[[MEMREF1]][%[[K]], %[[J]]], %[[C0]] {{.*}}: memref<64x64xi32>, vector<16xi32> +// CHECK: %[[T0:.*]] = arith.muli %[[AV]], %[[BV]] : vector<16xi32> +// CHECK: %[[T1:.*]] = arith.addi %[[ACC]], %[[T0]] : vector<16xi32> +// CHECK: affine.yield %[[T1]] : vector<16xi32> +// CHECK: } +// CHECK: vector.transfer_write %[[R]], %[[MEMREF2]][%[[I]], %[[J]]] : vector<16xi32>, memref<64x64xi32> +// CHECK: } +// CHECK: } +func.func @hoist_vector_transfer_pairs_in_affine_loops(%memref0: memref<64x64xi32>, %memref1: memref<64x64xi32>, %memref2: memref<64x64xi32>) { + %c0_i32 = arith.constant 0 : i32 + affine.for %arg3 = 0 to 64 { + affine.for %arg4 = 0 to 64 step 16 { + affine.for %arg5 = 0 to 64 { + %0 = vector.transfer_read %memref0[%arg3, %arg5], %c0_i32 {permutation_map = affine_map<(d0, d1) -> (0)>} : memref<64x64xi32>, vector<16xi32> + %1 = vector.transfer_read %memref1[%arg5, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32> + %2 = vector.transfer_read %memref2[%arg3, %arg4], %c0_i32 : memref<64x64xi32>, vector<16xi32> + %3 = arith.muli %0, %1 : vector<16xi32> + %4 = arith.addi %2, %3 : vector<16xi32> + vector.transfer_write %4, %memref2[%arg3, %arg4] : vector<16xi32>, memref<64x64xi32> + } + } + } + return +}