diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -2593,8 +2593,73 @@ [&op](Twine t) { return op.emitOpError(t); }); } -LogicalResult TransferWriteOp::fold(ArrayRef, - SmallVectorImpl &) { +/// Fold: +/// ``` +/// %v = vector.transfer_read %t0[%c0...], {masked = [false...]} : +/// tensor, vector +/// %t1 = side_effect_free_tensor_producing_op : tensor +/// %t2 = vector.transfer_write %v, %t1[%c0...] {masked = [false...]} : +/// vector, tensor +/// ``` +/// into: +/// ``` +/// %t0 +/// ``` +static LogicalResult foldReadInitWrite(TransferWriteOp write, + ArrayRef, + SmallVectorImpl &results) { + auto rankedTensorType = write.source().getType().dyn_cast(); + // If not operating on tensors, bail. + if (!rankedTensorType) + return failure(); + // Approximate linalg.init_tensor: single result, only index operands. + // TODO: OpInterface for a tensor creation op. + Operation *init = write.source().getDefiningOp(); + if (!init || init->getNumResults() != 1 || + llvm::any_of(init->getOperands(), + [](Value v) { return !v.getType().isa(); })) + return failure(); + // If no read, bail. + auto read = write.vector().getDefiningOp(); + if (!read) + return failure(); + // For now, only accept minor identity. Future: composition is minor identity. + if (!read.permutation_map().isMinorIdentity() || + !write.permutation_map().isMinorIdentity()) + return failure(); + // Bail on mismatching ranks. + if (read.getTransferRank() != write.getTransferRank()) + return failure(); + // Bail on masked. + if (read.hasMaskedDim() || write.hasMaskedDim()) + return failure(); + // Tensor types must be the same. + if (read.source().getType() != rankedTensorType || + init->getResult(0).getType() != rankedTensorType) + return failure(); + // Vector types must be the same. + if (read.getVectorType() != write.getVectorType()) + return failure(); + // Vector and Tensor shapes must match. + if (read.getVectorType().getShape() != rankedTensorType.getShape()) + return failure(); + // If any index is nonzero. + auto isNotConstantZero = [](Value v) { + auto cstOp = v.getDefiningOp(); + return !cstOp || cstOp.getValue() != 0; + }; + if (llvm::any_of(read.indices(), isNotConstantZero) || + llvm::any_of(write.indices(), isNotConstantZero)) + return failure(); + // Success. + results.push_back(read.source()); + return success(); +} + +LogicalResult TransferWriteOp::fold(ArrayRef operands, + SmallVectorImpl &results) { + if (succeeded(foldReadInitWrite(*this, operands, results))) + return success(); if (succeeded(foldTransferMaskAttribute(*this))) return success(); return foldMemRefCast(*this); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -pass-pipeline='func(canonicalize)' -split-input-file | FileCheck %s +// RUN: mlir-opt %s -pass-pipeline='func(canonicalize)' -split-input-file -allow-unregistered-dialect | FileCheck %s // ----- @@ -770,3 +770,19 @@ return %1, %i8_1: vector<2x4xf32>, vector<2x4xi8> } +// ----- + +// CHECK-LABEL: func @transfer_folding +// CHECK-SAME: %[[T:[0-9a-zA-Z]+]]: tensor<2x3x4xf32> +func @transfer_folding(%t0: tensor<2x3x4xf32>) -> (tensor<2x3x4xf32>) { + %c0 = constant 0 : index + %pad = constant 0.0 : f32 + %v = vector.transfer_read %t0[%c0, %c0, %c0], %pad {masked = [false, false, false]} : + tensor<2x3x4xf32>, vector<2x3x4xf32> + %t1 = "side_effect_free_tensor_producing_op"() : () -> (tensor<2x3x4xf32>) + %t2 = vector.transfer_write %v, %t1[%c0, %c0, %c0] {masked = [false, false, false]} : + vector<2x3x4xf32>, tensor<2x3x4xf32> + + // CHECK: return %[[T]] + return %t2: tensor<2x3x4xf32> +}