diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h @@ -15,6 +15,7 @@ namespace mlir { class TilingInterface; +class RewriterBase; namespace linalg { class GenericOp; class LinalgOp; @@ -33,6 +34,17 @@ namespace mlir { class DialectRegistry; +namespace transform { + +/// Implementation of tiling operations using `scf.foreach_thread`. +DiagnosedSilenceableFailure tileToForeachThreadOpImpl( + RewriterBase &rewriter, transform::TransformState &state, + TransformOpInterface transformOp, ArrayRef targets, + ArrayRef mixedNumThreads, + ArrayRef mixedTileSizes, Optional threadDimMapping, + SmallVector &tileOps, SmallVector &tiledOps); +} // namespace transform + namespace linalg { void registerTransformDialectExtension(DialectRegistry ®istry); } // namespace linalg diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1339,22 +1339,15 @@ // TileToForeachThreadOp //===----------------------------------------------------------------------===// -DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply( - transform::TransformResults &transformResults, - transform::TransformState &state) { - IRRewriter rewriter(getContext()); - ArrayRef targets = state.getPayloadOps(getTarget()); - - // If there the target payload ops are empty, there is nothing to do. - if (targets.empty()) { - transformResults.set(getForeachThreadOp().cast(), {}); - transformResults.set(getTiledOp().cast(), {}); +DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl( + RewriterBase &rewriter, transform::TransformState &state, + TransformOpInterface transformOp, ArrayRef targets, + ArrayRef mixedNumThreads, + ArrayRef mixedTileSizes, Optional threadDimMapping, + SmallVector &tileOps, SmallVector &tiledOps) { + + if (targets.empty()) return DiagnosedSilenceableFailure(success()); - } - - // Result payload ops. - SmallVector tileOps; - SmallVector tiledOps; // Given a list of OpFoldResults that are either index attrs or op handles, // return a list of OpFoldResults where all op handles are replaced with the @@ -1372,7 +1365,7 @@ state.getPayloadOps(ofr.get()); if (dynamicNumThreads.size() != 1) { DiagnosedSilenceableFailure diag = - emitSilenceableError() + transformOp.emitSilenceableError() << "handle must be mapped to exactly 1 payload op"; diag.attachNote(ofr.get().getLoc()) << "mapped to " << dynamicNumThreads.size() << " ops"; @@ -1382,7 +1375,7 @@ if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { DiagnosedSilenceableFailure diag = - emitSilenceableError() + transformOp.emitSilenceableError() << "payload op must have exactly 1 index result"; diag.attachNote(op->getLoc()) << "has " << op->getNumResults() << " results"; @@ -1398,14 +1391,14 @@ // Convert to OpFoldResults[index attributes or payload op]. SmallVector numThreads; DiagnosedSilenceableFailure status = - getOpResultsOrIndexAttrs(numThreads, getMixedNumThreads()); + getOpResultsOrIndexAttrs(numThreads, mixedNumThreads); if (!status.succeeded()) return status; // getMixedTileSizes are OpFoldResults[index attributes or PDL operation]. // Convert to OpFoldResults[index attributes or payload op]. SmallVector tileSizes; - status = getOpResultsOrIndexAttrs(tileSizes, getMixedTileSizes()); + status = getOpResultsOrIndexAttrs(tileSizes, mixedTileSizes); if (!status.succeeded()) return status; @@ -1414,19 +1407,20 @@ auto tilableOp = dyn_cast(target); if (!tilableOp) { DiagnosedSilenceableFailure diag = - emitSilenceableError() << "only TilingInterface ops are supported"; + transformOp.emitSilenceableError() + << "only TilingInterface ops are supported"; diag.attachNote(target->getLoc()) << "target op"; return diag; } rewriter.setInsertionPoint(tilableOp); - auto maybeThreadDimMappingAttr = getThreadDimMapping(); + auto maybeThreadDimMappingAttr = threadDimMapping; auto dimMapping = llvm::to_vector( maybeThreadDimMappingAttr ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr) : ArrayRef{}); - FailureOr tilingResult = failure(); - if (!getMixedNumThreads().empty()) { + FailureOr tilingResult = failure(); + if (!mixedNumThreads.empty()) { tilingResult = linalg::tileToForeachThreadOp(rewriter, tilableOp, numThreads, dimMapping); } else { @@ -1435,12 +1429,32 @@ } if (failed(tilingResult)) - return emitDefaultSilenceableFailure(tilableOp); + return transformOp.emitDefaultSilenceableFailure(tilableOp); rewriter.replaceOp(tilableOp, tilingResult->tileOp->getResults()); tileOps.push_back(tilingResult->tileOp); tiledOps.push_back(tilingResult->tiledOp); } + return DiagnosedSilenceableFailure(success()); +} + +DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply( + transform::TransformResults &transformResults, + transform::TransformState &state) { + IRRewriter rewriter(getContext()); + ArrayRef targets = state.getPayloadOps(getTarget()); + + // Result payload ops. + SmallVector tileOps; + SmallVector tiledOps; + + DiagnosedSilenceableFailure diag = tileToForeachThreadOpImpl( + rewriter, state, cast(getOperation()), targets, + getMixedNumThreads(), getMixedTileSizes(), getThreadDimMapping(), tileOps, + tiledOps); + + if (!diag.succeeded()) + return diag; transformResults.set(getForeachThreadOp().cast(), tileOps); transformResults.set(getTiledOp().cast(), tiledOps);