diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/VectorUtils.h" #include "mlir/IR/AffineExpr.h" @@ -2400,6 +2401,18 @@ return success(folded); } +static LogicalResult foldTensorCast(Operation *op) { + bool folded = false; + for (OpOperand &operand : op->getOpOperands()) { + auto castOp = operand.get().getDefiningOp(); + if (castOp && tensor::canFoldIntoConsumerOp(castOp)) { + operand.set(castOp.getOperand()); + folded = true; + } + } + return success(folded); +} + template static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) { // TODO: support more aggressive createOrFold on: @@ -2452,6 +2465,8 @@ return getResult(); if (succeeded(foldMemRefCast(*this))) return getResult(); + if (succeeded(foldTensorCast(*this))) + return getResult(); return OpFoldResult(); } diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -267,6 +267,20 @@ // ----- +// CHECK-LABEL: cast_transfers +func @cast_transfers(%A: tensor<4x8xf32>) -> (vector<4x8xf32>) { + %c0 = constant 0 : index + %f0 = constant 0.0 : f32 + %0 = tensor.cast %A : tensor<4x8xf32> to tensor + + // CHECK: vector.transfer_read %{{.*}} {masked = [false, false]} : tensor<4x8xf32>, vector<4x8xf32> + %1 = vector.transfer_read %0[%c0, %c0], %f0 : tensor, vector<4x8xf32> + + return %1 : vector<4x8xf32> +} + +// ----- + // CHECK-LABEL: func @insert_extract_transpose_2d( // CHECK-SAME: %[[V:[a-zA-Z0-9]*]]: vector<2x3xf32>, // CHECK-SAME: %[[F0:[a-zA-Z0-9]*]]: f32,