Index: mlir/include/mlir/Dialect/Vector/VectorTransforms.h =================================================================== --- mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -267,6 +267,12 @@ FilterConstraintType filter; }; +/// Implements transfer op write to read forwarding and dead transfer write +/// optimization. This analysis transfer ops to detect cases where we read a +/// value written by previous transfer write allowing us to remove the transfer +/// read. +void transferOpflowOpt(FuncOp func); + } // namespace vector //===----------------------------------------------------------------------===// Index: mlir/include/mlir/Dialect/Vector/VectorUtils.h =================================================================== --- mlir/include/mlir/Dialect/Vector/VectorUtils.h +++ mlir/include/mlir/Dialect/Vector/VectorUtils.h @@ -25,6 +25,7 @@ class Operation; class Value; class VectorType; +class VectorTransferOpInterface; /// Return the number of elements of basis, `0` if empty. int64_t computeMaxLinearIndex(ArrayRef basis); @@ -159,6 +160,11 @@ AffineMap getTransferMinorIdentityMap(MemRefType memRefType, VectorType vectorType); +/// Return true if we can prove that the transfer operations access disjoint +/// memory. +bool areTransferOpsDisjoint(VectorTransferOpInterface transferA, + VectorTransferOpInterface transferB); + namespace matcher { /// Matches vector.transfer_read, vector.transfer_write and ops that return a Index: mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/SCF/Utils.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/Vector/VectorUtils.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/Function.h" #include "mlir/Transforms/LoopUtils.h" @@ -80,42 +81,6 @@ } } -/// Return true if we can prove that the transfer operations access disjoint -/// memory. -static bool isDisjoint(VectorTransferOpInterface transferA, - VectorTransferOpInterface transferB) { - if (transferA.memref() != transferB.memref()) - return false; - // For simplicity only look at transfer of same type. - if (transferA.getVectorType() != transferB.getVectorType()) - return false; - unsigned rankOffset = transferA.getLeadingMemRefRank(); - for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) { - auto indexA = transferA.indices()[i].getDefiningOp(); - auto indexB = transferB.indices()[i].getDefiningOp(); - // If any of the indices are dynamic we cannot prove anything. - if (!indexA || !indexB) - continue; - - if (i < rankOffset) { - // For dimension used as index if we can prove that index are different we - // know we are accessing disjoint slices. - if (indexA.getValue().cast().getInt() != - indexB.getValue().cast().getInt()) - return true; - } else { - // For this dimension, we slice a part of the memref we need to make sure - // the intervals accessed don't overlap. - int64_t distance = - std::abs(indexA.getValue().cast().getInt() - - indexB.getValue().cast().getInt()); - if (distance >= transferA.getVectorType().getDimSize(i - rankOffset)) - return true; - } - } - return false; -} - void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) { bool changed = true; while (changed) { @@ -185,14 +150,14 @@ continue; if (auto transferWriteUse = dyn_cast(use.getOwner())) { - if (!isDisjoint( + if (!areTransferOpsDisjoint( cast(transferWrite.getOperation()), cast( transferWriteUse.getOperation()))) return WalkResult::advance(); } else if (auto transferReadUse = dyn_cast(use.getOwner())) { - if (!isDisjoint( + if (!areTransferOpsDisjoint( cast(transferWrite.getOperation()), cast( transferReadUse.getOperation()))) Index: mlir/lib/Dialect/Vector/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/Vector/CMakeLists.txt +++ mlir/lib/Dialect/Vector/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRVector VectorOps.cpp + VectorTransferOpTransforms.cpp VectorTransforms.cpp VectorUtils.cpp EDSC/Builders.cpp Index: mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp @@ -0,0 +1,227 @@ +//===- VectorTransferOpTransforms.cpp - transfer op transforms ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements functions concerned with optimizing transfer_read and +// transfer_write ops. +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/Vector/VectorTransforms.h" +#include "mlir/Dialect/Vector/VectorUtils.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/Function.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "vector-transfer-opt" + +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") + +using namespace mlir; + +/// return the ancestor op in the region or nullptr if the region is not +/// an ancestor of the op. +static Operation *findAncestorOpInRegion(Region *region, Operation *op) { + while (op != nullptr && op->getParentRegion() != region) + op = op->getParentOp(); + return op; +} + +namespace { +class TransferOpOptimization { +public: + TransferOpOptimization(FuncOp func) + : dominators(func), postDominators(func) {} + void deadStoreOp(vector::TransferWriteOp); + void storeToLoadForwarding(vector::TransferReadOp); + void removeDeadOp() { + for (Operation *op : opToErase) + op->erase(); + opToErase.clear(); + } + +private: + bool isReachable(Operation *start, Operation *dest); + DominanceInfo dominators; + PostDominanceInfo postDominators; + std::vector opToErase; +}; + +/// Return true if there is a path from start operation to dest operation +/// otherwise return false. The operations have to be in the same region. +bool TransferOpOptimization::isReachable(Operation *start, Operation *dest) { + assert(start->getParentRegion() == dest->getParentRegion() && + "This function only works for ops i the same region"); + // Simple case where the start op dominate the destination. + if (dominators.dominates(start, dest)) + return true; + Block *startBlock = start->getBlock(); + Block *destBlock = dest->getBlock(); + SmallVector worklist(startBlock->succ_begin(), + startBlock->succ_end()); + SmallPtrSet visited; + while (!worklist.empty()) { + Block *bb = worklist.pop_back_val(); + if (!visited.insert(bb).second) + continue; + if (dominators.dominates(bb, destBlock)) + return true; + worklist.append(bb->succ_begin(), bb->succ_end()); + } + return false; +} + +/// For transfer_write to overwrite fully another transfer_write must: +/// 1. Access the same memref with the same indices and vector type. +/// 2. Post-dominate the other transfer_write operation. +/// If several candidates are available, one must be post-dominated by all the +/// others since they are all post-dominating the same transfer_write. We only +/// consider the transfer_write post-dominated by all the other candidates as +/// this will be the first transfer_write executed after the potetially dead +/// transfer_write. +/// If we found such a we know that the original transfer_write is +/// dead if all reads that can be reached from the potentially dead +/// transfer_write are dominated by the overwritting transfer_write. +void TransferOpOptimization::deadStoreOp(vector::TransferWriteOp write) { + LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation() + << "\n"); + llvm::SmallVector reads; + Operation *firstOverwriteCandidate = nullptr; + for (auto *user : write.memref().getUsers()) { + if (user == write.getOperation()) + continue; + if (auto nextWrite = dyn_cast(user)) { + // Check candidate that can override the store. + if (write.indices() == nextWrite.indices() && + write.getVectorType() == nextWrite.getVectorType() && + write.permutation_map() == write.permutation_map() && + postDominators.postDominates(nextWrite, write)) { + if (firstOverwriteCandidate == nullptr || + postDominators.postDominates(firstOverwriteCandidate, nextWrite)) + firstOverwriteCandidate = nextWrite; + else + assert( + postDominators.postDominates(nextWrite, firstOverwriteCandidate)); + } + continue; + } + if (auto read = dyn_cast(user)) { + // Don't need to consider disjoint reads. + if (areTransferOpsDisjoint( + cast(write.getOperation()), + cast(read.getOperation()))) + continue; + } + reads.push_back(user); + } + if (firstOverwriteCandidate == nullptr) + return; + Region *topRegion = firstOverwriteCandidate->getParentRegion(); + Operation *writeAncestor = findAncestorOpInRegion(topRegion, write); + assert(writeAncestor && + "write op should be recursively part of the top region"); + + for (Operation *read : reads) { + Operation *readAncestor = findAncestorOpInRegion(topRegion, read); + // TODO: if the read and write have the same ancestor we could recurse in + // the region to know if the read is reachable with more precision. + if (readAncestor == nullptr || !isReachable(writeAncestor, readAncestor)) + continue; + if (!dominators.dominates(firstOverwriteCandidate, read)) { + LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " << *read + << "\n"); + return; + } + } + LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation() + << " overwritten by: " << *firstOverwriteCandidate << "\n"); + opToErase.push_back(write.getOperation()); +} + +/// A transfer_write candidate to storeToLoad forwarding must: +/// 1. Access the same memref with the same indices and vector type as the +/// transfer_read. +/// 2. Dominate the transfer_read operation. +/// If several candidates are available, one must be dominated by all the others +/// since they are all dominating the same transfer_read. We only consider the +/// transfer_write dominated by all the other candidates as this will be the +/// last transfer_write executed before the transfer_read. +/// If we found such a candidate we can do the forwarding if all the other +/// potentially aliasing ops that may reach the transfer_read are post-domiated +/// by the transfer_write. +void TransferOpOptimization::storeToLoadForwarding( + vector::TransferReadOp read) { + if (read.hasMaskedDim()) + return; + LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation() + << "\n"); + SmallVector blockingWrites; + vector::TransferWriteOp lastwrite = nullptr; + for (Operation *user : read.memref().getUsers()) { + if (isa(user)) + continue; + if (auto write = dyn_cast(user)) { + // If there is a write but we can prove that it is disjoint we can ignore + // it. + if (areTransferOpsDisjoint( + cast(write.getOperation()), + cast(read.getOperation()))) + continue; + if (dominators.dominates(write, read) && !write.hasMaskedDim() && + write.indices() == read.indices() && + write.getVectorType() == read.getVectorType() && + write.permutation_map() == read.permutation_map()) { + if (lastwrite == nullptr || dominators.dominates(lastwrite, write)) + lastwrite = write; + else + assert(dominators.dominates(write, lastwrite)); + continue; + } + } + blockingWrites.push_back(user); + } + + if (lastwrite == nullptr) + return; + + Region *topRegion = lastwrite.getParentRegion(); + Operation *readAncestor = findAncestorOpInRegion(topRegion, read); + assert(readAncestor && + "read op should be recursively part of the top region"); + + for (Operation *write : blockingWrites) { + Operation *writeAncestor = findAncestorOpInRegion(topRegion, write); + // TODO: if the store and read have the same ancestor we could recurse in + // the region to know if the read is reachable with more precision. + if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor)) + continue; + if (!postDominators.postDominates(lastwrite, write)) { + LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: " + << *write << "\n"); + return; + } + } + + LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation() + << " to: " << *read.getOperation() << "\n"); + read.replaceAllUsesWith(lastwrite.vector()); + opToErase.push_back(read.getOperation()); +} +} // namespace + +void mlir::vector::transferOpflowOpt(FuncOp func) { + TransferOpOptimization opt(func); + // Run store to load forwarding first has it can expose more dead store + // oportuinity. + func.walk( + [&](vector::TransferReadOp read) { opt.storeToLoadForwarding(read); }); + opt.removeDeadOp(); + func.walk([&](vector::TransferWriteOp write) { opt.deadStoreOp(write); }); + opt.removeDeadOp(); +} Index: mlir/lib/Dialect/Vector/VectorUtils.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorUtils.cpp +++ mlir/lib/Dialect/Vector/VectorUtils.cpp @@ -312,3 +312,36 @@ return true; } +bool mlir::areTransferOpsDisjoint(VectorTransferOpInterface transferA, + VectorTransferOpInterface transferB) { + if (transferA.memref() != transferB.memref()) + return false; + // For simplicity only look at transfer of same type. + if (transferA.getVectorType() != transferB.getVectorType()) + return false; + unsigned rankOffset = transferA.getLeadingMemRefRank(); + for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) { + auto indexA = transferA.indices()[i].getDefiningOp(); + auto indexB = transferB.indices()[i].getDefiningOp(); + // If any of the indices are dynamic we cannot prove anything. + if (!indexA || !indexB) + continue; + + if (i < rankOffset) { + // For dimension used as index if we can prove that index are different we + // know we are accessing disjoint slices. + if (indexA.getValue().cast().getInt() != + indexB.getValue().cast().getInt()) + return true; + } else { + // For this dimension, we slice a part of the memref we need to make sure + // the intervals accessed don't overlap. + int64_t distance = + std::abs(indexA.getValue().cast().getInt() - + indexB.getValue().cast().getInt()); + if (distance >= transferA.getVectorType().getDimSize(i - rankOffset)) + return true; + } + } + return false; +} Index: mlir/test/Dialect/Vector/vector-transferop-opt.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Vector/vector-transferop-opt.mlir @@ -0,0 +1,161 @@ +// RUN: mlir-opt %s -test-vector-transferop-opt | FileCheck %s + +// CHECK-LABEL: func @forward_dead_store +// CHECK-NOT: vector.transfer_write +// CHECK-NOT: vector.transfer_read +// CHECK: scf.for +// CHECK: } +// CHECK: vector.transfer_write +// CHECK: return +func @forward_dead_store(%arg0: i1, %arg1 : memref<4x4xf32>, %v0 : vector<1x4xf32>, + %v1 : vector<1x4xf32>, %i : index) { + %c1 = constant 1 : index + %c4 = constant 4 : index + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + vector.transfer_write %v0, %arg1[%c1, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<4x4xf32> + %0 = vector.transfer_read %arg1[%c1, %c0], %cf0 {masked = [false, false]} : memref<4x4xf32>, vector<1x4xf32> + %x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0) -> (vector<1x4xf32>) { + %1 = addf %acc, %acc : vector<1x4xf32> + scf.yield %1 : vector<1x4xf32> + } + vector.transfer_write %x, %arg1[%c1, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<4x4xf32> + return +} + +// CHECK-LABEL: func @forward_nested +// CHECK: vector.transfer_write +// CHECK: vector.transfer_write +// CHECK: scf.if +// CHECK-NOT: vector.transfer_read +// CHECK: } +// CHECK: vector.transfer_write +// CHECK: return +func @forward_nested(%arg0: i1, %arg1 : memref<4x4xf32>, %v0 : vector<1x4xf32>, + %v1 : vector<1x4xf32>, %i : index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %cf0 = constant 0.0 : f32 + vector.transfer_write %v1, %arg1[%i, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<4x4xf32> + vector.transfer_write %v0, %arg1[%c1, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<4x4xf32> + %x = scf.if %arg0 -> (vector<1x4xf32>) { + %0 = vector.transfer_read %arg1[%c1, %c0], %cf0 {masked = [false, false]} : memref<4x4xf32>, vector<1x4xf32> + scf.yield %0 : vector<1x4xf32> + } else { + scf.yield %v1 : vector<1x4xf32> + } + vector.transfer_write %x, %arg1[%c0, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<4x4xf32> + return +} + +// Negative test, the transfer_write in the scf.if region block the store to +// load forwarding because we don't recursively look into the region to realize +// that the transfer_write cannot reach the transfer_read. +// CHECK-LABEL: func @forward_nested_negative +// CHECK: vector.transfer_write +// CHECK: scf.if +// CHECK: vector.transfer_read +// CHECK: } else { +// CHECK: vector.transfer_write +// CHECK: } +// CHECK: vector.transfer_write +// CHECK: return +func @forward_nested_negative(%arg0: i1, %arg1 : memref<4x4xf32>, %v0 : vector<1x4xf32>, + %v1 : vector<1x4xf32>, %i : index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %cf0 = constant 0.0 : f32 + vector.transfer_write %v0, %arg1[%c1, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<4x4xf32> + %x = scf.if %arg0 -> (vector<1x4xf32>) { + %0 = vector.transfer_read %arg1[%c1, %c0], %cf0 {masked = [false, false]} : memref<4x4xf32>, vector<1x4xf32> + scf.yield %0 : vector<1x4xf32> + } else { + vector.transfer_write %v1, %arg1[%i, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<4x4xf32> + scf.yield %v1 : vector<1x4xf32> + } + vector.transfer_write %x, %arg1[%c0, %i] {masked = [false, false]} : vector<1x4xf32>, memref<4x4xf32> + return +} + +// CHECK-LABEL: func @dead_store_region +// CHECK: vector.transfer_write +// CHECK: scf.if +// CHECK: } else { +// CHECK: vector.transfer_read +// CHECK: } +// CHECK: scf.if +// CHECK-NOT: vector.transfer_write +// CHECK: } +// CHECK: vector.transfer_write +// CHECK-NOT: vector.transfer_write +// CHECK: vector.transfer_read +// CHECK: return +func @dead_store_region(%arg0: i1, %arg1 : memref<4x4xf32>, %v0 : vector<1x4xf32>, + %v1 : vector<1x4xf32>, %i : index) -> (vector<1x4xf32>) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %cf0 = constant 0.0 : f32 + vector.transfer_write %v0, %arg1[%c1, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<4x4xf32> + %x = scf.if %arg0 -> (vector<1x4xf32>) { + scf.yield %v1 : vector<1x4xf32> + } else { + %0 = vector.transfer_read %arg1[%i, %c0], %cf0 {masked = [false, false]} : memref<4x4xf32>, vector<1x4xf32> + scf.yield %0 : vector<1x4xf32> + } + scf.if %arg0 { + vector.transfer_write %v0, %arg1[%c1, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<4x4xf32> + } + vector.transfer_write %x, %arg1[%c1, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<4x4xf32> + vector.transfer_write %x, %arg1[%c1, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<4x4xf32> + %1 = vector.transfer_read %arg1[%i, %c0], %cf0 {masked = [false, false]} : memref<4x4xf32>, vector<1x4xf32> + return %1 : vector<1x4xf32> +} + +// CHECK-LABEL: func @dead_store_negative +// CHECK: scf.if +// CHECK: vector.transfer_write +// CHECK: vector.transfer_read +// CHECK: } else { +// CHECK: } +// CHECK: vector.transfer_write +// CHECK: return +func @dead_store_negative(%arg0: i1, %arg1 : memref<4x4xf32>, %v0 : vector<1x4xf32>, + %v1 : vector<1x4xf32>, %i : index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %cf0 = constant 0.0 : f32 + %x = scf.if %arg0 -> (vector<1x4xf32>) { + vector.transfer_write %v0, %arg1[%c1, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<4x4xf32> + %0 = vector.transfer_read %arg1[%i, %c0], %cf0 {masked = [false, false]} : memref<4x4xf32>, vector<1x4xf32> + scf.yield %0 : vector<1x4xf32> + } else { + scf.yield %v1 : vector<1x4xf32> + } + vector.transfer_write %x, %arg1[%c1, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<4x4xf32> + return +} + +// CHECK-LABEL: func @dead_store_nested_region +// CHECK: scf.if +// CHECK: vector.transfer_read +// CHECK: scf.if +// CHECK-NOT: vector.transfer_write +// CHECK: } +// CHECK: vector.transfer_write +// CHECK: } +// CHECK: return +func @dead_store_nested_region(%arg0: i1, %arg1: i1, %arg2 : memref<4x4xf32>, %v0 : vector<1x4xf32>, + %v1 : vector<1x4xf32>, %i : index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %cf0 = constant 0.0 : f32 + scf.if %arg0 { + %0 = vector.transfer_read %arg2[%i, %c0], %cf0 {masked = [false, false]} : memref<4x4xf32>, vector<1x4xf32> + scf.if %arg1 { + vector.transfer_write %v1, %arg2[%c1, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<4x4xf32> + } + vector.transfer_write %v0, %arg2[%c1, %c0] {masked = [false, false]} : vector<1x4xf32>, memref<4x4xf32> + } + return +} + Index: mlir/test/lib/Transforms/TestVectorTransforms.cpp =================================================================== --- mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -291,6 +291,11 @@ } }; +struct TestVectorTransferOpt + : public PassWrapper { + void runOnFunction() override { transferOpflowOpt(getFunction()); } +}; + } // end anonymous namespace namespace mlir { @@ -327,6 +332,9 @@ PassRegistration vectorToForLoop( "test-vector-to-forloop", "Test conversion patterns to break up a vector op into a for loop"); + PassRegistration transferOpOpt( + "test-vector-transferop-opt", + "Test optimization transformations for transfer ops"); } } // namespace test } // namespace mlir