diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -2153,20 +2153,21 @@ SmallVector tiled; SmallVector, 4> loops; loops.resize(getLoops().size()); - for (auto &en : llvm::enumerate(targets)) { - auto linalgOp = dyn_cast(en.value()); - if (!linalgOp) { - DiagnosedSilenceableFailure diag = emitSilenceableError() - << "only linalg ops are supported"; - diag.attachNote(en.value()->getLoc()) << "target op"; + for (auto &[i, op] : llvm::enumerate(targets)) { + auto tilingInterface = dyn_cast(op); + auto dpsInterface = dyn_cast(op); + if (!tilingInterface || !dpsInterface) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() << "only ops implementing TilingInterface and " + "DestinationStyleOpInterface are supported"; + diag.attachNote(op->getLoc()) << "target op"; return diag; } scf::SCFTilingOptions tilingOptions; - unsigned index = en.index(); if (!tileSizes.empty()) { - tilingOptions.setTileSizeComputationFunction([&, index](OpBuilder &b, - Operation *) { + tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b, + Operation *) { SmallVector sizes; sizes.reserve(tileSizes.size()); unsigned dynamicIdx = 0; @@ -2193,18 +2194,16 @@ } tilingOptions.setInterchange(getInterchange()); - TrivialPatternRewriter rewriter(linalgOp.getContext()); - FailureOr maybeTilingResult = tileUsingSCFForOp( - rewriter, cast(linalgOp.getOperation()), - tilingOptions); + TrivialPatternRewriter rewriter(op->getContext()); + FailureOr maybeTilingResult = + tileUsingSCFForOp(rewriter, tilingInterface, tilingOptions); if (failed(maybeTilingResult)) return DiagnosedSilenceableFailure::definiteFailure(); - if (linalgOp.hasBufferSemantics()) - rewriter.eraseOp(linalgOp); + if (dpsInterface.hasBufferSemantics()) + rewriter.eraseOp(op); else - rewriter.replaceOp(linalgOp, - maybeTilingResult->loops.front()->getResults()); + rewriter.replaceOp(op, maybeTilingResult->loops.front()->getResults()); tiled.append(maybeTilingResult->tiledOps); for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))