diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h @@ -47,6 +47,10 @@ /// the same shape. bool isCastLikeInsertSliceOp(InsertSliceOp op); +/// A tensor.extract_slice is a cast-like operation if it merely rank-reduces +/// the source tensor or extracts the entire source tensor. +bool isCastLikeExtractSliceOp(ExtractSliceOp op); + } // namespace tensor } // namespace mlir diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp --- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp +++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp @@ -38,8 +38,6 @@ return nullptr; // Skip cast-like operations. - // TODO: CastOpInterface could be used if CollapseShapeOp and ExpandShapeOp - // implement that interface values.clear(); llvm::TypeSwitch(defOp) .Case([&](CastOp op) { values.push_back(op.getSource()); }) @@ -53,6 +51,10 @@ if (isCastLikeInsertSliceOp(op)) values.push_back(op.getSource()); }) + .Case([&](ExtractSliceOp op) { + if (isCastLikeExtractSliceOp(op)) + values.push_back(op.getSource()); + }) .Default([](Operation *op) {}); } while (!values.empty()); diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp --- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp @@ -123,3 +123,22 @@ return true; } + +bool mlir::tensor::isCastLikeExtractSliceOp(ExtractSliceOp op) { + llvm::SmallBitVector droppedDims = op.getDroppedDims(); + int64_t resultDim = 0; + // Source dims and result dims (apart from dropped dims) must have the same + // size. + for (int64_t dim = 0; dim < op.getSourceType().getRank(); ++dim) { + if (droppedDims.test(dim)) { + continue; + } + FailureOr equalDimSize = ValueBoundsConstraintSet::areEqual( + op.getSource(), op.getResult(), dim, resultDim); + if (failed(equalDimSize) || !*equalDimSize) + return false; + ++resultDim; + } + + return true; +} diff --git a/mlir/test/Dialect/Tensor/tracking-listener.mlir b/mlir/test/Dialect/Tensor/tracking-listener.mlir --- a/mlir/test/Dialect/Tensor/tracking-listener.mlir +++ b/mlir/test/Dialect/Tensor/tracking-listener.mlir @@ -105,3 +105,38 @@ {replacement_0 = 0} : tensor into tensor<1x?x1xf32> return } + +// ----- + +func.func @cast_like_extract_slice() { + %0 = "test.foo"() {replaced} : () -> (tensor<5xf32>) + // expected-remark @below {{replacement found}} + %1 = "test.foo"() : () -> (tensor<1x5x1x1xf32>) + %2 = tensor.extract_slice %1[0, 0, 0, 0][1, 5, 1, 1][1, 1, 1, 1] + {replacement_0 = 0} : tensor<1x5x1x1xf32> to tensor<5xf32> + return +} + +// ----- + +func.func @cast_like_extract_slice_dynamic() { + %0 = "test.foo"() {replaced} : () -> (tensor) + // expected-remark @below {{replacement found}} + %1 = "test.foo"() : () -> (tensor<1x?x1x1xf32>) + %c1 = arith.constant 1 : index + %dim = tensor.dim %1, %c1 : tensor<1x?x1x1xf32> + %2 = tensor.extract_slice %1[0, 0, 0, 0][1, %dim, 1, 1][1, 1, 1, 1] + {replacement_0 = 0} : tensor<1x?x1x1xf32> to tensor + return +} + +// ----- + +func.func @non_cast_like_extract_slice() { + // expected-error @below {{listener could not find replacement op}} + %0 = "test.foo"() {replaced} : () -> (tensor<5xf32>) + %1 = "test.foo"() : () -> (tensor<1x5x1x1xf32>) + %2 = tensor.extract_slice %1[0, 0, 0, 0][1, 3, 1, 1][1, 1, 1, 1] + {replacement_0 = 0} : tensor<1x5x1x1xf32> to tensor<3xf32> + return +}