diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -123,18 +123,17 @@ /// created PadTensorOp. /// Return failure if the operand cannot be padded to a static shape. static LogicalResult padOperandToSmallestStaticBoundingBox( - PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand &operand, + PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand *opOperand, const LinalgTilingOptions &options, Value &result) { - auto tensorType = operand.get().getType().cast(); // Already static shape, no need to pad. - if (tensorType.hasStaticShape()) + if (llvm::none_of(opToPad.getShape(opOperand), ShapedType::isDynamic)) return success(); - auto subtensor = operand.get().getDefiningOp(); + auto subtensor = opOperand->get().getDefiningOp(); // Not a subtensor, cannot construct a static bounding box. if (!subtensor) return failure(); SmallVector staticSizes; - staticSizes.reserve(tensorType.getRank()); + staticSizes.reserve(opToPad.getRank(opOperand)); auto shapedOp = cast(subtensor.getOperation()); for (auto size : shapedOp.getMixedSizes()) { @@ -148,11 +147,11 @@ opToPad, "No constant bounding box can be found for padding"); staticSizes.push_back(indexAttr.getInt()); } - Value pad = options.paddingValueComputationFunction(rewriter, operand); - auto staticTensorType = - RankedTensorType::get(staticSizes, tensorType.getElementType()); + Value pad = options.paddingValueComputationFunction(rewriter, *opOperand); + auto staticTensorType = RankedTensorType::get( + staticSizes, getElementTypeOrSelf(opOperand->get().getType())); result = linalg::PadTensorOp::createPadHighOp( - staticTensorType, operand.get(), pad, opToPad->getLoc(), rewriter); + staticTensorType, opOperand->get(), pad, opToPad->getLoc(), rewriter); return success(); } @@ -183,9 +182,8 @@ // If padding was requested but the shape cannot be bounded statically then // the pattern fails to apply. if (failed(padOperandToSmallestStaticBoundingBox( - rewriter, opToPad, *opOperand, options, paddedOperand))) { + rewriter, opToPad, opOperand, options, paddedOperand))) return failure(); - } newOperands.push_back(paddedOperand ? paddedOperand : opOperand->get()); }