diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -1061,6 +1061,8 @@ return impl::getTransferMinorIdentityMap(memRefType, vectorType); } }]; + + let hasFolder = 1; } def Vector_TransferWriteOp : @@ -1150,6 +1152,8 @@ return impl::getTransferMinorIdentityMap(memRefType, vectorType); } }]; + + let hasFolder = 1; } def Vector_ShapeCastOp : diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1498,6 +1498,30 @@ [&op](Twine t) { return op.emitOpError(t); }); } +/// This is a common class used for patterns of the form +/// ``` +/// someop(memrefcast) -> someop +/// ``` +/// It folds the source of the memref_cast into the root operation directly. +static LogicalResult foldMemRefCast(Operation *op) { + bool folded = false; + for (OpOperand &operand : op->getOpOperands()) { + auto castOp = operand.get().getDefiningOp(); + if (castOp && canFoldIntoConsumerOp(castOp)) { + operand.set(castOp.getOperand()); + folded = true; + } + } + return success(folded); +} + +OpFoldResult TransferReadOp::fold(ArrayRef) { + /// transfer_read(memrefcast) -> transfer_read + if (succeeded(foldMemRefCast(*this))) + return getResult(); + return OpFoldResult(); +} + //===----------------------------------------------------------------------===// // TransferWriteOp //===----------------------------------------------------------------------===// @@ -1583,6 +1607,11 @@ [&op](Twine t) { return op.emitOpError(t); }); } +LogicalResult TransferWriteOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} + //===----------------------------------------------------------------------===// // ShapeCastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -159,3 +159,19 @@ // CHECK-NEXT: return [[ADD]] return %7 : vector<4x3x2xf32> } + +// ----- + +// CHECK-LABEL: cast_transfers +func @cast_transfers(%A: memref<4x8xf32>) -> (vector<4x8xf32>) { + %c0 = constant 0 : index + %f0 = constant 0.0 : f32 + %0 = memref_cast %A : memref<4x8xf32> to memref + + // CHECK: vector.transfer_read %{{.*}} : memref<4x8xf32>, vector<4x8xf32> + %1 = vector.transfer_read %0[%c0, %c0], %f0 : memref, vector<4x8xf32> + + // CHECK: vector.transfer_write %{{.*}} : vector<4x8xf32>, memref<4x8xf32> + vector.transfer_write %1, %0[%c0, %c0] : vector<4x8xf32>, memref + return %1 : vector<4x8xf32> +}