Index: mlir/include/mlir/Dialect/Vector/VectorOps.td =================================================================== --- mlir/include/mlir/Dialect/Vector/VectorOps.td +++ mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1421,6 +1421,7 @@ ]; let hasFolder = 1; + let hasCanonicalizer = 1; } def Vector_LoadOp : Vector_Op<"load"> { Index: mlir/include/mlir/Dialect/Vector/VectorUtils.h =================================================================== --- mlir/include/mlir/Dialect/Vector/VectorUtils.h +++ mlir/include/mlir/Dialect/Vector/VectorUtils.h @@ -28,6 +28,11 @@ class VectorType; class VectorTransferOpInterface; +namespace vector { +class TransferWriteOp; +class TransferReadOp; +} // namespace vector + /// Return the number of elements of basis, `0` if empty. int64_t computeMaxLinearIndex(ArrayRef basis); @@ -177,6 +182,16 @@ bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB); +/// Return true if the transfer_write fully writes the data accessed by the +/// transfer_read. +bool transferEncompasses(vector::TransferWriteOp defWrite, + vector::TransferReadOp read); + +/// Return true if the write op fully over-write the priorWrite transfer_write +/// op. +bool transferEncompasses(vector::TransferWriteOp write, + vector::TransferWriteOp priorWrite); + namespace matcher { /// Matches vector.transfer_read, vector.transfer_write and ops that return a Index: mlir/lib/Dialect/Vector/VectorOps.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorOps.cpp +++ mlir/lib/Dialect/Vector/VectorOps.cpp @@ -2512,7 +2512,25 @@ return success(); } +static Value foldTransferWriteIntoTransferRead(TransferReadOp readOp) { + if (!readOp.getShapedType().isa()) + return {}; + auto defWrite = readOp.source().getDefiningOp(); + while (defWrite) { + if (transferEncompasses(defWrite, readOp)) + return defWrite.vector(); + if (!isDisjointTransferIndices( + cast(defWrite.getOperation()), + cast(readOp.getOperation()))) + break; + defWrite = defWrite.source().getDefiningOp(); + } + return {}; +} + OpFoldResult TransferReadOp::fold(ArrayRef) { + if (Value vec = foldTransferWriteIntoTransferRead(*this)) + return vec; /// transfer_read(memrefcast) -> transfer_read if (succeeded(foldTransferInBoundsAttribute(*this))) return getResult(); @@ -2745,6 +2763,42 @@ SideEffects::DefaultResource::get()); } +namespace { +class DeadTransferWrite final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TransferWriteOp writeOp, + PatternRewriter &rewriter) const override { + if (!writeOp.getShapedType().isa()) + return failure(); + vector::TransferWriteOp writeToModify = writeOp; + auto defWrite = writeOp.source().getDefiningOp(); + while (defWrite) { + if (transferEncompasses(writeOp, defWrite)) { + writeToModify.sourceMutable().assign(defWrite.source()); + return success(); + } + if (!isDisjointTransferIndices( + cast(defWrite.getOperation()), + cast(writeOp.getOperation()))) + break; + // If the previous write op doesn't have any other use we an safely look + // at the previous store to see if it can be removed. + if (!defWrite->hasOneUse()) + break; + writeToModify = defWrite; + defWrite = defWrite.source().getDefiningOp(); + } + return failure(); + } +}; +} // namespace + +void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // LoadOp //===----------------------------------------------------------------------===// Index: mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp +++ mlir/lib/Dialect/Vector/VectorTransferOpTransforms.cpp @@ -34,34 +34,13 @@ return op; } -/// Return true if the transfer_write fully writes the data accessed by the -/// transfer_read. -static bool transferEncompasses(vector::TransferWriteOp defWrite, - vector::TransferReadOp read) { - return !defWrite.hasOutOfBoundsDim() && - defWrite.indices() == read.indices() && - defWrite.getVectorType() == read.getVectorType() && - defWrite.permutation_map() == read.permutation_map(); -} - -/// Return true if the write op fully over-write the priorWrite transfer_write -/// op. -static bool transferEncompasses(vector::TransferWriteOp write, - vector::TransferWriteOp priorWrite) { - return priorWrite.indices() == write.indices() && - priorWrite.getVectorType() == write.getVectorType() && - priorWrite.permutation_map() == write.permutation_map(); -} - namespace { class TransferOptimization { public: TransferOptimization(FuncOp func) : dominators(func), postDominators(func) {} void deadStoreOp(vector::TransferWriteOp); - void deadStoreOpTensor(vector::TransferWriteOp); void storeToLoadForwarding(vector::TransferReadOp); - void storeToLoadForwardingTensor(vector::TransferReadOp); void removeDeadOp() { for (Operation *op : opToErase) op->erase(); @@ -231,44 +210,6 @@ opToErase.push_back(read.getOperation()); } -/// Walk up the SSA links, if any write gets fully overwritten we can skip it. -/// If it has no more uses it becomes dead. -void TransferOptimization::deadStoreOpTensor(vector::TransferWriteOp write) { - auto defWrite = write.source().getDefiningOp(); - while (defWrite) { - if (transferEncompasses(write, defWrite)) { - write.sourceMutable().assign(defWrite.source()); - if (defWrite->use_empty()) - opToErase.push_back(defWrite.getOperation()); - return; - } - if (!isDisjointTransferIndices( - cast(defWrite.getOperation()), - cast(write.getOperation()))) - break; - defWrite = defWrite.source().getDefiningOp(); - } -} - -/// Walk up the SSA links, if any write fully match the written vector we can -/// replace the read by the vector. The read becomes dead and can be removed. -void TransferOptimization::storeToLoadForwardingTensor( - vector::TransferReadOp read) { - auto defWrite = read.source().getDefiningOp(); - while (defWrite) { - if (transferEncompasses(defWrite, read)) { - read.replaceAllUsesWith(defWrite.vector()); - opToErase.push_back(read.getOperation()); - return; - } - if (!isDisjointTransferIndices( - cast(defWrite.getOperation()), - cast(read.getOperation()))) - break; - defWrite = defWrite.source().getDefiningOp(); - } -} - } // namespace void mlir::vector::transferOpflowOpt(FuncOp func) { @@ -278,15 +219,11 @@ func.walk([&](vector::TransferReadOp read) { if (read.getShapedType().isa()) opt.storeToLoadForwarding(read); - else - opt.storeToLoadForwardingTensor(read); }); opt.removeDeadOp(); func.walk([&](vector::TransferWriteOp write) { if (write.getShapedType().isa()) opt.deadStoreOp(write); - else - opt.deadStoreOpTensor(write); }); opt.removeDeadOp(); } Index: mlir/lib/Dialect/Vector/VectorUtils.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorUtils.cpp +++ mlir/lib/Dialect/Vector/VectorUtils.cpp @@ -356,3 +356,19 @@ return false; return isDisjointTransferIndices(transferA, transferB); } + +bool mlir::transferEncompasses(vector::TransferWriteOp defWrite, + vector::TransferReadOp read) { + return !defWrite.hasOutOfBoundsDim() && !defWrite.mask() && !read.mask() && + defWrite.indices() == read.indices() && + defWrite.getVectorType() == read.getVectorType() && + defWrite.permutation_map() == read.permutation_map(); +} + +bool mlir::transferEncompasses(vector::TransferWriteOp write, + vector::TransferWriteOp priorWrite) { + return priorWrite.indices() == write.indices() && + priorWrite.mask() == write.mask() && + priorWrite.getVectorType() == write.getVectorType() && + priorWrite.permutation_map() == write.permutation_map(); +} Index: mlir/test/Dialect/Vector/canonicalize.mlir =================================================================== --- mlir/test/Dialect/Vector/canonicalize.mlir +++ mlir/test/Dialect/Vector/canonicalize.mlir @@ -799,3 +799,99 @@ // CHECK-NEXT: return %[[T0]], %[[T0]], %[[T0]] return %r0, %r1, %r2: tensor<2x3x4xf32>, tensor<2x3x4xf32>, tensor<2x3x4xf32> } + +// ----- + +// CHECK-LABEL: func @store_to_load_tensor +// CHECK-SAME: (%[[ARG:.*]]: tensor<4x4xf32>, %[[V0:.*]]: vector<1x4xf32>, %[[V1:.*]]: vector<1x4xf32>) +// CHECK: return %[[V0]] : vector<1x4xf32> +func @store_to_load_tensor(%arg0 : tensor<4x4xf32>, + %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>) -> vector<1x4xf32> { + %c1 = constant 1 : index + %c2 = constant 2 : index + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} : + vector<1x4xf32>, tensor<4x4xf32> + %w1 = vector.transfer_write %v1, %w0[%c2, %c0] {in_bounds = [true, true]} : + vector<1x4xf32>, tensor<4x4xf32> + %0 = vector.transfer_read %w1[%c1, %c0], %cf0 {in_bounds = [true, true]} : + tensor<4x4xf32>, vector<1x4xf32> + return %0 : vector<1x4xf32> +} + +// ----- + +// CHECK-LABEL: func @store_to_load_negative_tensor +// CHECK: vector.transfer_write +// CHECK: vector.transfer_write +// CHECK: %[[V:.*]] = vector.transfer_read +// CHECK: return %[[V]] : vector<1x4xf32> +func @store_to_load_negative_tensor(%arg0 : tensor<4x4xf32>, + %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> vector<1x4xf32> { + %c1 = constant 1 : index + %c2 = constant 2 : index + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} : + vector<1x4xf32>, tensor<4x4xf32> + %w1 = vector.transfer_write %v0, %w0[%i, %i] {in_bounds = [true, true]} : + vector<1x4xf32>, tensor<4x4xf32> + %0 = vector.transfer_read %w1[%c1, %c0], %cf0 {in_bounds = [true, true]} : + tensor<4x4xf32>, vector<1x4xf32> + return %0 : vector<1x4xf32> +} + +// ----- + + +// CHECK-LABEL: func @dead_store_tensor +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C1:.*]] = constant 1 : index +// CHECK-DAG: %[[C2:.*]] = constant 2 : index +// CHECK-NOT: vector.transfer_write {{.*}}, {{.*}}[%[[C1]], %[[C0]] +// CHECK: vector.transfer_write {{.*}}, {{.*}}[%[[C2]], %[[C0]] +// CHECK: %[[VTW:.*]] = vector.transfer_write {{.*}}, {{.*}}[%[[C1]], %[[C0]] +// CHECK: return %[[VTW]] : tensor<4x4xf32> +func @dead_store_tensor(%arg0 : tensor<4x4xf32>, + %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> { + %c1 = constant 1 : index + %c2 = constant 2 : index + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} : + vector<1x4xf32>, tensor<4x4xf32> + %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]} : + vector<1x4xf32>, tensor<4x4xf32> + %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]} : + vector<1x4xf32>, tensor<4x4xf32> + return %w2 : tensor<4x4xf32> +} + +// ----- + +// CHECK-LABEL: func @dead_store_tensor_negative +// CHECK-DAG: %[[C0:.*]] = constant 0 : index +// CHECK-DAG: %[[C1:.*]] = constant 1 : index +// CHECK: vector.transfer_write +// CHECK: vector.transfer_write +// CHECK: vector.transfer_read +// CHECK: %[[VTW:.*]] = vector.transfer_write {{.*}}, {{.*}}[%[[C1]], %[[C0]]] +// CHECK: return %[[VTW]] : tensor<4x4xf32> +func @dead_store_tensor_negative(%arg0 : tensor<4x4xf32>, + %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> { + %c1 = constant 1 : index + %c2 = constant 2 : index + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} : + vector<1x4xf32>, tensor<4x4xf32> + %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]} : + vector<1x4xf32>, tensor<4x4xf32> + %0 = vector.transfer_read %w1[%i, %i], %cf0 {in_bounds = [true, true]} : + tensor<4x4xf32>, vector<1x4xf32> + %x = addf %0, %0 : vector<1x4xf32> + %w2 = vector.transfer_write %x, %w0[%c1, %c0] {in_bounds = [true, true]} : + vector<1x4xf32>, tensor<4x4xf32> + return %w2 : tensor<4x4xf32> +} Index: mlir/test/Dialect/Vector/vector-transferop-opt.mlir =================================================================== --- mlir/test/Dialect/Vector/vector-transferop-opt.mlir +++ mlir/test/Dialect/Vector/vector-transferop-opt.mlir @@ -184,56 +184,3 @@ return } -// CHECK-LABEL: func @forward_dead_store_tensor -// CHECK-NOT: vector.transfer_write -// CHECK-NOT: vector.transfer_read -// CHECK: scf.for -// CHECK: } -// CHECK: %[[VTW:.*]] = vector.transfer_write -// CHECK: return %[[VTW]] : tensor<4x4xf32> -func @forward_dead_store_tensor(%arg0: i1, %arg1 : tensor<4x4xf32>, - %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> { - %c1 = constant 1 : index - %c4 = constant 4 : index - %c0 = constant 0 : index - %cf0 = constant 0.0 : f32 - %w0 = vector.transfer_write %v0, %arg1[%c1, %c0] {in_bounds = [true, true]} : - vector<1x4xf32>, tensor<4x4xf32> - %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]} : - tensor<4x4xf32>, vector<1x4xf32> - %x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0) - -> (vector<1x4xf32>) { - %1 = addf %acc, %acc : vector<1x4xf32> - scf.yield %1 : vector<1x4xf32> - } - %w1 = vector.transfer_write %x, %w0[%c1, %c0] {in_bounds = [true, true]} : - vector<1x4xf32>, tensor<4x4xf32> - return %w1 : tensor<4x4xf32> -} - -// CHECK-LABEL: func @forward_dead_store_negative_tensor -// CHECK: vector.transfer_write -// CHECK: vector.transfer_read -// CHECK: scf.for -// CHECK: } -// CHECK: %[[VTW:.*]] = vector.transfer_write -// CHECK: return %[[VTW]] : tensor<4x4xf32> -func @forward_dead_store_negative_tensor(%arg0: i1, %arg1 : tensor<4x4xf32>, - %v0 : vector<1x4xf32>, %v1 : vector<1x4xf32>, %i : index) -> tensor<4x4xf32> { - %c1 = constant 1 : index - %c4 = constant 4 : index - %c0 = constant 0 : index - %cf0 = constant 0.0 : f32 - %w0 = vector.transfer_write %v0, %arg1[%c1, %i] {in_bounds = [true, true]} : - vector<1x4xf32>, tensor<4x4xf32> - %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]} : - tensor<4x4xf32>, vector<1x4xf32> - %x = scf.for %i0 = %c0 to %c4 step %c1 iter_args(%acc = %0) - -> (vector<1x4xf32>) { - %1 = addf %acc, %acc : vector<1x4xf32> - scf.yield %1 : vector<1x4xf32> - } - %w1 = vector.transfer_write %x, %w0[%c1, %c0] {in_bounds = [true, true]} : - vector<1x4xf32>, tensor<4x4xf32> - return %w1 : tensor<4x4xf32> -}