diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -147,26 +147,37 @@ return *this; } -/// Helper function that tries to pad `opOperand`. Exit early and return success -/// for scalar operands or if `paddingFunc` returns failure. Otherwise, try to -/// pad the operand even if it already has a static shape. Set `result` to the -/// result of the created PadTensorOp or return failure if the operand cannot be -/// padded to a static shape. +/// Helper function that tries to pad `opOperand`. Exit early for scalar +/// operands, if `paddingFunc` returns failure, or if `opOperand` is not defined +/// by an ExtractSliceOp. Otherwise, try to pad the operand even if it already +/// has a static shape. Set `result` to the result of the created PadTensorOp or +/// and return success if the operand either has been padded to a static shape +/// or already had a static shape and failure otherwise. static LogicalResult padOperandToSmallestStaticBoundingBox( OpBuilder &b, linalg::LinalgOp opToPad, OpOperand *opOperand, const PaddingValueComputationFunction &paddingFunc, const PaddingNoFoldComputationFunction &nofoldFunc, Value &result) { - // Can't pad scalars. - if (opToPad.getShape(opOperand).empty()) + // Get the shape of the operand and check if it has a dynamic shape. Only + // return failure if the operand is not a scalar and has a dynamic shape. + ArrayRef shape = opToPad.getShape(opOperand); + bool hasDynamicShape = llvm::is_contained(shape, ShapedType::kDynamicSize); + + // Cannot pad scalar operands. + if (shape.empty()) return success(); - // Can't pad if no padding value is known. + + // Cannot pad if the padding value is unknown. FailureOr paddingValue = paddingFunc(b, *opOperand); if (failed(paddingValue)) - return success(); + return failure(hasDynamicShape); + + // Cannot construct a static bounding box if the operand is not defined by an + // ExtractSliceOp. auto sliceOp = opOperand->get().getDefiningOp(); - // Not a slice op, cannot construct a static bounding box. if (!sliceOp) - return failure(); + return failure(hasDynamicShape); + + // Upper bound the `sliceOp` sizes to obtain a static bounding box. SmallVector staticSizes; staticSizes.reserve(opToPad.getRank(opOperand)); auto shapedOp = cast(sliceOp.getOperation()); @@ -186,6 +197,8 @@ } staticSizes.push_back(upperBound.getValue()); } + + // Pad the operand to the bounding box defined by `staticSizes`. auto staticTensorType = RankedTensorType::get( staticSizes, getElementTypeOrSelf(opOperand->get())); bool nofold = nofoldFunc ? nofoldFunc(*opOperand) : false;