diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -61,6 +61,10 @@ /// ``` bool canFoldIntoConsumerOp(CastOp castOp); +/// Performs folding of any operand of `op` if it comes from a tensor::CastOp +/// that can be folded. +LogicalResult foldTensorCast(Operation *op); + } // namespace tensor } // namespace mlir diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -3790,6 +3790,8 @@ if (getSourceType() == getType() && succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType()))) return this->source(); + if (succeeded(tensor::foldTensorCast(*this))) + return this->source(); return OpFoldResult(); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -73,6 +73,20 @@ return true; } +/// Performs folding of any operand of `op` if it comes from a tensor::CastOp +/// that can be folded. +LogicalResult mlir::tensor::foldTensorCast(Operation *op) { + bool folded = false; + for (OpOperand &operand : op->getOpOperands()) { + auto castOp = operand.get().getDefiningOp(); + if (castOp && tensor::canFoldIntoConsumerOp(castOp)) { + operand.set(castOp.getOperand()); + folded = true; + } + } + return success(folded); +} + bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { if (inputs.size() != 1 || outputs.size() != 1) return false; diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -237,3 +237,18 @@ %1 = subtensor %0[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor to tensor<16x32xi8> return %1 : tensor<16x32xi8> } + +// ----- + +// CHECK-LABEL: func @rank_reducing_subtensor_insert_of_cast +// CHECK-SAME: %[[A:.[a-z0-9A-Z_]+]]: tensor<16x32xi8> +// CHECK-SAME: %[[B:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8> +// CHECK: %[[S:.+]] = subtensor_insert %[[A]] into %[[B]][0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<16x32xi8> into tensor<4x6x16x32xi8> +// Tensor cast is folded away. +// CHECK-NOT: tensor.cast +// CHECK: return %[[S]] : tensor<4x6x16x32xi8> +func @rank_reducing_subtensor_insert_of_cast(%a : tensor<16x32xi8>, %b : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> { + %cast = tensor.cast %a : tensor<16x32xi8> to tensor + %res = subtensor_insert %cast into %b[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor into tensor<4x6x16x32xi8> + return %res : tensor<4x6x16x32xi8> +}