diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h @@ -38,7 +38,7 @@ bool isCastLikeInsertSliceOp(InsertSliceOp op); /// A tensor.extract_slice is a cast-like operation if it merely rank-reduces -/// the source tensor or extracts the entire source tensor. +/// unit dimensions of the source tensor or extracts the entire source tensor. bool isCastLikeExtractSliceOp(ExtractSliceOp op); } // namespace tensor diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp --- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp @@ -98,8 +98,13 @@ int64_t resultDim = 0; // Source dims and result dims (apart from dropped dims) must have the same // size. - for (int64_t dim = 0; dim < op.getSourceType().getRank(); ++dim) { + RankedTensorType sourceType = op.getSourceType(); + for (int64_t dim = 0, e = sourceType.getRank(); dim < e; ++dim) { if (droppedDims.test(dim)) { + // ExtractSlice may drop unit dimensions that result from taking a size-1 + // slice from a non-size-1 source dimension. + if (sourceType.getDimSize(dim) != 1) + return false; continue; } FailureOr equalDimSize = ValueBoundsConstraintSet::areEqual( diff --git a/mlir/test/Dialect/Tensor/tracking-listener.mlir b/mlir/test/Dialect/Tensor/tracking-listener.mlir --- a/mlir/test/Dialect/Tensor/tracking-listener.mlir +++ b/mlir/test/Dialect/Tensor/tracking-listener.mlir @@ -140,3 +140,14 @@ {replacement_0 = 0} : tensor<1x5x1x1xf32> to tensor<3xf32> return } + +// ----- + +func.func @non_cast_like_extract_slice_drop_non_unit_dim() { + // expected-error @below {{listener could not find replacement op}} + %0 = "test.foo"() {replaced} : () -> (tensor) + %1 = "test.foo"() : () -> (tensor<1x5x1x1xf32>) + %2 = tensor.extract_slice %1[0, 0, 0, 0][1, 1, 1, 1][1, 1, 1, 1] + {replacement_0 = 0} : tensor<1x5x1x1xf32> to tensor + return +}