diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -93,29 +93,28 @@ currOpOperand = linalgOp.getDpsInitOperand(result.getResultNumber()); } - // Fail if `currOpOperand` is not defined by an ExtractSliceOp or EmptyOp. - auto sliceOp = currOpOperand->get().getDefiningOp(); - auto emptyOp = currOpOperand->get().getDefiningOp(); - - llvm::SmallBitVector droppedDims; SmallVector mixedSizes; - if (sliceOp) { - // Compute the dropped dimensions if `sliceOp` is rank-reducing. - droppedDims = sliceOp.getDroppedDims(); - mixedSizes = sliceOp.getMixedSizes(); - } else if (emptyOp) { - mixedSizes = emptyOp.getMixedSizes(); - droppedDims.resize(mixedSizes.size()); + if (auto reifiableOp = + llvm::dyn_cast_or_null( + currOpOperand->get().getDefiningOp())) { + ReifiedRankedShapedTypeDims reifiedReturnShapes; + LogicalResult status = + reifiableOp.reifyResultShapes(rewriter, reifiedReturnShapes); + mixedSizes = reifiedReturnShapes[0]; + if (failed(status)) { + LLVM_DEBUG(DBGS() << "--failed to reify result shapes\n"); + return rewriter.notifyMatchFailure(opToPad, + "failed to reify result shapes"); + } } else if (hasStaticShape) { mixedSizes = getAsIndexOpFoldResult(rewriter.getContext(), shape); - droppedDims.resize(mixedSizes.size()); } else { // TODO: may want to add support for going through loop iter args. // This is not strictly necessary as we can pad before hoisting but it would // make the system more resilient to minor transformation reordering. - LLVM_DEBUG(DBGS() << "--not defined by an extractSlice or emptyOp\n"); + LLVM_DEBUG(DBGS() << "--not a ReifyRankedShapedTypeOpInterface op\n"); return rewriter.notifyMatchFailure( - opToPad, "not defined by an extractSlice or emptyOp"); + opToPad, "not a ReifyRankedShapedTypeOpInterface op"); } LLVM_DEBUG(llvm::interleaveComma(mixedSizes, DBGS() << "--mixedSizes: "); llvm::dbgs() << "\n"); @@ -125,11 +124,6 @@ int64_t shapeIdx = 0; for (const auto &en : enumerate(mixedSizes)) { LLVM_DEBUG(DBGS() << "----mixedSizes: " << en.value() << "\n"); - // Skip dropped dimensions. - if (droppedDims.test(en.index())) { - LLVM_DEBUG(DBGS() << "------dim is dropped, SKIP\n"); - continue; - } // Skip dimensions that do not require padding. if (!shapeDimsToPad.contains(shapeIdx)) { shapeIdx++;