diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -187,13 +187,13 @@ return failure(hasDynamicShape); // Upper bound the `sliceOp` sizes to obtain a static bounding box. - SmallVector staticSizes; - staticSizes.reserve(opToPad.getRank(opOperand)); auto shapedOp = cast(sliceOp.getOperation()); + SmallVector upperBounds; + upperBounds.reserve(shapedOp.getMixedSizes().size()); for (auto size : shapedOp.getMixedSizes()) { - // If the size is an attribute add it directly to `staticSizes`. + // If the size is an attribute add it directly to `upperBounds`. if (size.is()) { - staticSizes.push_back( + upperBounds.push_back( size.get().dyn_cast().getInt()); continue; } @@ -204,8 +204,22 @@ LLVM_DEBUG(DBGS() << "No constant bounding box can be found for padding"); return failure(); } - staticSizes.push_back(upperBound.getValue()); + upperBounds.push_back(upperBound.getValue()); + } + + // Drop size one dimensions found in `upperBounds` if `sliceOp` is + // rank-reducing. All size one dimensions not found in the `shape` of the + // operand to pad need to be dropped. + SmallVector staticSizes; + staticSizes.reserve(shape.size()); + for (auto en : llvm::enumerate(shapedOp.getMixedSizes())) { + if (getConstantIntValue(en.value()) == static_cast(1) && + staticSizes.size() < shape.size() && shape[staticSizes.size()] != 1) + continue; + staticSizes.push_back(upperBounds[en.index()]); } + assert(staticSizes.size() == shape.size() && + "expect the dynamic and static ranks to match"); // Pad the operand to the bounding box defined by `staticSizes`. auto staticTensorType = RankedTensorType::get( diff --git a/mlir/test/Dialect/Linalg/pad.mlir b/mlir/test/Dialect/Linalg/pad.mlir --- a/mlir/test/Dialect/Linalg/pad.mlir +++ b/mlir/test/Dialect/Linalg/pad.mlir @@ -426,3 +426,24 @@ %5 = tensor.insert_slice %4 into %arg2[%iv0, %iv1] [4, %0] [1, 1] : tensor<4x?xf32> into tensor<24x25xf32> return %5 : tensor<24x25xf32> } + +// ----- + +#map0 = affine_map<()[s0] -> (64, s0)> + +// FILL: rank_reducing +// FILL-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<1x64x1x64xf32> +func @rank_reducing(%arg0: tensor<1x64x1x64xf32>, + %iv0 : index) -> tensor<1x?x?xf32> { + %cst = arith.constant 0.0 : f32 + %size = affine.min #map0()[%iv0] + %0 = tensor.extract_slice %arg0[0, 0, 0, 0] [1, %size, 1, %size] [1, 1, 1, 1] : tensor<1x64x1x64xf32> to tensor<1x?x?xf32> + + // Check the fill is padded despite the rank-reducing slice operation. + // FILL: %[[T0:.*]] = linalg.pad_tensor + // FILL: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]]) + // FILL-SAME: tensor<1x64x64xf32> + // FILL: = tensor.extract_slice %[[T1]] + %1 = linalg.fill(%cst, %0) : f32, tensor<1x?x?xf32> -> tensor<1x?x?xf32> + return %1 : tensor<1x?x?xf32> +}