diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -181,6 +181,99 @@ #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS }; +class DiagnosedDefiniteFailure; + +DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, + const Twine &message = {}); + +/// A compatibility class connecting `InFlightDiagnostic` to +/// `DiagnosedSilenceableFailure` while providing an interface similar to the +/// former. Implicitly convertible to `DiagnosticSilenceableFailure` in definite +/// failure state and to `LogicalResult` failure. Reports the error on +/// conversion or on destruction. Instances of this class can be created by +/// `emitDefiniteFailure()`. +class DiagnosedDefiniteFailure { + friend DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, + const Twine &message); + +public: + /// Only move-constructible because it carries an in-flight diagnostic. + DiagnosedDefiniteFailure(DiagnosedDefiniteFailure &&) = default; + + /// Forward the message to the diagnostic. + template + DiagnosedDefiniteFailure &operator<<(T &&value) & { + diag << std::forward(value); + return *this; + } + template + DiagnosedDefiniteFailure &&operator<<(T &&value) && { + return std::move(this->operator<<(std::forward(value))); + } + + /// Attaches a note to the error. + Diagnostic &attachNote(Optional loc = llvm::None) { + return diag.attachNote(loc); + } + + /// Implicit conversion to DiagnosedSilenceableFailure in the definite failure + /// state. Reports the error. + operator DiagnosedSilenceableFailure() { + diag.report(); + return DiagnosedSilenceableFailure::definiteFailure(); + } + + /// Implicit conversion to LogicalResult in the failure state. Reports the + /// error. + operator LogicalResult() { + diag.report(); + return failure(); + } + +private: + /// Constructs a definite failure at the given location with the given + /// message. + explicit DiagnosedDefiniteFailure(Location loc, const Twine &message) + : diag(emitError(loc, message)) {} + + /// Copy-construction and any assignment is disallowed to prevent repeated + /// error reporting. + DiagnosedDefiniteFailure(const DiagnosedDefiniteFailure &) = delete; + DiagnosedDefiniteFailure & + operator=(const DiagnosedDefiniteFailure &) = delete; + DiagnosedDefiniteFailure &operator=(DiagnosedDefiniteFailure &&) = delete; + + /// The error message. + InFlightDiagnostic diag; +}; + +/// Emits a definite failure with the given message. The returned object allows +/// for last-minute modification to the error message, such as attaching notes +/// and completing the message. It will be reported when the object is +/// destructed or converted. +inline DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, + const Twine &message) { + return DiagnosedDefiniteFailure(loc, message); +} +inline DiagnosedDefiniteFailure emitDefiniteFailure(Operation *op, + const Twine &message = {}) { + return emitDefiniteFailure(op->getLoc(), message); +} + +/// Emits a silenceable failure with the given message. A silenceable failure +/// must be either suppressed or converted into a definite failure and reported +/// to the user. +inline DiagnosedSilenceableFailure +emitSilenceableFailure(Location loc, const Twine &message = {}) { + Diagnostic diag(loc, DiagnosticSeverity::Error); + diag << message; + return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); +} +inline DiagnosedSilenceableFailure +emitSilenceableFailure(Operation *op, const Twine &message = {}) { + return emitSilenceableFailure(op->getLoc(), message); +} + namespace transform { class TransformOpInterface; diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -63,20 +63,29 @@ } /// Creates the silenceable failure object with a diagnostic located at the - /// current operation. - DiagnosedSilenceableFailure emitSilenceableError() { - Diagnostic diag($_op->getLoc(), DiagnosticSeverity::Error); - return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); + /// current operation. Silenceable failure must be suppressed or reported + /// explicitly at some later time. + DiagnosedSilenceableFailure + emitSilenceableError(const ::llvm::Twine &message = {}) { + return ::mlir::emitSilenceableFailure($_op); + } + + /// Creates the definite failure object with a diagnostic located at the + /// current operation. Definite failure will be reported when the object + /// is destroyed or converted. + DiagnosedDefiniteFailure + emitDefiniteFailure(const ::llvm::Twine &message = {}) { + return ::mlir::emitDefiniteFailure($_op, message); } /// Creates the default silenceable failure for a transform op that failed /// to properly apply to a target. DiagnosedSilenceableFailure emitDefaultSilenceableFailure( Operation *target) { - Diagnostic diag($_op->getLoc(), DiagnosticSeverity::Error); + DiagnosedSilenceableFailure diag = emitSilenceableFailure($_op->getLoc()); diag << $_op->getName() << " failed to apply"; diag.attachNote(target->getLoc()) << "when applied to this op"; - return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); + return diag; } }]; } diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -324,8 +324,7 @@ if (transformOp.has_value()) { return transformOp->emitSilenceableError() << message; } - foreachThreadOp->emitError() << message; - return DiagnosedSilenceableFailure::definiteFailure(); + return emitDefiniteFailure(foreachThreadOp, message); }; if (foreachThreadOp.getNumResults() > 0) diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -470,10 +470,9 @@ } ArrayRef containingOps = state.getPayloadOps(getContainingOp()); if (containingOps.size() != 1) { - // Definite failure. - return DiagnosedSilenceableFailure( - this->emitOpError("requires exactly one containing_op handle (got ") - << containingOps.size() << ")"); + return emitDefiniteFailure() + << "requires exactly one containing_op handle (got " + << containingOps.size() << ")"; } Operation *containingOp = containingOps.front(); @@ -925,11 +924,11 @@ } if (splitPoints.size() != payload.size()) { - emitError() << "expected the dynamic split point handle to point to as " - "many operations (" - << splitPoints.size() << ") as the target handle (" - << payload.size() << ")"; - return DiagnosedSilenceableFailure::definiteFailure(); + return emitDefiniteFailure() + << "expected the dynamic split point handle to point to as " + "many operations (" + << splitPoints.size() << ") as the target handle (" + << payload.size() << ")"; } } else { splitPoints.resize(payload.size(), diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -177,17 +177,16 @@ for (Operation *original : originals) { if (original->isAncestor(getOperation())) { - InFlightDiagnostic diag = - emitError() << "scope must not contain the transforms being applied"; + auto diag = emitDefiniteFailure() + << "scope must not contain the transforms being applied"; diag.attachNote(original->getLoc()) << "scope"; - return DiagnosedSilenceableFailure::definiteFailure(); + return diag; } if (!original->hasTrait()) { - InFlightDiagnostic diag = - emitError() - << "only isolated-from-above ops can be alternative scopes"; + auto diag = emitDefiniteFailure() + << "only isolated-from-above ops can be alternative scopes"; diag.attachNote(original->getLoc()) << "scope"; - return DiagnosedSilenceableFailure(std::move(diag)); + return diag; } } @@ -523,8 +522,8 @@ for (Operation *root : state.getPayloadOps(getRoot())) { if (failed(extension->findAllMatches( getPatternName().getLeafReference().getValue(), root, targets))) { - emitOpError() << "could not find pattern '" << getPatternName() << "'"; - return DiagnosedSilenceableFailure::definiteFailure(); + emitDefiniteFailure() + << "could not find pattern '" << getPatternName() << "'"; } } results.set(getResult().cast(), targets); diff --git a/mlir/test/Dialect/Transform/transform-state-extension.mlir b/mlir/test/Dialect/Transform/transform-state-extension.mlir --- a/mlir/test/Dialect/Transform/transform-state-extension.mlir +++ b/mlir/test/Dialect/Transform/transform-state-extension.mlir @@ -44,3 +44,13 @@ test_check_if_test_extension_present %arg0 } } + +// ----- + +module { + transform.sequence failures(suppress) { + ^bb0(%arg0: !pdl.operation): + // expected-error @below {{TestTransformStateExtension missing}} + test_remap_operand_to_self %arg0 + } +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -188,10 +188,8 @@ DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply( transform::TransformResults &results, transform::TransformState &state) { auto *extension = state.getExtension(); - if (!extension) { - emitError() << "TestTransformStateExtension missing"; - return DiagnosedSilenceableFailure::definiteFailure(); - } + if (!extension) + return emitDefiniteFailure("TestTransformStateExtension missing"); if (failed(extension->updateMapping(state.getPayloadOps(getOperand()).front(), getOperation())))