diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -158,8 +158,10 @@ //===----------------------------------------------------------------------===// def Tensor_ExtractSliceOp : BaseOpWithOffsetSizesAndStrides< - Tensor_Dialect, "extract_slice", [NoSideEffect, AttrSizedOperandSegments, - OffsetSizeAndStrideOpInterface]> { + Tensor_Dialect, "extract_slice", + [NoSideEffect, AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + OffsetSizeAndStrideOpInterface]> { let summary = "extract slice operation"; let description = [{ The "extract_slice" operation extract a tensor from another tensor as @@ -284,6 +286,11 @@ /// Return the number of leading operands before the `offsets`, `sizes` and /// and `strides` operands. static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; } + + /// Return the dimensions of the source that are dropped in the + /// result when the result is rank-reduced. + llvm::SmallDenseSet getDroppedDims(); + }]; let hasCanonicalizer = 1; diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -277,9 +277,12 @@ unsigned unsignedIndex = index.getValue().getZExtValue(); if (auto sliceOp = dyn_cast_or_null(definingOp)) { - assert(sliceOp.isDynamicSize(unsignedIndex) && - "Expected dynamic slice size"); - return sliceOp.getDynamicSize(unsignedIndex); + // Fold only for non-rank reduced ops. For the rank-reduced version, rely on + // `resolve-shaped-type-result-dims` pass. + if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() && + sliceOp.isDynamicSize(unsignedIndex)) { + return {sliceOp.getDynamicSize(unsignedIndex)}; + } } // dim(cast) -> dim @@ -895,6 +898,46 @@ return resultType; } +llvm::SmallDenseSet ExtractSliceOp::getDroppedDims() { + llvm::SmallDenseSet droppedDims; + ArrayRef resultShape = getType().getShape(); + SmallVector mixedSizes = getMixedSizes(); + unsigned shapePos = 0; + for (auto size : enumerate(mixedSizes)) { + Optional sizeVal = getConstantIntValue(size.value()); + // If the size is not 1, or if the current matched dimension of the result + // is the same static shape as the size value (which is 1), then the + // dimension is preserved. + if (!sizeVal || sizeVal.getValue() != 1 || + (shapePos < resultShape.size() && resultShape[shapePos] == 1)) { + shapePos++; + continue; + } + droppedDims.insert(size.index()); + } + return droppedDims; +} + +LogicalResult ExtractSliceOp::reifyResultShapes( + OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { + reifiedReturnShapes.resize(1); + reifiedReturnShapes[0].reserve(getType().getRank()); + SmallVector mixedSizes = getMixedSizes(); + llvm::SmallDenseSet droppedDims = getDroppedDims(); + Location loc = getLoc(); + for (auto size : enumerate(mixedSizes)) { + if (droppedDims.count(size.index())) + continue; + if (auto attr = size.value().dyn_cast()) { + reifiedReturnShapes[0].push_back(builder.create( + loc, attr.cast().getInt())); + continue; + } + reifiedReturnShapes[0].push_back(size.value().get()); + } + return success(); +} + namespace { /// Pattern to rewrite an extract_slice op with tensor::Cast arguments. /// This essentially pushes memref_cast past its consuming slice when diff --git a/mlir/test/Dialect/Tensor/resolve-shaped-type-result-dims.mlir b/mlir/test/Dialect/Tensor/resolve-shaped-type-result-dims.mlir --- a/mlir/test/Dialect/Tensor/resolve-shaped-type-result-dims.mlir +++ b/mlir/test/Dialect/Tensor/resolve-shaped-type-result-dims.mlir @@ -25,3 +25,120 @@ // CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C1]] // CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG1]], %[[C2]] // CHECK: return %[[D0]], %[[D1]], %[[D2]] + +// ----- + +func @extract_slice(%arg0 : tensor, %arg1 : index, %arg2 : index, + %arg3 : index) -> (index, index, index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %0 = tensor.extract_slice %arg0[0, 0, 0] [%arg1, %arg2, %arg3] [1, 1, 1] : + tensor to tensor + %1 = tensor.dim %0, %c0 : tensor + %2 = tensor.dim %0, %c1 : tensor + %3 = tensor.dim %0, %c2 : tensor + return %1, %2, %3 : index, index, index +} +// CHECK-LABEL: func @extract_slice( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index +// CHECK: return %[[ARG1]], %[[ARG2]], %[[ARG3]] + +// ----- + +func @extract_slice_rank_reduced_1(%arg0 : tensor, + %arg1 : index) -> index { + %c0 = constant 0 : index + %0 = tensor.extract_slice %arg0[0, 0, 0] [1, %arg1, 1] [1, 1, 1] : + tensor to tensor + %1 = tensor.dim %0, %c0 : tensor + return %1 : index +} +// CHECK-LABEL: func @extract_slice_rank_reduced_1( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK: return %[[ARG1]] + +// ----- + +func @extract_slice_rank_reduced_2(%arg0 : tensor, + %arg1 : index) -> index { + %c0 = constant 0 : index + %0 = tensor.extract_slice %arg0[0, 0, 0] [1, %arg1, 1] [1, 1, 1] : + tensor to tensor + %1 = tensor.dim %0, %c0 : tensor + return %1 : index +} +// CHECK-LABEL: func @extract_slice_rank_reduced_2( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK: return %[[ARG1]] + +// ----- + +func @extract_slice_rank_reduced_3(%arg0 : tensor, + %arg1 : index) -> index { + %c1 = constant 1 : index + %0 = tensor.extract_slice %arg0[0, 0, 0] [1, %arg1, 1] [1, 1, 1] : + tensor to tensor<1x?xf32> + %1 = tensor.dim %0, %c1 : tensor<1x?xf32> + return %1 : index +} +// CHECK-LABEL: func @extract_slice_rank_reduced_3( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK: return %[[ARG1]] + +// ----- + +func @extract_slice_rank_reduced_4(%arg0 : tensor, + %arg1 : index) -> index { + %c1 = constant 1 : index + %0 = tensor.extract_slice %arg0[0, 0, 0] [1, %arg1, 1] [1, 1, 1] : + tensor to tensor<1x?x1xf32> + %1 = tensor.dim %0, %c1 : tensor<1x?x1xf32> + return %1 : index +} +// CHECK-LABEL: func @extract_slice_rank_reduced_4( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK: return %[[ARG1]] + +// ----- + +func @extract_slice_rank_reduced_5(%arg0 : tensor, %arg1 : index, + %arg2 : index) -> (index, index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = tensor.extract_slice %arg0[0, 0, 0] [%arg1, 1, %arg2] [1, 1, 1] : + tensor to tensor + %1 = tensor.dim %0, %c0 : tensor + %2 = tensor.dim %0, %c1 : tensor + return %1, %2 : index, index +} +// CHECK-LABEL: func @extract_slice_rank_reduced_5( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index +// CHECK: return %[[ARG1]], %[[ARG2]] + +// ----- + +func @extract_slice_rank_reduced_6(%arg0 : tensor, %arg1 : index, + %arg2 : index) -> (index, index) { + %c0 = constant 0 : index + %c2 = constant 2 : index + %0 = tensor.extract_slice %arg0[0, 0, 0] [%arg1, 1, %arg2] [1, 1, 1] : + tensor to tensor + %1 = tensor.dim %0, %c0 : tensor + %2 = tensor.dim %0, %c2 : tensor + return %1, %2 : index, index +} +// CHECK-LABEL: func @extract_slice_rank_reduced_6( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index +// CHECK: return %[[ARG1]], %[[ARG2]]