diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -86,72 +86,28 @@ Value paddingValue = rewriter.create( opToPad.getLoc(), cast(paddingAttr)); - // Follow the use-def chain if `currOpOperand` is defined by a LinalgOp. - OpOperand *currOpOperand = opOperand; - while (auto linalgOp = currOpOperand->get().getDefiningOp()) { - OpResult result = cast(currOpOperand->get()); - currOpOperand = linalgOp.getDpsInitOperand(result.getResultNumber()); - } - - SmallVector mixedSizes; - if (auto reifiableOp = - llvm::dyn_cast_or_null( - currOpOperand->get().getDefiningOp())) { - ReifiedRankedShapedTypeDims reifiedReturnShapes; - LogicalResult status = - reifiableOp.reifyResultShapes(rewriter, reifiedReturnShapes); - mixedSizes = reifiedReturnShapes[0]; - if (failed(status)) { - LLVM_DEBUG(DBGS() << "--failed to reify result shapes\n"); - return rewriter.notifyMatchFailure(opToPad, - "failed to reify result shapes"); - } - } else if (hasStaticShape) { - mixedSizes = getAsIndexOpFoldResult(rewriter.getContext(), shape); - } else { - // TODO: may want to add support for going through loop iter args. - // This is not strictly necessary as we can pad before hoisting but it would - // make the system more resilient to minor transformation reordering. - LLVM_DEBUG(DBGS() << "--not a ReifyRankedShapedTypeOpInterface op\n"); - return rewriter.notifyMatchFailure( - opToPad, "not a ReifyRankedShapedTypeOpInterface op"); - } - LLVM_DEBUG(llvm::interleaveComma(mixedSizes, DBGS() << "--mixedSizes: "); - llvm::dbgs() << "\n"); - // Upper bound the sizes to obtain a static bounding box. SmallVector paddedShape(shape.begin(), shape.end()); - int64_t shapeIdx = 0; - for (const auto &en : enumerate(mixedSizes)) { - LLVM_DEBUG(DBGS() << "----mixedSizes: " << en.value() << "\n"); + for (int64_t i = 0, e = shape.size(); i < e; ++i) { + LLVM_DEBUG(DBGS() << "--compute padded size for dim " << i << "\n"); // Skip dimensions that do not require padding. - if (!shapeDimsToPad.contains(shapeIdx)) { - shapeIdx++; - LLVM_DEBUG(DBGS() << "------dim does not require padding, SKIP\n"); - continue; - } - // If the size is an attribute add it directly to `paddedShape`. - if (en.value().is()) { - paddedShape[shapeIdx++] = - dyn_cast(en.value().get()).getInt(); - LLVM_DEBUG( - DBGS() << "------dim is an attr, add it to padded shape, SKIP\n"); + if (!shapeDimsToPad.contains(i)) { + LLVM_DEBUG(DBGS() << "----dim does not require padding, SKIP\n"); continue; } // Otherwise, try to compute a constant upper bound for the size value. FailureOr upperBound = ValueBoundsConstraintSet::computeConstantBound( - presburger::BoundType::UB, en.value().get(), - /*dim=*/std::nullopt, /*stopCondition=*/nullptr, /*closedUB=*/true); + presburger::BoundType::UB, opOperand->get(), + /*dim=*/i, /*stopCondition=*/nullptr, /*closedUB=*/true); if (failed(upperBound)) { - LLVM_DEBUG(DBGS() << "--count not compute a bounding box for padding"); + LLVM_DEBUG(DBGS() << "----count not compute a bounding box for padding"); return rewriter.notifyMatchFailure( opToPad, "count not compute a bounding box for padding"); } - paddedShape[shapeIdx++] = *upperBound; + paddedShape[i] = *upperBound; + LLVM_DEBUG(DBGS() << "----new dim size: " << paddedShape[i] << "\n"); } - assert(shapeIdx == static_cast(shape.size()) && - "expect the dynamic and static ranks to match"); // Pad the operand to the bounding box defined by `paddedShape`. auto paddedTensorType = RankedTensorType::get(