diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -34,7 +34,6 @@ /// failures as their diagnostics have been already reported to the user. class [[nodiscard]] DiagnosedSilenceableFailure { public: - explicit DiagnosedSilenceableFailure(LogicalResult result) : result(result) {} DiagnosedSilenceableFailure(const DiagnosedSilenceableFailure &) = delete; DiagnosedSilenceableFailure & operator=(const DiagnosedSilenceableFailure &) = delete; @@ -156,6 +155,7 @@ } private: + explicit DiagnosedSilenceableFailure(LogicalResult result) : result(result) {} explicit DiagnosedSilenceableFailure(Diagnostic &&diagnostic) : result(failure()) { diagnostics.emplace_back(std::move(diagnostic)); diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -51,23 +51,12 @@ ]; let extraSharedClassDeclaration = [{ - /// Emits a generic transform error for the current transform operation - /// targeting the given Payload IR operation and returns failure. Should - /// be only used as a last resort when the transformation itself provides - /// no further indication as to the reason of the failure. - ::mlir::LogicalResult reportUnknownTransformError( - ::mlir::Operation *target) { - ::mlir::InFlightDiagnostic diag = $_op->emitError() << "failed to apply"; - diag.attachNote(target->getLoc()) << "attempted to apply to this op"; - return diag; - } - /// Creates the silenceable failure object with a diagnostic located at the /// current operation. Silenceable failure must be suppressed or reported /// explicitly at some later time. DiagnosedSilenceableFailure emitSilenceableError(const ::llvm::Twine &message = {}) { - return ::mlir::emitSilenceableFailure($_op); + return ::mlir::emitSilenceableFailure($_op, message); } /// Creates the definite failure object with a diagnostic located at the @@ -78,6 +67,17 @@ return ::mlir::emitDefiniteFailure($_op, message); } + /// Emits a generic definite failure for the current transform operation + /// targeting the given Payload IR operation and returns failure. Should + /// be only used as a last resort when the transformation itself provides + /// no further indication as to the reason of the failure. + DiagnosedDefiniteFailure emitDefaultDefiniteFailure( + ::mlir::Operation *target) { + auto diag = ::mlir::emitDefiniteFailure($_op, "failed to apply"); + diag.attachNote(target->getLoc()) << "attempted to apply to this op"; + return diag; + } + /// Creates the default silenceable failure for a transform op that failed /// to properly apply to a target. DiagnosedSilenceableFailure emitDefaultSilenceableFailure( diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -119,7 +119,7 @@ blkSizeX, blkSizeY, blkSizeZ); rewriter.setInsertionPointToEnd(&launchOp.getBody().front()); rewriter.create(loc); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } /// Alter kernel configuration of the given kernel. diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -79,20 +79,20 @@ Conv1DNwcWcfOp>>(target); if (succeeded(windowedNhwc)) { results.push_back(*windowedNhwc); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } FailureOr windowedNchw = tryApply>(target); if (succeeded(windowedNchw)) { results.push_back(*windowedNchw); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } FailureOr depthwise = tryApply(target); if (succeeded(depthwise)) { results.push_back(*depthwise); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } results.assign(1, nullptr); return emitDefaultSilenceableFailure(target); @@ -206,7 +206,8 @@ return tileConsumerAndFuseProducerGreedilyUsingSCFForOp( rewriter, tilingInterfaceOp, tileAndFuseOptions); }); - return DiagnosedSilenceableFailure(result); + return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() + : DiagnosedSilenceableFailure::success(); } ParseResult transform::FuseOp::parse(OpAsmParser &parser, @@ -568,12 +569,12 @@ // Exit early if no transformation is needed. if (isa(target)) { results.push_back(target); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } FailureOr generic = tryApply(target); if (succeeded(generic)) { results.push_back(generic->getOperation()); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } results.assign(1, nullptr); return emitDefaultSilenceableFailure(target); @@ -592,7 +593,7 @@ // Exit early if no transformation is needed. if (interchangeVector.empty()) { results.push_back(target); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } TrivialPatternRewriter rewriter(target->getContext()); FailureOr res = @@ -600,7 +601,7 @@ if (failed(res)) return DiagnosedSilenceableFailure::definiteFailure(); results.push_back(res->getOperation()); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } LogicalResult transform::InterchangeOp::verify() { @@ -639,8 +640,7 @@ ArrayRef payloadOps = state.getPayloadOps(getTarget()); if (payloadOps.size() != 1) { results.set(getResult().cast(), {}); - return DiagnosedSilenceableFailure( - this->emitOpError("requires exactly one target handle")); + return emitDefiniteFailure("requires exactly one target handle"); } SmallVector res; @@ -687,7 +687,7 @@ payloadOps.front()->walk(matchFun); results.set(getResult().cast(), res); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } //===---------------------------------------------------------------------===// @@ -792,7 +792,7 @@ tryApply(target, paddingOptions); if (succeeded(result)) { results.push_back(result->getOperation()); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } results.assign(1, nullptr); @@ -866,15 +866,15 @@ promotionOptions = promotionOptions.setAlignment(*getAlignment()); if (failed(promoteSubviewsPrecondition(target, promotionOptions))) - return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + return emitDefaultDefiniteFailure(target); TrivialPatternRewriter rewriter(target->getContext()); rewriter.setInsertionPoint(target); FailureOr res = promoteSubViews(rewriter, target, promotionOptions); if (failed(res)) - return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + return emitDefaultDefiniteFailure(target); results.push_back(target); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// @@ -909,7 +909,7 @@ replacements.push_back(replacement); } transformResults.set(getReplacement().cast(), replacements); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } void transform::ReplaceOp::getEffects( @@ -972,10 +972,10 @@ FailureOr maybeTilingResult = tileUsingSCFForOp( rewriter, cast(target.getOperation()), tilingOptions); if (failed(maybeTilingResult)) - return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + return emitDefaultDefiniteFailure(target); results.append(maybeTilingResult->tiledOps); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// @@ -1171,13 +1171,13 @@ ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc()) : splitReduction(rewriter, target, splitFn, getUseAlloc()); if (failed(splitResult)) - return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + return emitDefaultDefiniteFailure(target); results.push_back(splitResult->initOrAlloc); results.push_back(splitResult->fillOp); results.push_back(splitResult->splitLinalgOp); results.push_back(splitResult->resultCombiningLinalgOp); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// @@ -1200,12 +1200,12 @@ sizes); if (failed(result)) - return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + return emitDefaultSilenceableFailure(target); results.push_back(result->loops.front()); results.push_back(result->initialOp); results.push_back(result->parallelTiledOp); results.push_back(result->mergeOp); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// @@ -1235,7 +1235,7 @@ results.push_back(result->initialOp); results.push_back(result->parallelTiledOp); results.push_back(result->mergeOp); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// @@ -1523,7 +1523,7 @@ } } - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl( @@ -1533,7 +1533,7 @@ ArrayRef mixedTileSizes, Optional mapping, SmallVector &tileOps, SmallVector &tiledOps) { if (targets.empty()) - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); // getMixedNumThreads are OpFoldResults[index attributes or PDL operation]. // Convert to OpFoldResults[index attributes or payload op]. @@ -1577,7 +1577,7 @@ tileOps.push_back(tilingResult->tileOp); tiledOps.push_back(tilingResult->tiledOp); } - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply( @@ -1604,7 +1604,7 @@ transformResults.set(getForeachThreadOp().cast(), tileOps); transformResults.set(getTiledOp().cast(), tiledOps); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } void transform::TileToForeachThreadOp::getEffects( @@ -1852,10 +1852,10 @@ linalg::populatePadOpVectorizationPatterns(patterns); if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) - return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + return emitDefaultDefiniteFailure(target); results.push_back(target); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -33,7 +33,7 @@ } results.push_back(newBuffer.value()); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -103,10 +103,8 @@ FailureOr outlined = outlineSingleBlockRegion( rewriter, location, exec.getRegion(), getFuncName(), &call); - if (failed(outlined)) { - (void)reportUnknownTransformError(target); - return DiagnosedSilenceableFailure::definiteFailure(); - } + if (failed(outlined)) + return emitDefaultDefiniteFailure(target); if (symbolTableOp) { SymbolTable &symbolTable = @@ -139,7 +137,7 @@ scf::peelAndCanonicalizeForLoop(rewriter, target, result); // TODO: Return both the peeled loop and the remainder loop. results.push_back(failed(status) ? target : result); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===// @@ -200,7 +198,7 @@ pattern.returningMatchAndRewrite(target, rewriter); if (succeeded(patternResult)) { results.push_back(*patternResult); - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } results.assign(1, nullptr); return emitDefaultSilenceableFailure(target); @@ -225,7 +223,7 @@ diag << "Op failed to unroll"; return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); } - return DiagnosedSilenceableFailure(success()); + return DiagnosedSilenceableFailure::success(); } //===----------------------------------------------------------------------===//