diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -24,6 +24,7 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" @@ -2783,8 +2784,35 @@ if (!extractOp.hasUnitStride()) return failure(); + // Bail on illegal rank-reduction: we need to check that the rank-reduced + // dims are exactly the leading dims. I.e. the following is illegal: + // ``` + // %0 = tensor.extract_slice %t[0,0,0][2,1,4][1,1,1] : + // tensor<2x1x4xf32> to tensor<2x4xf32> + // %1 = vector.transfer_read %0[0,0], %cst : + // tensor<2x4xf32>, vector<2x4xf32> + // ``` + // + // Cannot fold into: + // ``` + // %0 = vector.transfer_read %t[0,0,0], %cst : + // tensor<2x1x4xf32>, vector<2x4xf32> + // ``` + // For this, check the trailing `vectorRank` dims of the extract_slice + // result tensor match the trailing dims of the inferred result tensor. int64_t rankReduced = extractOp.getSourceType().getRank() - extractOp.getType().getRank(); + int64_t vectorRank = xferOp.getVectorType().getRank(); + RankedTensorType inferredDestTensorType = + tensor::ExtractSliceOp::inferResultType( + extractOp.getSourceType(), extractOp.getMixedOffsets(), + extractOp.getMixedSizes(), extractOp.getMixedStrides()); + auto actualDestTensorShape = extractOp.getType().getShape(); + if (rankReduced > 0 && + actualDestTensorShape.take_back(vectorRank) != + inferredDestTensorType.getShape().take_back(vectorRank)) + return failure(); + SmallVector newIndices; // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced // indices first. @@ -3168,7 +3196,7 @@ if (xferOp.mask()) return failure(); // Fold only if the TransferWriteOp completely overwrites the `source` with - // a vector. I.e., the result of the TransferWriteOp is a new tensor who's + // a vector. I.e., the result of the TransferWriteOp is a new tensor whose // content is the data of the vector. if (!llvm::equal(xferOp.getVectorType().getShape(), xferOp.getShapedType().getShape())) @@ -3176,6 +3204,35 @@ if (!xferOp.permutation_map().isIdentity()) return failure(); + // Bail on illegal rank-reduction: we need to check that the rank-reduced + // dims are exactly the leading dims. I.e. the following is illegal: + // ``` + // %0 = vector.transfer_write %v, %t[0,0], %cst : + // vector<2x4xf32>, tensor<2x4xf32> + // %1 = tensor.insert_slice %0 into %tt[0,0,0][2,1,4][1,1,1] : + // tensor<2x4xf32> into tensor<2x1x4xf32> + // ``` + // + // Cannot fold into: + // ``` + // %0 = vector.transfer_write %v, %t[0,0,0], %cst : + // vector<2x4xf32>, tensor<2x1x4xf32> + // ``` + // For this, check the trailing `vectorRank` dims of the insert_slice result + // tensor match the trailing dims of the inferred result tensor. + int64_t rankReduced = + insertOp.getType().getRank() - insertOp.getSourceType().getRank(); + int64_t vectorRank = xferOp.getVectorType().getRank(); + RankedTensorType inferredSourceTensorType = + tensor::ExtractSliceOp::inferResultType( + insertOp.getType(), insertOp.getMixedOffsets(), + insertOp.getMixedSizes(), insertOp.getMixedStrides()); + auto actualSourceTensorShape = insertOp.getSourceType().getShape(); + if (rankReduced > 0 && + actualSourceTensorShape.take_back(vectorRank) != + inferredSourceTensorType.getShape().take_back(vectorRank)) + return failure(); + SmallVector indices = getValueOrCreateConstantIndexOp( rewriter, insertOp.getLoc(), insertOp.getMixedOffsets()); SmallVector inBounds(xferOp.getTransferRank(), true); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -995,6 +995,20 @@ // ----- +// CHECK-LABEL: func @transfer_read_of_extract_slice_illegal_rank_reducing( +// CHECK: extract_slice +// CHECK: vector.transfer_read +func @transfer_read_of_extract_slice_illegal_rank_reducing(%t : tensor, %s1 : index, %s2 : index) -> vector<5x6xf32> { + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %cst = arith.constant 0.0 : f32 + %0 = tensor.extract_slice %t[5, %s1, 6] [%s2, 1, 12] [1, 1, 1] : tensor to tensor + %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true, true]} : tensor, vector<5x6xf32> + return %1 : vector<5x6xf32> +} + +// ----- + // CHECK-LABEL: func @insert_slice_of_transfer_write( // CHECK-SAME: %[[t1:.*]]: tensor, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index // CHECK: %[[c3:.*]] = arith.constant 3 : index @@ -1009,6 +1023,18 @@ // ----- +// CHECK-LABEL: func @insert_slice_of_transfer_write_illegal_rank_extending( +// CHECK: vector.transfer_write +// CHECK: insert_slice +func @insert_slice_of_transfer_write_illegal_rank_extending(%t1 : tensor, %v : vector<5x6xf32>, %s : index, %t2 : tensor<5x6xf32>) -> tensor { + %c0 = arith.constant 0 : index + %0 = vector.transfer_write %v, %t2[%c0, %c0] {in_bounds = [true, true]} : vector<5x6xf32>, tensor<5x6xf32> + %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [5, 1, 6] [1, 1, 1] : tensor<5x6xf32> into tensor + return %1 : tensor +} + +// ----- + // CHECK-LABEL: func @insert_slice_of_transfer_write_rank_extending( // CHECK-SAME: %[[t1:.*]]: tensor, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index // CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index