diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -3167,6 +3167,21 @@ SideEffects::DefaultResource::get()); } +/// Returns true if all rank reduced in the given `extractOp` happen in leading +/// dimensions earlier than last `trailingRank` dimensions. +static bool areAllRankReducedLeadingDim(tensor::ExtractSliceOp extractOp, + unsigned trailingRank) { + // If no ranks are reduced at all, it's a degenerated case; always true. + if (extractOp.getSourceType().getRank() == extractOp.getType().getRank()) + return true; + + RankedTensorType inferredType = extractOp.inferResultType( + extractOp.getSourceType(), extractOp.getMixedOffsets(), + extractOp.getMixedSizes(), extractOp.getMixedStrides()); + return extractOp.getType().getShape().take_back(trailingRank) == + inferredType.getShape().take_back(trailingRank); +} + namespace { /// Fold transfer_reads of a tensor.extract_slice op. E.g.: /// @@ -3221,18 +3236,11 @@ // ``` // For this, check the trailing `vectorRank` dims of the extract_slice // result tensor match the trailing dims of the inferred result tensor. + if (!areAllRankReducedLeadingDim(extractOp, extractOp.getType().getRank())) + return failure(); + int64_t rankReduced = extractOp.getSourceType().getRank() - extractOp.getType().getRank(); - int64_t vectorRank = xferOp.getVectorType().getRank(); - RankedTensorType inferredDestTensorType = - tensor::ExtractSliceOp::inferResultType( - extractOp.getSourceType(), extractOp.getMixedOffsets(), - extractOp.getMixedSizes(), extractOp.getMixedStrides()); - auto actualDestTensorShape = extractOp.getType().getShape(); - if (rankReduced > 0 && - actualDestTensorShape.take_back(vectorRank) != - inferredDestTensorType.getShape().take_back(vectorRank)) - return failure(); SmallVector newIndices; // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1809,3 +1809,16 @@ %insert = vector.insert %v0, %v1[0] : vector<4x3xi32> into vector<2x4x3xi32> return %insert : vector<2x4x3xi32> } + +// ----- + +// CHECK-LABEL: func.func @transfer_read_from_rank_reducing_extract_slice +// CHECK: tensor.extract_slice +// CHECK: vector.transfer_read +func.func @transfer_read_from_rank_reducing_extract_slice(%src: tensor<1x8x8x8xf32>, %i1: index, %i2: index, %i3: index, %i4: index) -> vector<4xf32> { + %c0 = arith.constant 0 : index + %f0 = arith.constant 0.000000e+00 : f32 + %0 = tensor.extract_slice %src[0, %i1, %i2, %i3] [1, 4, 1, 4] [1, 1, 1, 1] : tensor<1x8x8x8xf32> to tensor<1x4x4xf32> + %1 = vector.transfer_read %0[%c0, %i4, %c0], %f0 {in_bounds = [true]} : tensor<1x4x4xf32>, vector<4xf32> + return %1 : vector<4xf32> +}