diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -14,6 +14,13 @@ #include "llvm/ADT/SmallBitVector.h" namespace mlir { +namespace vector { + +class TransferReadOp; +class TransferWriteOp; + +} // namespace vector + namespace linalg { struct LinalgTilingOptions; @@ -437,6 +444,50 @@ LinalgLoweringType loweringType; }; +//===----------------------------------------------------------------------===// +// Op-specific patterns. +//===----------------------------------------------------------------------===// +/// Match and rewrite for the pattern: +/// ``` +/// %alloc = ... +/// [optional] %view = std.view %alloc ... +/// %subView = subview %A ... +/// [optional] linalg.fill(%allocOrView, %cst) ... +/// linalg.copy(%in, %subView) ... +/// vector.transfer_read %allocOrView[...], %cst ... +/// ``` +/// Where there is no interleaved use between linalg.copy and transfer_read as +/// well as no interleaved use between linalg.fill and linalg.copy (if +/// linalg.fill is specified). +/// This is a custom rewrite to forward partial reads (with optional fills) to +/// vector.transfer_read. +struct LinalgCopyVTRForwardingPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, + PatternRewriter &rewriter) const override; +}; + +/// Match and rewrite for the pattern: +/// ``` +/// %alloc = ... +/// [optional] %view = std.view %alloc ... +/// %subView = subview %allocOrView... +/// ... +/// vector.transfer_write %..., %allocOrView[...] +/// linalg.copy(%subView, %out) +/// ``` +/// Where there is no interleaved use between transfer_write and linalg.copy. +/// This is a custom rewrite to forward partial writes to vector.transfer_write. +struct LinalgCopyVTWForwardingPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, + PatternRewriter &rewriter) const override; +}; + //===----------------------------------------------------------------------===// // Support for staged pattern application. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -103,12 +103,13 @@ llvm_unreachable("Unexpected conv with padding"); } + StringRef dbgPref = "\n[" DEBUG_TYPE "]: "; + (void)dbgPref; edsc::ScopedContext scope(builder, op->getLoc()); if (auto fillOp = dyn_cast(op)) { // Vectorize fill as a vector.broadcast. - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE - "]: Rewrite linalg.fill as vector.broadcast: " - << *op << ":\n"); + LLVM_DEBUG(dbgs() << dbgPref + << "Rewrite linalg.fill as vector.broadcast: " << *op); Value memref = vector_type_cast(fillOp.getOutputBuffer(0)); Value dst = std_load(memref); Value res = vector_broadcast(dst.getType(), fillOp.value()); @@ -117,9 +118,8 @@ } // Vectorize other ops as vector contraction (currently only matmul). - LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE - "]: Rewrite linalg op as vector.contract: " - << *op << ":\n"); + LLVM_DEBUG(dbgs() << dbgPref + << "Rewrite linalg op as vector.contract: " << *op); auto linalgOp = cast(op); Value a = std_load(vector_type_cast(linalgOp.getInput(0))); Value b = std_load(vector_type_cast(linalgOp.getInput(1))); @@ -129,3 +129,170 @@ linalgOp.iterator_types()); std_store(res, memref); } + +/// Check whether there is any interleaved use of any `values` between `firstOp` +/// and `secondOp`. Conservatively return `true` if any op or value is in a +/// different block. +static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, + ValueRange values) { + StringRef dbgPref = "\n[" DEBUG_TYPE "]: "; + (void)dbgPref; + if (firstOp->getBlock() != secondOp->getBlock() || + !firstOp->isBeforeInBlock(secondOp)) { + LLVM_DEBUG(llvm::dbgs() + << dbgPref << "interleavedUses precondition failed, firstOp: " + << *firstOp << ", second op: " << *secondOp); + return true; + } + for (auto v : values) { + for (auto &u : v.getUses()) { + Operation *owner = u.getOwner(); + if (owner == firstOp || owner == secondOp) + continue; + // TODO: this is too conservative, use dominance info in the future. + if (owner->getBlock() == firstOp->getBlock() && + (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner))) + continue; + LLVM_DEBUG(llvm::dbgs() + << dbgPref << " found interleaved op " << *owner + << ", firstOp: " << *firstOp << ", second op: " << *secondOp); + return true; + } + } + return false; +} + +/// In the future this will evolve to use interfaces, side-effect modeling and +/// aliasing. +LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite( + vector::TransferReadOp xferOp, PatternRewriter &rewriter) const { + + // Transfer into `view`. + Value viewOrAlloc = xferOp.memref(); + if (!viewOrAlloc.getDefiningOp() && + !viewOrAlloc.getDefiningOp()) + return failure(); + + StringRef dbgPref = "\n[" DEBUG_TYPE "]: VTRForwarding: "; + (void)dbgPref; + LLVM_DEBUG(llvm::dbgs() << dbgPref << viewOrAlloc); + + // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. + SubViewOp subViewOp; + for (auto &u : viewOrAlloc.getUses()) { + if (auto newSubViewOp = dyn_cast(u.getOwner())) { + if (subViewOp) + return failure(); + else + subViewOp = newSubViewOp; + } + } + if (!subViewOp) + return failure(); + Value subView = subViewOp.getResult(); + LLVM_DEBUG(llvm::dbgs() << dbgPref << "with subView " << subView); + + // Find the copy into `subView` without interleaved uses. + CopyOp copyOp; + for (auto &u : subView.getUses()) { + if (auto newCopyOp = dyn_cast(u.getOwner())) { + if (newCopyOp.getOutputBuffer(0) != subView) + continue; + LLVM_DEBUG(llvm::dbgs() << dbgPref << "copy candidate " << *newCopyOp); + if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView})) + continue; + copyOp = newCopyOp; + break; + } + } + if (!copyOp) + return failure(); + LLVM_DEBUG(llvm::dbgs() << dbgPref << "with copy " << *copyOp); + + // Find the fill into `viewOrAlloc` without interleaved uses before the copy. + FillOp maybeFillOp; + for (auto &u : viewOrAlloc.getUses()) { + if (auto newFillOp = dyn_cast(u.getOwner())) { + if (newFillOp.getOutputBuffer(0) != viewOrAlloc) + continue; + LLVM_DEBUG(llvm::dbgs() << dbgPref << "fill candidate " << *newFillOp); + if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView})) + continue; + maybeFillOp = newFillOp; + break; + } + } + // Ensure padding matches. + if (maybeFillOp && xferOp.padding() != maybeFillOp.value()) + return failure(); + if (maybeFillOp) + LLVM_DEBUG(llvm::dbgs() << dbgPref << "with maybeFillOp " << *maybeFillOp); + + // `in` is the subview that linalg.copy reads. Replace it. + Value in = copyOp.getInput(0); + + Value res = rewriter.create( + xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(), + xferOp.permutation_map(), xferOp.padding(), + xferOp.masked() ? *xferOp.masked() : ArrayAttr()); + + if (maybeFillOp) + rewriter.eraseOp(maybeFillOp); + rewriter.eraseOp(copyOp); + rewriter.replaceOp(xferOp, res); + + return success(); +} + +/// In the future this will evolve to use interfaces, side-effect modeling and +/// aliasing. +LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite( + vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const { + // Transfer into `viewOrAlloc`. + Value viewOrAlloc = xferOp.memref(); + if (!viewOrAlloc.getDefiningOp() && + !viewOrAlloc.getDefiningOp()) + return failure(); + + // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`. + SubViewOp subViewOp; + for (auto &u : viewOrAlloc.getUses()) { + if (auto newSubViewOp = dyn_cast(u.getOwner())) { + if (subViewOp) + return failure(); + subViewOp = newSubViewOp; + } + } + if (!subViewOp) + return failure(); + Value subView = subViewOp.getResult(); + + // Find the copy from `subView` without interleaved uses. + CopyOp copyOp; + for (auto &u : subViewOp.getResult().getUses()) { + if (auto newCopyOp = dyn_cast(u.getOwner())) { + if (newCopyOp.getInput(0) != subView) + continue; + if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView})) + continue; + copyOp = newCopyOp; + break; + } + } + if (!copyOp) + return failure(); + + // `out` is the subview copied into that we replace. + Value out = copyOp.getOutputBuffer(0); + + // Forward vector.transfer into copy. + rewriter.create( + xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(), + xferOp.permutation_map(), + xferOp.masked() ? *xferOp.masked() : ArrayAttr()); + + rewriter.eraseOp(copyOp); + rewriter.eraseOp(xferOp); + + return success(); +} diff --git a/mlir/test/Dialect/Linalg/forward-vector-transfers.mlir b/mlir/test/Dialect/Linalg/forward-vector-transfers.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/forward-vector-transfers.mlir @@ -0,0 +1,153 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect -test-linalg-transform-patterns=test-vector-transfer-forwarding-patterns | FileCheck %s + +// CHECK-LABEL: testAllocRead +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref +// CHECK-NOT: linalg.fill +// CHECK-NOT: linalg.copy +// CHECK: %[[ALLOC:.*]] = alloc +// CHECK: vector.transfer_read %[[ARG0]] +func @testAllocRead(%in: memref) -> vector<32 x f32> { + %c0 = constant 0: index + %f0 = constant 0.0: f32 + %alloc = alloc() : memref<32 x f32> + %subview = subview %alloc[0][16][1] : memref<32 x f32> to memref<16 x f32> + linalg.copy(%in, %subview): memref, memref<16 x f32> + %0 = vector.transfer_read %alloc[%c0], %f0: memref<32 x f32>, vector<32 x f32> + dealloc %alloc : memref<32 x f32> + return %0: vector<32 x f32> +} + +// CHECK-LABEL: testAllocFillRead +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref +// CHECK-NOT: linalg.fill +// CHECK-NOT: linalg.copy +// CHECK: %[[ALLOC:.*]] = alloc +// CHECK: vector.transfer_read %[[ARG0]] +func @testAllocFillRead(%in: memref) -> vector<32 x f32> { + %c0 = constant 0: index + %f0 = constant 0.0: f32 + %alloc = alloc() : memref<32 x f32> + linalg.fill(%alloc, %f0): memref<32 x f32>, f32 + %subview = subview %alloc[0][16][1] : memref<32 x f32> to memref<16 x f32> + linalg.copy(%in, %subview): memref, memref<16 x f32> + %0 = vector.transfer_read %alloc[%c0], %f0: memref<32 x f32>, vector<32 x f32> + dealloc %alloc : memref<32 x f32> + return %0: vector<32 x f32> +} + +// CHECK-LABEL: testViewRead +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref +// CHECK-NOT: linalg.fill +// CHECK-NOT: linalg.copy +// CHECK: %[[ALLOC:.*]] = alloc +// CHECK: vector.transfer_read %[[ARG0]] +func @testViewRead(%in: memref) -> vector<32 x f32> { + %c0 = constant 0: index + %f0 = constant 0.0: f32 + %alloc = alloc() : memref<128 x i8> + %view = view %alloc[%c0][] : memref<128 x i8> to memref<32 x f32> + %subview = subview %view[0][16][1] : memref<32 x f32> to memref<16 x f32> + linalg.copy(%in, %subview): memref, memref<16 x f32> + %0 = vector.transfer_read %view[%c0], %f0: memref<32 x f32>, vector<32 x f32> + dealloc %alloc : memref<128 x i8> + return %0: vector<32 x f32> +} + +// CHECK-LABEL: testViewFillRead +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref +// CHECK-NOT: linalg.fill +// CHECK-NOT: linalg.copy +// CHECK: %[[ALLOC:.*]] = alloc +// CHECK: vector.transfer_read %[[ARG0]] +func @testViewFillRead(%in: memref) -> vector<32 x f32> { + %c0 = constant 0: index + %f0 = constant 0.0: f32 + %alloc = alloc() : memref<128 x i8> + %view = view %alloc[%c0][] : memref<128 x i8> to memref<32 x f32> + %subview = subview %view[0][16][1] : memref<32 x f32> to memref<16 x f32> + linalg.fill(%view, %f0): memref<32 x f32>, f32 + linalg.copy(%in, %subview): memref, memref<16 x f32> + %0 = vector.transfer_read %view[%c0], %f0: memref<32 x f32>, vector<32 x f32> + dealloc %alloc : memref<128 x i8> + return %0: vector<32 x f32> +} + +// CHECK-LABEL: testAllocWrite +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: vector +// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: memref +// CHECK-NOT: linalg.copy +// CHECK: %[[ALLOC:.*]] = alloc +// CHECK: vector.transfer_write %[[ARG0]], %[[ARG1]] +func @testAllocWrite(%vec: vector<32 x f32>, %out: memref) { + %c0 = constant 0: index + %f0 = constant 0.0: f32 + %alloc = alloc() : memref<32 x f32> + %subview = subview %alloc[0][16][1] : memref<32 x f32> to memref<16 x f32> + vector.transfer_write %vec, %alloc[%c0] : vector<32 x f32>, memref<32 x f32> + linalg.copy(%subview, %out): memref<16 x f32>, memref + dealloc %alloc : memref<32 x f32> + return +} + +// CHECK-LABEL: testViewWrite +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: vector +// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: memref +// CHECK-NOT: linalg.copy +// CHECK: %[[ALLOC:.*]] = alloc +// CHECK: vector.transfer_write %[[ARG0]], %[[ARG1]] +func @testViewWrite(%vec: vector<32 x f32>, %out: memref) { + %c0 = constant 0: index + %f0 = constant 0.0: f32 + %alloc = alloc() : memref<128 x i8> + %view = view %alloc[%c0][] : memref<128 x i8> to memref<32 x f32> + %subview = subview %view[0][16][1] : memref<32 x f32> to memref<16 x f32> + vector.transfer_write %vec, %view[%c0] : vector<32 x f32>, memref<32 x f32> + linalg.copy(%subview, %out): memref<16 x f32>, memref + dealloc %alloc : memref<128 x i8> + return +} + +///===--------------------------------------------------------------------===/// +// Negative tests +///===--------------------------------------------------------------------===/// + +// This should fail the rewrite due to mismatching fill and transfer read value. +// CHECK-LABEL: failAllocFillRead +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref +// CHECK-NOT: vector.transfer_read %[[ARG0]] +// CHECK: %[[ALLOC:.*]] = alloc +// CHECK: linalg.copy +// CHECK: vector.transfer_read %[[ALLOC]] +func @failAllocFillRead(%in: memref) -> vector<32 x f32> { + %c0 = constant 0: index + %f0 = constant 0.0: f32 + %f1 = constant 1.0: f32 + %alloc = alloc() : memref<32 x f32> + linalg.fill(%alloc, %f0): memref<32 x f32>, f32 + %subview = subview %alloc[0][16][1] : memref<32 x f32> to memref<16 x f32> + linalg.copy(%in, %subview): memref, memref<16 x f32> + "some_interleaved_use"(%subview) : (memref<16 x f32>) -> () + %0 = vector.transfer_read %alloc[%c0], %f1: memref<32 x f32>, vector<32 x f32> + dealloc %alloc : memref<32 x f32> + return %0: vector<32 x f32> +} + +// This should fail the rewrite due to some interleaved use. +// CHECK-LABEL: failAllocWrite +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: vector +// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: memref +// CHECK-NOT: vector.transfer_write %[[ARG0]], %[[ARG1]] +// CHECK: %[[ALLOC:.*]] = alloc +// CHECK: vector.transfer_write %[[ARG0]], %[[ALLOC]] +// CHECK: linalg.copy +func @failAllocWrite(%vec: vector<32 x f32>, %out: memref) { + %c0 = constant 0: index + %f0 = constant 0.0: f32 + %alloc = alloc() : memref<32 x f32> + %subview = subview %alloc[0][16][1] : memref<32 x f32> to memref<16 x f32> + vector.transfer_write %vec, %alloc[%c0] : vector<32 x f32>, memref<32 x f32> + "some_interleaved_use"(%subview) : (memref<16 x f32>) -> () + linalg.copy(%subview, %out): memref<16 x f32>, memref + dealloc %alloc : memref<32 x f32> + return +} diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -48,6 +49,11 @@ Option testPromotionOptions{*this, "test-linalg-promotion-options", llvm::cl::desc("Test promotion options"), llvm::cl::init(false)}; + Option testVectorTransferForwardingPatterns{ + *this, "test-vector-transfer-forwarding-patterns", + llvm::cl::desc( + "Test a fused pass that forwards linalg.copy to vector.transfer"), + llvm::cl::init(false)}; }; } // end anonymous namespace @@ -167,19 +173,6 @@ }); } -static OwningRewritePatternList -getMatmulToVectorCanonicalizationPatterns(MLIRContext *context) { - OwningRewritePatternList patterns; - AffineApplyOp::getCanonicalizationPatterns(patterns, context); - AffineMinOp::getCanonicalizationPatterns(patterns, context); - AffineMaxOp::getCanonicalizationPatterns(patterns, context); - AllocOp::getCanonicalizationPatterns(patterns, context); - SubViewOp::getCanonicalizationPatterns(patterns, context); - ViewOp::getCanonicalizationPatterns(patterns, context); - MatmulOp::getCanonicalizationPatterns(patterns, context); - return patterns; -} - static void fillL1TilingAndMatmulToVectorPatterns( FuncOp funcOp, StringRef startMarker, SmallVectorImpl &patternsVector) { @@ -261,40 +254,58 @@ LinalgMarker({"PROMOTE"})); } +static void +applyMatmulToVectorPatterns(FuncOp funcOp, + bool testMatmulToVectorPatterns1dTiling, + bool testMatmulToVectorPatterns2dTiling) { + MLIRContext *ctx = funcOp.getContext(); + SmallVector stage1Patterns; + if (testMatmulToVectorPatterns1dTiling) { + fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns); + } else if (testMatmulToVectorPatterns2dTiling) { + stage1Patterns.emplace_back( + LinalgTilingPattern(ctx, + LinalgTilingOptions() + .setTileSizes({768, 264, 768}) + .setInterchange({1, 2, 0}), + LinalgMarker({"START"}, "L2"))); + fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns); + } + OwningRewritePatternList stage2Patterns = + getLinalgTilingCanonicalizationPatterns(ctx); + applyStagedPatterns(funcOp, stage1Patterns, stage2Patterns); +} + +static void applyVectorTransferForwardingPatterns(FuncOp funcOp) { + OwningRewritePatternList forwardPattern; + forwardPattern.insert(funcOp.getContext()); + forwardPattern.insert(funcOp.getContext()); + applyPatternsAndFoldGreedily(funcOp, forwardPattern); +} + /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnFunction() { - if (testPatterns) { - applyPatterns(getFunction()); - return; - } + auto lambda = [&](void *) { + getFunction().walk([](LinalgOp op) { + op.removeAttr(LinalgTransforms::kLinalgTransformMarker); + }); + }; + std::unique_ptr cleanupGuard{(void *)1, lambda}; + if (testPromotionOptions) { OwningRewritePatternList patterns; fillPromotionCallBackPatterns(&getContext(), patterns); applyPatternsAndFoldGreedily(getFunction(), patterns); - } else { - SmallVector stage1Patterns; - if (testMatmulToVectorPatterns1dTiling) { - fillL1TilingAndMatmulToVectorPatterns(getFunction(), "START", - stage1Patterns); - } else if (testMatmulToVectorPatterns2dTiling) { - stage1Patterns.emplace_back( - LinalgTilingPattern(&getContext(), - LinalgTilingOptions() - .setTileSizes({768, 264, 768}) - .setInterchange({1, 2, 0}), - LinalgMarker({"START"}, "L2"))); - fillL1TilingAndMatmulToVectorPatterns(getFunction(), "L2", - stage1Patterns); - } - OwningRewritePatternList stage2Patterns = - getMatmulToVectorCanonicalizationPatterns(&getContext()); - applyStagedPatterns(getFunction(), stage1Patterns, stage2Patterns); + return; } - - // Drop the marker. - getFunction().walk([](LinalgOp op) { - op.removeAttr(LinalgTransforms::kLinalgTransformMarker); - }); + if (testPatterns) + return applyPatterns(getFunction()); + if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling) + return applyMatmulToVectorPatterns(getFunction(), + testMatmulToVectorPatterns1dTiling, + testMatmulToVectorPatterns2dTiling); + if (testVectorTransferForwardingPatterns) + return applyVectorTransferForwardingPatterns(getFunction()); } namespace mlir {