diff --git a/mlir/lib/Dialect/Tensor/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Tensor/TransformOps/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/TransformOps/CMakeLists.txt @@ -14,4 +14,5 @@ MLIRTensorDialect MLIRTensorTransforms MLIRTransformDialect + MLIRValueBoundsOpInterface ) diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp --- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp +++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; @@ -23,18 +24,27 @@ // TrackingListener //===----------------------------------------------------------------------===// -/// A tensor.insert_slice is a cast-like operation if it the source tensor and -/// the destination tensor have the same number of elements. I.e., the result -/// tensor data equals the source tensor data, maybe rank-extended to a -/// different shape. +/// A tensor.insert_slice is a cast-like operation if it merely rank-extends the +/// source tensor or inserts the source tensor into a destination tensor with +/// the same shape. static bool isCastLikeInsertSliceOp(InsertSliceOp op) { - // TODO: Support dynamically shaped tensors. Utilize ValueBoundsOpInterface - // to check if source and destination have the same shape. - if (!op.getSourceType().hasStaticShape() || - !op.getDestType().hasStaticShape()) - return false; - return op.getSourceType().getNumElements() == - op.getDestType().getNumElements(); + llvm::SmallBitVector droppedDims = op.getDroppedDims(); + int64_t srcDim = 0; + // Source dims and destination dims (apart from dropped dims) must have the + // same size. + for (int64_t resultDim = 0; resultDim < op.getDestType().getRank(); + ++resultDim) { + if (droppedDims.test(resultDim)) { + continue; + } + FailureOr equalDimSize = ValueBoundsConstraintSet::areEqual( + op.getSource(), op.getResult(), srcDim, resultDim); + if (failed(equalDimSize) || !*equalDimSize) + return false; + ++srcDim; + } + + return true; } Operation * diff --git a/mlir/test/Dialect/Tensor/tracking-listener.mlir b/mlir/test/Dialect/Tensor/tracking-listener.mlir --- a/mlir/test/Dialect/Tensor/tracking-listener.mlir +++ b/mlir/test/Dialect/Tensor/tracking-listener.mlir @@ -82,3 +82,26 @@ : tensor<5xf32> into tensor<7xf32> return } + +// ----- + +func.func @cast_like_insert_slice_dynamic( + %t: tensor<1x?x1xf32>, %f: f32, %pos: index) { + %c0 = arith.constant 0 : index + %0 = tensor.insert %f into %t[%c0, %pos, %c0] {replaced} : tensor<1x?x1xf32> + + // Rank reduction + %c1 = arith.constant 1 : index + %dim1 = tensor.dim %t, %c1 : tensor<1x?x1xf32> + %1 = tensor.extract_slice %t[0, 0, 0][1, %dim1, 1][1, 1, 1] + : tensor<1x?x1xf32> to tensor + // expected-remark @below {{replacement found}} + %2 = tensor.insert %f into %1[%c0] : tensor + // Rank expansion + // Throw in a wrench: Do not use %dim1 directly, but another SSA value that + // has the same runtime value. + %dim1b = tensor.dim %1, %c0 : tensor + %3 = tensor.insert_slice %2 into %t[0, 0, 0][1, %dim1b, 1][1, 1, 1] + {replacement_0 = 0} : tensor into tensor<1x?x1xf32> + return +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -5972,6 +5972,7 @@ ":TensorTransformOpsIncGen", ":TensorTransforms", ":TransformDialect", + ":ValueBoundsOpInterface", "//llvm:Support", ], )