diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -2593,8 +2593,70 @@ [&op](Twine t) { return op.emitOpError(t); }); } -LogicalResult TransferWriteOp::fold(ArrayRef, - SmallVectorImpl &) { +/// Fold: +/// ``` +/// %t1 = ... +/// %v = vector.transfer_read %t0[%c0...], {masked = [false...]} : +/// tensor, vector +/// %t2 = vector.transfer_write %v, %t1[%c0...] {masked = [false...]} : +/// vector, tensor +/// ``` +/// +/// into: +/// +/// ``` +/// %t0 +/// ``` +/// +/// The producer of t1 may or may not be DCE'd depending on whether it is a +/// block argument or has side effects. +static LogicalResult foldReadInitWrite(TransferWriteOp write, + ArrayRef, + SmallVectorImpl &results) { + auto rankedTensorType = write.source().getType().dyn_cast(); + // If not operating on tensors, bail. + if (!rankedTensorType) + return failure(); + // If no read, bail. + auto read = write.vector().getDefiningOp(); + if (!read) + return failure(); + // For now, only accept minor identity. Future: composition is minor identity. + if (!read.permutation_map().isMinorIdentity() || + !write.permutation_map().isMinorIdentity()) + return failure(); + // Bail on mismatching ranks. + if (read.getTransferRank() != write.getTransferRank()) + return failure(); + // Bail on masked. + if (read.hasMaskedDim() || write.hasMaskedDim()) + return failure(); + // Tensor types must be the same. + if (read.source().getType() != rankedTensorType) + return failure(); + // Vector types must be the same. + if (read.getVectorType() != write.getVectorType()) + return failure(); + // Vector and Tensor shapes must match. + if (read.getVectorType().getShape() != rankedTensorType.getShape()) + return failure(); + // If any index is nonzero. + auto isNotConstantZero = [](Value v) { + auto cstOp = v.getDefiningOp(); + return !cstOp || cstOp.getValue() != 0; + }; + if (llvm::any_of(read.indices(), isNotConstantZero) || + llvm::any_of(write.indices(), isNotConstantZero)) + return failure(); + // Success. + results.push_back(read.source()); + return success(); +} + +LogicalResult TransferWriteOp::fold(ArrayRef operands, + SmallVectorImpl &results) { + if (succeeded(foldReadInitWrite(*this, operands, results))) + return success(); if (succeeded(foldTransferMaskAttribute(*this))) return success(); return foldMemRefCast(*this); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s +// RUN: mlir-opt %s -pass-pipeline='func(canonicalize)' -split-input-file -allow-unregistered-dialect | FileCheck %s // ----- @@ -770,3 +770,32 @@ return %1, %i8_1: vector<2x4xf32>, vector<2x4xi8> } +// ----- + +// CHECK-LABEL: func @transfer_folding_1 +// CHECK-SAME: %[[T0:[0-9a-zA-Z]+]]: tensor<2x3x4xf32> +// CHECK-SAME: %[[T1:[0-9a-zA-Z]+]]: tensor<2x3x4xf32> +func @transfer_folding_1(%t0: tensor<2x3x4xf32>, %t1: tensor<2x3x4xf32>) + -> (tensor<2x3x4xf32>, tensor<2x3x4xf32>, tensor<2x3x4xf32>) +{ + %c0 = constant 0 : index + %pad = constant 0.0 : f32 + %v = vector.transfer_read %t0[%c0, %c0, %c0], %pad {masked = [false, false, false]} : + tensor<2x3x4xf32>, vector<2x3x4xf32> + + %r0 = vector.transfer_write %v, %t1[%c0, %c0, %c0] {masked = [false, false, false]} : + vector<2x3x4xf32>, tensor<2x3x4xf32> + + %t2 = "test.constant"() { value = dense<6.0> : tensor<2x3x4xf32>} : () -> (tensor<2x3x4xf32>) + %r1 = vector.transfer_write %v, %t2[%c0, %c0, %c0] {masked = [false, false, false]} : + vector<2x3x4xf32>, tensor<2x3x4xf32> + + + // CHECK-NEXT: some_op_that_may_have_side_effects + %t3 = "some_op_that_may_have_side_effects"() : () -> (tensor<2x3x4xf32>) + %r2 = vector.transfer_write %v, %t0[%c0, %c0, %c0] {masked = [false, false, false]} : + vector<2x3x4xf32>, tensor<2x3x4xf32> + + // CHECK-NEXT: return %[[T0]], %[[T0]], %[[T0]] + return %r0, %r1, %r2: tensor<2x3x4xf32>, tensor<2x3x4xf32>, tensor<2x3x4xf32> +}