diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -186,26 +186,34 @@ if (!sliceOp) return failure(hasDynamicShape); + // Compute the dropped dimensions if `sliceOp` is ranke-reducing. + llvm::SmallDenseSet droppedDims = sliceOp.getDroppedDims(); + // Upper bound the `sliceOp` sizes to obtain a static bounding box. SmallVector staticSizes; - staticSizes.reserve(opToPad.getRank(opOperand)); + staticSizes.reserve(shape.size()); auto shapedOp = cast(sliceOp.getOperation()); - for (auto size : shapedOp.getMixedSizes()) { + for (auto en : enumerate(shapedOp.getMixedSizes())) { + // Skip dropped dimensions. + if (droppedDims.contains(en.index())) + continue; // If the size is an attribute add it directly to `staticSizes`. - if (size.is()) { + if (en.value().is()) { staticSizes.push_back( - size.get().dyn_cast().getInt()); + en.value().get().dyn_cast().getInt()); continue; } // Otherwise, try to compute a constant upper bound for the size value. FailureOr upperBound = - getConstantUpperBoundForIndex(size.get()); + getConstantUpperBoundForIndex(en.value().get()); if (failed(upperBound)) { LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding"); return failure(); } staticSizes.push_back(upperBound.getValue()); } + assert(staticSizes.size() == shape.size() && + "expect the dynamic and static ranks to match"); // Pad the operand to the bounding box defined by `staticSizes`. auto staticTensorType = RankedTensorType::get( diff --git a/mlir/test/Dialect/Linalg/pad.mlir b/mlir/test/Dialect/Linalg/pad.mlir --- a/mlir/test/Dialect/Linalg/pad.mlir +++ b/mlir/test/Dialect/Linalg/pad.mlir @@ -426,3 +426,24 @@ %5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, %0] [1, 1] : tensor<4x?xf32> into tensor<24x25xf32> return %5 : tensor<24x25xf32> } + +// ----- + +#map0 = affine_map<()[s0] -> (64, s0)> + +// FILL: rank_reducing +// FILL-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<1x64x1x64xf32> +func @rank_reducing(%arg0: tensor<1x64x1x64xf32>, + %iv0 : index) -> tensor<1x?x?xf32> { + %cst = arith.constant 0.0 : f32 + %size = affine.min #map0()[%iv0] + %0 = tensor.extract_slice %arg0[0, 0, 0, 0] [1, %size, 1, %size] [1, 1, 1, 1] : tensor<1x64x1x64xf32> to tensor<1x?x?xf32> + + // Check the fill is padded despite the rank-reducing slice operation. + // FILL: %[[T0:.*]] = linalg.pad_tensor + // FILL: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]]) + // FILL-SAME: tensor<1x64x64xf32> + // FILL: = tensor.extract_slice %[[T1]] + %1 = linalg.fill(%cst, %0) : f32, tensor<1x?x?xf32> -> tensor<1x?x?xf32> + return %1 : tensor<1x?x?xf32> +}