diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h @@ -16,11 +16,25 @@ /// Hoist alloc/dealloc pairs and alloca op out of immediately enclosing /// scf::ForOp if both conditions are true: -/// 1. all operands are defined outside the loop. -/// 2. all uses are ViewLikeOp or DeallocOp. +/// 1. All operands are defined outside the loop. +/// 2. All uses are ViewLikeOp or DeallocOp. // TODO: generalize on a per-need basis. void hoistViewAllocOps(FuncOp func); +/// Hoist vector.transfer_read / vector.transfer_write pairs out of immediately +/// enclosing scf::ForOp iteratively, if the following conditions are true: +/// 1. The 2 ops access the same memref with the same indices. +/// 2. All operands are invariant under the enclosing scf::ForOp. +/// 3. No uses of the memref either dominate the transfer_read or are +/// dominated by the transfer_write (i.e. no aliasing between the write and +/// the read across the loop) +/// To improve hoisting opportunities, call the `moveLoopInvariantCode` helper +/// function on the candidate loop above which to hoist. Hoisting the transfers +/// results in scf::ForOp yielding the value that originally transited through +/// memory. +// TODO: generalize on a per-need basis. +void hoistRedundantVectorTransfers(FuncOp func); + } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Transforms/LoopUtils.h b/mlir/include/mlir/Transforms/LoopUtils.h --- a/mlir/include/mlir/Transforms/LoopUtils.h +++ b/mlir/include/mlir/Transforms/LoopUtils.h @@ -22,10 +22,11 @@ namespace mlir { class AffineForOp; class FuncOp; +class LoopLikeOpInterface; +struct MemRefRegion; class OpBuilder; class Value; class ValueRange; -struct MemRefRegion; namespace scf { class ForOp; @@ -294,6 +295,9 @@ separateFullTiles(MutableArrayRef nest, SmallVectorImpl *fullTileNest = nullptr); +/// Move loop invariant code out of `looplike`. +LogicalResult moveLoopInvariantCode(LoopLikeOpInterface looplike); + } // end namespace mlir #endif // MLIR_TRANSFORMS_LOOP_UTILS_H diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -41,20 +41,24 @@ } if (auto forOp = dyn_cast(op)) { - for (auto *ownerInst : forOp.getInductionVar().getUsers()) - if (forwardSlice->count(ownerInst) == 0) - getForwardSliceImpl(ownerInst, forwardSlice, filter); + for (auto *ownerOp : forOp.getInductionVar().getUsers()) + if (forwardSlice->count(ownerOp) == 0) + getForwardSliceImpl(ownerOp, forwardSlice, filter); } else if (auto forOp = dyn_cast(op)) { - for (auto *ownerInst : forOp.getInductionVar().getUsers()) - if (forwardSlice->count(ownerInst) == 0) - getForwardSliceImpl(ownerInst, forwardSlice, filter); + for (auto *ownerOp : forOp.getInductionVar().getUsers()) + if (forwardSlice->count(ownerOp) == 0) + getForwardSliceImpl(ownerOp, forwardSlice, filter); + for (auto result : forOp.getResults()) + for (auto *ownerOp : result.getUsers()) + if (forwardSlice->count(ownerOp) == 0) + getForwardSliceImpl(ownerOp, forwardSlice, filter); } else { assert(op->getNumRegions() == 0 && "unexpected generic op with regions"); assert(op->getNumResults() <= 1 && "unexpected multiple results"); if (op->getNumResults() > 0) { - for (auto *ownerInst : op->getResult(0).getUsers()) - if (forwardSlice->count(ownerInst) == 0) - getForwardSliceImpl(ownerInst, forwardSlice, filter); + for (auto *ownerOp : op->getResult(0).getUsers()) + if (forwardSlice->count(ownerOp) == 0) + getForwardSliceImpl(ownerOp, forwardSlice, filter); } } @@ -139,15 +143,15 @@ SetVector backwardSlice; SetVector forwardSlice; while (currentIndex != slice.size()) { - auto *currentInst = (slice)[currentIndex]; - // Compute and insert the backwardSlice starting from currentInst. + auto *currentOp = (slice)[currentIndex]; + // Compute and insert the backwardSlice starting from currentOp. backwardSlice.clear(); - getBackwardSlice(currentInst, &backwardSlice, backwardFilter); + getBackwardSlice(currentOp, &backwardSlice, backwardFilter); slice.insert(backwardSlice.begin(), backwardSlice.end()); - // Compute and insert the forwardSlice starting from currentInst. + // Compute and insert the forwardSlice starting from currentOp. forwardSlice.clear(); - getForwardSlice(currentInst, &forwardSlice, forwardFilter); + getForwardSlice(currentOp, &forwardSlice, forwardFilter); slice.insert(forwardSlice.begin(), forwardSlice.end()); ++currentIndex; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -12,10 +12,15 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Transforms/Hoisting.h" +#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/Utils.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/Function.h" +#include "mlir/Transforms/LoopUtils.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Debug.h" @@ -73,3 +78,92 @@ }); } } + +void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) { + bool changed = true; + while (changed) { + changed = false; + + func.walk([&](vector::TransferReadOp transferRead) { + auto loop = transferRead.getParentOfType(); + if (!loop) + return WalkResult::advance(); + + if (failed(moveLoopInvariantCode( + cast(loop.getOperation())))) + llvm_unreachable( + "Unexpected failure to move invariant code out of loop"); + + LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation() + << "\n"); + + llvm::SetVector forwardSlice; + getForwardSlice(transferRead, &forwardSlice); + + // Look for the first TransferWriteOp in the forwardSlice of + // `transferRead`, in the same block and to the same memref. + vector::TransferWriteOp transferWrite; + for (auto *sliceOp : llvm::reverse(forwardSlice)) { + auto candidateWrite = dyn_cast(sliceOp); + if (!candidateWrite || candidateWrite.memref() != transferRead.memref()) + continue; + transferWrite = candidateWrite; + } + + // All operands of the TransferRead must be defined outside of the loop. + for (auto operand : transferRead.getOperands()) + if (!loop.isDefinedOutsideOfLoop(operand)) + return WalkResult::advance(); + + // Only hoist transfer_read / transfer_write pairs for now. + if (!transferWrite) + return WalkResult::advance(); + + LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation() + << "\n"); + + // Approximate aliasing by checking that: + // 1. indices are the same, + // 2. no other use either dominates the transfer_read or is dominated + // by the transfer_write (i.e. aliasing between the write and the read + // across the loop). + if (transferRead.indices() != transferWrite.indices()) + return WalkResult::advance(); + + // TODO: may want to memoize this information for performance but it + // likely gets invalidated often. + DominanceInfo dom(loop); + if (!dom.dominates(transferRead.getOperation(), transferWrite)) + return WalkResult::advance(); + for (auto &use : transferRead.memref().getUses()) + if (dom.properlyDominates(use.getOwner(), + transferRead.getOperation()) || + dom.properlyDominates(transferWrite, use.getOwner())) + return WalkResult::advance(); + + // Hoist read before. + if (failed(loop.moveOutOfLoop({transferRead}))) + llvm_unreachable( + "Unexpected failure to move transfer read out of loop"); + + // Hoist write after. + transferWrite.getOperation()->moveAfter(loop); + + // Rewrite the loop with new yields and mark the old one for erasure. + OpBuilder b(transferRead); + auto newForOp = cloneWithNewYields(b, loop, transferRead.vector(), + transferWrite.vector()); + + // Transfer write has been hoisted, need to update the written value to + // the value yielded by the newForOp. + transferWrite.vector().replaceAllUsesWith( + newForOp.getResults().take_back()[0]); + + changed = true; + loop.erase(); + // Need to interrupt and restart because erasing the loop messes up the + // walk. + return WalkResult::interrupt(); + }); + } +} diff --git a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp --- a/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp +++ b/mlir/lib/Transforms/LoopInvariantCodeMotion.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/Function.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Transforms/LoopUtils.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -73,7 +74,7 @@ return true; } -static LogicalResult moveLoopInvariantCode(LoopLikeOpInterface looplike) { +LogicalResult mlir::moveLoopInvariantCode(LoopLikeOpInterface looplike) { auto &loopBody = looplike.getLoopBody(); // We use two collections here as we need to preserve the order for insertion diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir --- a/mlir/test/Dialect/Linalg/hoisting.mlir +++ b/mlir/test/Dialect/Linalg/hoisting.mlir @@ -1,12 +1,13 @@ -// RUN: mlir-opt %s -test-linalg-hoisting=test-hoist-view-allocs | FileCheck %s +// RUN: mlir-opt %s -test-linalg-hoisting=test-hoist-view-allocs -allow-unregistered-dialect | FileCheck %s +// RUN: mlir-opt %s -test-linalg-hoisting=test-hoist-redundant-transfers -allow-unregistered-dialect | FileCheck %s --check-prefix=VECTOR_TRANSFERS -// CHECK-LABEL: func @hoist( +// CHECK-LABEL: func @hoist_allocs( // CHECK-SAME: %[[VAL:[a-zA-Z0-9]*]]: index, // CHECK-SAME: %[[LB:[a-zA-Z0-9]*]]: index, // CHECK-SAME: %[[UB:[a-zA-Z0-9]*]]: index, // CHECK-SAME: %[[STEP:[a-zA-Z0-9]*]]: index, // CHECK-SAME: %[[CMP:[a-zA-Z0-9]*]]: i1 -func @hoist(%val: index, %lb : index, %ub : index, %step: index, %cmp: i1) { +func @hoist_allocs(%val: index, %lb : index, %ub : index, %step: index, %cmp: i1) { // CHECK-DAG: alloca(%[[VAL]]) : memref // CHECK-DAG: %[[A0:.*]] = alloc(%[[VAL]]) : memref scf.for %i = %lb to %ub step %step { @@ -83,3 +84,69 @@ // CHECK: dealloc %[[A0]] : memref return } + +// VECTOR_TRANSFERS-LABEL: func @hoist_vector_transfer_pairs( +// VECTOR_TRANSFERS-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref, +// VECTOR_TRANSFERS-SAME: %[[MEMREF1:[a-zA-Z0-9]*]]: memref, +// VECTOR_TRANSFERS-SAME: %[[MEMREF2:[a-zA-Z0-9]*]]: memref, +// VECTOR_TRANSFERS-SAME: %[[MEMREF3:[a-zA-Z0-9]*]]: memref, +// VECTOR_TRANSFERS-SAME: %[[MEMREF4:[a-zA-Z0-9]*]]: memref, +// VECTOR_TRANSFERS-SAME: %[[MEMREF5:[a-zA-Z0-9]*]]: memref, +// VECTOR_TRANSFERS-SAME: %[[VAL:[a-zA-Z0-9]*]]: index, +// VECTOR_TRANSFERS-SAME: %[[LB:[a-zA-Z0-9]*]]: index, +// VECTOR_TRANSFERS-SAME: %[[UB:[a-zA-Z0-9]*]]: index, +// VECTOR_TRANSFERS-SAME: %[[STEP:[a-zA-Z0-9]*]]: index, +// VECTOR_TRANSFERS-SAME: %[[CMP:[a-zA-Z0-9]*]]: i1 +func @hoist_vector_transfer_pairs( + %memref0: memref, %memref1: memref, %memref2: memref, + %memref3: memref, %memref4: memref, %memref5: memref, + %val: index, %lb : index, %ub : index, %step: index, %cmp: i1) { + %c0 = constant 0 : index + %cst = constant 0.0 : f32 + +// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : memref, vector<1xf32> +// VECTOR_TRANSFERS: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>) { +// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : memref, vector<2xf32> +// VECTOR_TRANSFERS: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<2xf32>) { +// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : memref, vector<3xf32> +// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : memref, vector<4xf32> +// VECTOR_TRANSFERS: "some_crippling_use"(%[[MEMREF4]]) : (memref) -> () +// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : memref, vector<5xf32> +// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32> +// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32> +// VECTOR_TRANSFERS: "some_use"(%[[MEMREF2]]) : (memref) -> vector<3xf32> +// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32> +// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<5xf32>) -> vector<5xf32> +// VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<3xf32>, memref +// VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<4xf32>, memref +// VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<5xf32>, memref +// VECTOR_TRANSFERS: "some_crippling_use"(%[[MEMREF3]]) : (memref) -> () +// VECTOR_TRANSFERS: scf.yield {{.*}} : vector<1xf32>, vector<2xf32> +// VECTOR_TRANSFERS: } +// VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<2xf32>, memref +// VECTOR_TRANSFERS: scf.yield {{.*}} : vector<1xf32> +// VECTOR_TRANSFERS: } +// VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<1xf32>, memref + scf.for %i = %lb to %ub step %step { + scf.for %j = %lb to %ub step %step { + %r0 = vector.transfer_read %memref1[%c0, %c0], %cst: memref, vector<1xf32> + %r1 = vector.transfer_read %memref0[%i, %i], %cst: memref, vector<2xf32> + %r2 = vector.transfer_read %memref2[%c0, %c0], %cst: memref, vector<3xf32> + %r3 = vector.transfer_read %memref3[%c0, %c0], %cst: memref, vector<4xf32> + "some_crippling_use"(%memref4) : (memref) -> () + %r4 = vector.transfer_read %memref4[%c0, %c0], %cst: memref, vector<5xf32> + %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32> + %u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32> + %u2 = "some_use"(%memref2) : (memref) -> vector<3xf32> + %u3 = "some_use"(%r3) : (vector<4xf32>) -> vector<4xf32> + %u4 = "some_use"(%r4) : (vector<5xf32>) -> vector<5xf32> + vector.transfer_write %u0, %memref1[%c0, %c0] : vector<1xf32>, memref + vector.transfer_write %u1, %memref0[%i, %i] : vector<2xf32>, memref + vector.transfer_write %u2, %memref2[%c0, %c0] : vector<3xf32>, memref + vector.transfer_write %u3, %memref3[%c0, %c0] : vector<4xf32>, memref + vector.transfer_write %u4, %memref4[%c0, %c0] : vector<5xf32>, memref + "some_crippling_use"(%memref3) : (memref) -> () + } + } + return +} diff --git a/mlir/test/lib/Transforms/TestLinalgHoisting.cpp b/mlir/test/lib/Transforms/TestLinalgHoisting.cpp --- a/mlir/test/lib/Transforms/TestLinalgHoisting.cpp +++ b/mlir/test/lib/Transforms/TestLinalgHoisting.cpp @@ -29,6 +29,10 @@ *this, "test-hoist-view-allocs", llvm::cl::desc("Test hoisting alloc used by view"), llvm::cl::init(false)}; + Option testHoistRedundantTransfers{ + *this, "test-hoist-redundant-transfers", + llvm::cl::desc("Test hoisting transfer_read/transfer_write pairs"), + llvm::cl::init(false)}; }; } // end anonymous namespace @@ -37,6 +41,10 @@ hoistViewAllocOps(getFunction()); return; } + if (testHoistRedundantTransfers) { + hoistRedundantVectorTransfers(getFunction()); + return; + } } namespace mlir {