diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -3294,7 +3294,7 @@ return failure(); if (xferOp.hasOutOfBoundsDim()) return failure(); - if (!xferOp.getPermutationMap().isIdentity()) + if (!xferOp.getPermutationMap().isMinorIdentity()) return failure(); if (xferOp.getMask()) return failure(); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1169,6 +1169,24 @@ // ----- +// CHECK-LABEL: func @transfer_read_of_extract_slice( +// CHECK-SAME: %[[t:.*]]: tensor, %[[s1:.*]]: index, %[[s2:.*]]: index +// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index +// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index +// CHECK: %[[add:.*]] = arith.addi %[[s1]], %[[c4]] +// CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true]} : tensor, vector<6xf32> +// CHECK: return %[[r]] +func.func @transfer_read_of_extract_slice(%t : tensor, %s1 : index, %s2 : index) -> vector<6xf32> { + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %cst = arith.constant 0.0 : f32 + %0 = tensor.extract_slice %t[5, %s1] [10, %s2] [1, 1] : tensor to tensor<10x?xf32> + %1 = vector.transfer_read %0[%c3, %c4], %cst {in_bounds = [true]} : tensor<10x?xf32>, vector<6xf32> + return %1 : vector<6xf32> +} + +// ----- + // CHECK-LABEL: func @transfer_read_of_extract_slice_rank_reducing( // CHECK-SAME: %[[t:.*]]: tensor, %[[s1:.*]]: index, %[[s2:.*]]: index // CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index