diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td @@ -151,6 +151,38 @@ transform operations can return _new_ handles that can be read or consumed by subsequent operations. + ## Execution Model + + The transformation starts at the specifed top-level transform IR operation + and applies to some payload IR scope, identified by the payload IR op that + contains the IR to transform. It is the responsibility of the user to + properly select the scope and/or to avoid the transformations to modify the + IR outside of the given scope. The top-level transform IR operation may + contain further transform operations and execute them in the desired order. + + Transformation application functions produce a tri-state status: + + - success; + - recoverable (silencable) failure; + - irrecoverable failure. + + Transformation container operations may intercept recoverable failures and + perform the required recovery steps thus succeeding themselves. On + the other hand, they must propagate irrecoverable failures. For such + failures, the diagnostics are emitted immediately whereas their emission is + postponed for recoverable faliures. Transformation container operations may + also fail to recover from a theoretically recoverable failure, in which case + they are expected to emit the diagnostic and turn the failure into an + irrecoverable one. A recoverable failure produced by applying the top-level + transform IR operation is considered irrecoverable. + + Transformation container operations are allowed to "step over" some nested + operations if the application of some previous operation produced a failure. + This can be conceptually thought of as having a global "recoverable error + register" that is read/write accessed by each transform operation as a side + effect. The transformation is skipped if the register already contains an + error description, and the control flow proceeds to the following operation. + ## Handle Invalidation The execution model of the transform dialect expects that a payload IR diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -14,6 +14,129 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" namespace mlir { + +/// The result of a transform IR operation application. This can have one of the +/// three states: +/// - success; +/// - silencable (recoverable) failure with yet-unreported diagnostic; +/// - definite failure. +/// Silencable failure is intended to communicate information about +/// transformations that did not apply but in a way that supports recovery, +/// for example, they did not modify the payload IR or modified it in some +/// predictable way. They are associated with a Diagnostic that provides more +/// details on the failure. Silencable failure can be discarded, turning the +/// result into success, or "reported", emitting the diagnostic and turning the +/// result into definite failure. Transform IR operations containing other +/// operations are allowed to do either with the results of the nested +/// transformations, but must propagate definite failures as their diagnostics +/// have been already reported to the user. +class LLVM_NODISCARD DiagnosedSilencableFailure { +public: + explicit DiagnosedSilencableFailure(LogicalResult result) : result(result) {} + DiagnosedSilencableFailure(const DiagnosedSilencableFailure &) = delete; + DiagnosedSilencableFailure & + operator=(const DiagnosedSilencableFailure &) = delete; + DiagnosedSilencableFailure(DiagnosedSilencableFailure &&) = default; + DiagnosedSilencableFailure & + operator=(DiagnosedSilencableFailure &&) = default; + + /// Constructs a DiagnosedSilencableFailure in the success state. + static DiagnosedSilencableFailure success() { + return DiagnosedSilencableFailure(::mlir::success()); + } + + /// Constructs a DiagnosedSilencableFailure in the failure state. Typically, + /// a diagnostic has been emitted before this. + static DiagnosedSilencableFailure definiteFailure() { + return DiagnosedSilencableFailure(::mlir::failure()); + } + + /// Constructs a DiagnosedSilencableFailure in the silencable failure state, + /// ready to emit the given diagnostic. This is considered a failure + /// regardless of the diagnostic severity. + static DiagnosedSilencableFailure silencableFailure(Diagnostic &&diag) { + return DiagnosedSilencableFailure(std::forward(diag)); + } + + /// Converts all kinds of failure into a LogicalResult failure, emitting the + /// diagnostic if necessary. Must not be called more than once. + LogicalResult checkAndReport() { +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + assert(!reported && "attempting to report a diagnostic more than once"); + reported = true; +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + if (diagnostic) { + diagnostic->getLocation().getContext()->getDiagEngine().emit( + std::move(*diagnostic)); + diagnostic.reset(); + result = ::mlir::failure(); + } + return result; + } + + /// Returns `true` if this is a silencable failure. + bool isSilencableFailure() const { return diagnostic.hasValue(); } + + /// Returns `true` if this is a success. + bool succeeded() const { + return !diagnostic.hasValue() && ::mlir::succeeded(result); + } + + /// Returns the diagnostic message without emitting it. Expects this object + /// to be a silencable failure. + std::string getMessage() const { return diagnostic->str(); } + + /// Converts silencable failure into LogicalResult success without reporting + /// the diagnostic, preserves the other states. + LogicalResult silence() { + if (diagnostic) { + diagnostic.reset(); + result = ::mlir::success(); + } + return result; + } + + /// Streams the given values into the diagnotic. Expects this object to be a + /// silencable failure. + template DiagnosedSilencableFailure &operator<<(T &&value) & { + assert(isSilencableFailure() && + "can only append output in silencable failure state"); + *diagnostic << std::forward(value); + return *this; + } + template DiagnosedSilencableFailure &&operator<<(T &&value) && { + return std::move(this->operator<<(std::forward(value))); + } + + /// Attaches a note to the diagnostic. Expects this object to be a silencable + /// failure. + Diagnostic &attachNote(Optional loc = llvm::None) { + assert(isSilencableFailure() && + "can only attach notes to silencable failures"); + return diagnostic->attachNote(loc); + } + +private: + explicit DiagnosedSilencableFailure(Diagnostic &&diagnostic) + : diagnostic(std::move(diagnostic)), result(failure()) {} + + /// The diagnostic associated with this object. If present, the object is + /// considered to be in the silencable failure state regardless of the + /// `result` field. + Optional diagnostic; + + /// The "definite" logical state, either success or failure. Ignored if the + /// diagnostic message is present. + LogicalResult result; + +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + /// Whther the associated diagnostic have been reported. Diagnostic reporting + /// consumes the diagnostic, so we need a mechanism to differentiate a + /// reported diagnostic from a state where it was never created. + bool reported = false; +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS +}; + namespace transform { class TransformOpInterface; @@ -103,7 +226,7 @@ /// Applies the transformation specified by the given transform op and updates /// the state accordingly. - LogicalResult applyTransform(TransformOpInterface transform); + DiagnosedSilencableFailure applyTransform(TransformOpInterface transform); /// Records the mapping between a block argument in the transform IR and a /// list of operations in the payload IR. The arguments must be defined in @@ -401,7 +524,7 @@ /// the payload IR, depending on what is available in the context. LogicalResult mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, - Operation *op); + Operation *op, unsigned region); /// Verification hook for PossibleTopLevelTransformOpTrait. LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op); @@ -411,7 +534,7 @@ /// can be standalone top-level transforms. Such operations typically contain /// other Transform dialect operations that can be executed following some /// control flow logic specific to the current operation. The operations with -/// this trait are expected to have exactly one single-block region with one +/// this trait are expected to have at least one single-block region with one /// argument of PDL Operation type. The operations are also expected to be valid /// without operands, in which case they are considered top-level, and with one /// or more arguments, in which case they are considered nested. Top-level @@ -430,16 +553,18 @@ return detail::verifyPossibleTopLevelTransformOpTrait(op); } - /// Returns the single block of the op's only region. - Block *getBodyBlock() { return &this->getOperation()->getRegion(0).front(); } + /// Returns the single block of the given region. + Block *getBodyBlock(unsigned region = 0) { + return &this->getOperation()->getRegion(region).front(); + } - /// Sets up the mapping between the entry block of the only region of this op + /// Sets up the mapping between the entry block of the given region of this op /// and the relevant list of Payload IR operations in the given state. The /// state is expected to be already scoped at the region of this operation. /// Returns failure if the mapping failed, e.g., the value is already mapped. - LogicalResult mapBlockArguments(TransformState &state) { + LogicalResult mapBlockArguments(TransformState &state, unsigned region = 0) { return detail::mapPossibleTopLevelTransformOpBlockArguments( - state, this->getOperation()); + state, this->getOperation(), region); } }; @@ -461,8 +586,8 @@ /// Calls `applyToOne` for every payload operation associated with the operand /// of this transform IR op. If `applyToOne` returns ops, associates them with /// the result of this transform op. - LogicalResult apply(TransformResults &transformResults, - TransformState &state); + DiagnosedSilencableFailure apply(TransformResults &transformResults, + TransformState &state); /// Checks that the op matches the expectations of this trait. static LogicalResult verifyTrait(Operation *op); @@ -497,22 +622,22 @@ StringRef getName() override { return "transform.payload_ir"; } }; -/// Trait implementing the MemoryEffectOpInterface for single-operand operations -/// that "consume" their operand and produce a new result. +/// Trait implementing the MemoryEffectOpInterface for operations that "consume" +/// their operands and produce new results. template class FunctionalStyleTransformOpTrait : public OpTrait::TraitBase { public: - /// This op "consumes" the operand by reading and freeing it, "produces" the - /// results by allocating and writing it and reads/writes the payload IR in - /// the process. + /// This op "consumes" the operands by reading and freeing then, "produces" + /// the results by allocating and writing it and reads/writes the payload IR + /// in the process. void getEffects(SmallVectorImpl &effects) { - effects.emplace_back(MemoryEffects::Read::get(), - this->getOperation()->getOperand(0), - TransformMappingResource::get()); - effects.emplace_back(MemoryEffects::Free::get(), - this->getOperation()->getOperand(0), - TransformMappingResource::get()); + for (Value operand : this->getOperation()->getOperands()) { + effects.emplace_back(MemoryEffects::Read::get(), operand, + TransformMappingResource::get()); + effects.emplace_back(MemoryEffects::Free::get(), operand, + TransformMappingResource::get()); + } for (Value result : this->getOperation()->getResults()) { effects.emplace_back(MemoryEffects::Allocate::get(), result, TransformMappingResource::get()); @@ -525,8 +650,6 @@ /// Checks that the op matches the expectations of this trait. static LogicalResult verifyTrait(Operation *op) { - static_assert(OpTy::template hasTrait(), - "expected single-operand op"); if (!op->getName().getInterface()) { op->emitError() << "FunctionalStyleTransformOpTrait should only be attached to ops " @@ -612,12 +735,12 @@ /// where OpTy is either /// - Operation *, in which case the transform is always applied; /// - a concrete Op class, in which case a check is performed whether -/// `targets` contains operations of the same class and a failure is reported -/// if it does not. +/// `targets` contains operations of the same class and a silencable failure +/// is reported if it does not. template -LogicalResult applyTransformToEach(ArrayRef targets, - SmallVectorImpl &results, - FnTy transform) { +DiagnosedSilencableFailure +applyTransformToEach(ArrayRef targets, + SmallVectorImpl &results, FnTy transform) { using OpTy = typename llvm::function_traits::template arg_t<0>; static_assert(std::is_convertible::value, "expected transform function to take an operation"); @@ -627,37 +750,43 @@ "FailureOr"); for (Operation *target : targets) { auto specificOp = dyn_cast(target); - if (!specificOp) - return failure(); + if (!specificOp) { + Diagnostic diag(target->getLoc(), DiagnosticSeverity::Error); + diag << "attempted to apply transform to the wrong op kind"; + return DiagnosedSilencableFailure::silencableFailure(std::move(diag)); + } auto result = transform(specificOp); if (failed(appendTransformResultToVector(result, results))) - return failure(); + return DiagnosedSilencableFailure::definiteFailure(); } - return success(); + return DiagnosedSilencableFailure::success(); } } // namespace detail } // namespace transform } // namespace mlir template -mlir::LogicalResult mlir::transform::TransformEachOpTrait::apply( +mlir::DiagnosedSilencableFailure +mlir::transform::TransformEachOpTrait::apply( TransformResults &transformResults, TransformState &state) { using TransformOpType = typename llvm::function_traits< decltype(&OpTy::applyToOne)>::template arg_t<0>; ArrayRef targets = state.getPayloadOps(this->getOperation()->getOperand(0)); SmallVector results; - if (failed(detail::applyTransformToEach( - targets, results, [&](TransformOpType specificOp) { - return static_cast(this)->applyToOne(specificOp); - }))) - return failure(); + DiagnosedSilencableFailure result = detail::applyTransformToEach( + targets, results, [&](TransformOpType specificOp) { + return static_cast(this)->applyToOne(specificOp); + }); + if (!result.succeeded()) + return result; + if (OpTy::template hasTrait()) { transformResults.set( this->getOperation()->getResult(0).template cast(), results); } - return success(); + return DiagnosedSilencableFailure::success(); } template diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -38,9 +38,11 @@ accepts as arguments the object that must be populated with results of the current transformation and a transformation state object that can be used for queries, e.g., to obtain the list of operations on which the - transformation represented by the current op is targeted. + transformation represented by the current op is targeted. Returns a + special status object indicating whether the transformation succeeded + or failed, and, if it failed, whether the failure is recoverable or not. }], - /*returnType=*/"::mlir::LogicalResult", + /*returnType=*/"::mlir::DiagnosedSilencableFailure", /*name=*/"apply", /*arguments=*/(ins "::mlir::transform::TransformResults &":$transformResults, @@ -59,6 +61,13 @@ diag.attachNote(target->getLoc()) << "attempted to apply to this op"; return diag; } + + /// Creates the silencable failure object with a diagnostic located at the + /// current operation. + DiagnosedSilencableFailure emitSilencableError() { + Diagnostic diag($_op->getLoc(), DiagnosticSeverity::Error); + return DiagnosedSilencableFailure::silencableFailure(std::move(diag)); + } }]; } diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -17,6 +17,83 @@ include "mlir/Dialect/Transform/IR/TransformEffects.td" include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +def AlternativesOp : TransformDialectOp<"alternatives", + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface, + IsolatedFromAbove, PossibleTopLevelTransformOpTrait, + SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> { + let summary = "Attempts sequences of transforms until one succeeds"; + let description = [{ + This op may have an arbitrary number of regions, each of which represents a + sequence of transform operations to be applied to the same payload IR. The + regions are visited in order of appearance, and transforms in them are + applied in their respective order of appearance. If one of these transforms + fails to apply, the remaining ops in the same region are skipped an the next + region is attempted. If all transformations in a region succeed, the + remaining regions are skipped and the entire "alternatives" transformation + succeeds. If all regions contained a failing transformation, the entire + "alternatives" transformation fails. + + It is up to the nested operations to define which errors are "recoverable" + (or "silencable") and allow another alternatives to be attempted, and which + errors should be propagated without attempting the other alternatives. + + The single operand of this operation is the scope in which the alternative + transformation sequences are attempted, that is, an operation in the payload + IR that contains all the other operations that may be modified by the + transformations. There is no check that the transforms are indeed scoped + as their "apply" methods can be arbitrarily complex. Therefore it is the + responsibility of the user to ensure that the transforms are scoped + correctly, or to produce an irrecoverable error and thus abort the execution + without attempting the remaining alternatives. Note that the payload IR + outside of the given scope is not necessarily in the valid state, or even + accessible to the tranfsormation. + + The changes to the IR within the scope performed by transforms in the failed + alternative region are reverted before attempting the next region. + Practically, this is achieved by cloning the scope. Therefore it is advised + to limit the scope as much as possible and place the most likely + alternatives early in the region list. The operation is also isolated from + above and requires rediscovering the operations within the given scope to + avoid additional handle invalidation. The latter restriction may be lifted + in the future. + + Each of the regions may yield transform IR handles. The handles of the first + successful alternative region are returned as the results of the + "alternatives" op. Therefore, each alternative region must yield the same + number of results, which should also match the number and the types of the + "alternatives" op results. + + Remark: this op allows one to implement a simple "try" construct as follows: + + ```mlir + %result = transform.alternatives %scope { + ^bb0(%arg0: !pdl.operation): + // Try a failible transformation. + %0 = transform.failible %arg0 // ... + // If succeeded, yield the the result of the transformation. + transform.yield %0 : !pdl.operation + }, { + ^bb0(%arg0: !pdl.operation): + // Otherwise, the second alternative is tried and it always succeeeds by + // returning the original handle. + transform.yield %arg0 : !pdl.operation + } + ``` + }]; + + let arguments = (ins Optional:$scope); + let results = (outs Variadic:$results); + let regions = (region VariadicRegion>:$alternatives); + + let assemblyFormat = + "($scope^)? (`->` type($results)^)? attr-dict-with-keyword regions"; + let hasVerifier = 1; +} + def GetClosestIsolatedParentOp : TransformDialectOp<"get_closest_isolated_parent", [DeclareOpInterfaceMethods, NavigationTransformOpTrait, MemoryEffectsOpInterface]> { diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -24,7 +24,7 @@ // OneShotBufferizeOp //===----------------------------------------------------------------------===// -LogicalResult +DiagnosedSilencableFailure transform::OneShotBufferizeOp::apply(TransformResults &transformResults, TransformState &state) { OneShotBufferizationOptions options; @@ -39,19 +39,19 @@ for (Operation *target : payloadOps) { auto moduleOp = dyn_cast(target); if (getTargetIsModule() && !moduleOp) - return emitError("expected ModuleOp target"); + return emitSilencableError() << "expected ModuleOp target"; if (options.bufferizeFunctionBoundaries) { if (!moduleOp) - return emitError("expected ModuleOp target"); + return emitSilencableError() << "expected ModuleOp target"; if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options))) - return emitError("bufferization failed"); + return emitSilencableError() << "bufferization failed"; } else { if (failed(bufferization::runOneShotBufferize(target, options))) - return emitError("bufferization failed"); + return emitSilencableError() << "bufferization failed"; } } - return success(); + return DiagnosedSilencableFailure::success(); } void transform::OneShotBufferizeOp::getEffects( diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -13,10 +13,8 @@ #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Parser/Parser.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/Support/FormatVariadic.h" using namespace mlir; using namespace mlir::linalg; @@ -166,14 +164,14 @@ return success(); } -LogicalResult +DiagnosedSilencableFailure transform::FuseOp::apply(mlir::transform::TransformResults &transformResults, mlir::transform::TransformState &state) { LinalgTilingAndFusionOptions fusionOptions; fusionOptions.tileSizes = extractI64Array(getTileSizes()); fusionOptions.tileInterchange = extractI64Array(getTileInterchange()); - return applyTilingToAll( + LogicalResult result = applyTilingToAll( getOperation(), getTarget(), fusionOptions.tileSizes, transformResults, state, [&](LinalgOp linalgOp) -> FailureOr { LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions); @@ -190,6 +188,8 @@ tileLoopNest->getLoopOps().end()}; return tiledLinalgOp; }); + return failed(result) ? DiagnosedSilencableFailure::definiteFailure() + : DiagnosedSilencableFailure::success(); } ParseResult transform::FuseOp::parse(OpAsmParser &parser, @@ -398,8 +398,9 @@ // TileOp //===----------------------------------------------------------------------===// -LogicalResult transform::TileOp::apply(TransformResults &transformResults, - TransformState &state) { +DiagnosedSilencableFailure +transform::TileOp::apply(TransformResults &transformResults, + TransformState &state) { LinalgTilingOptions tilingOptions; SmallVector tileSizes = extractI64Array(getSizes()); @@ -408,12 +409,13 @@ tilingOptions.setInterchange(extractUIntArray(getInterchange())); LinalgTilingPattern pattern(getContext(), tilingOptions); - return applyTilingToAll(getOperation(), getTarget(), tileSizes, - transformResults, state, [&](LinalgOp linalgOp) { - SimpleRewriter rewriter(linalgOp.getContext()); - return pattern.returningMatchAndRewrite(linalgOp, - rewriter); - }); + LogicalResult result = applyTilingToAll( + getOperation(), getTarget(), tileSizes, transformResults, state, + [&](LinalgOp linalgOp) { + SimpleRewriter rewriter(linalgOp.getContext()); + return pattern.returningMatchAndRewrite(linalgOp, rewriter); + }); + return DiagnosedSilencableFailure(result); } ParseResult transform::TileOp::parse(OpAsmParser &parser, diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/SCF/Transforms.h" #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" using namespace mlir; @@ -30,7 +31,7 @@ // GetParentForOp //===----------------------------------------------------------------------===// -LogicalResult +DiagnosedSilencableFailure transform::GetParentForOp::apply(transform::TransformResults &results, transform::TransformState &state) { SetVector parents; @@ -40,9 +41,10 @@ for (unsigned i = 0, e = getNumLoops(); i < e; ++i) { loop = current->getParentOfType(); if (!loop) { - InFlightDiagnostic diag = emitError() << "could not find an '" - << scf::ForOp::getOperationName() - << "' parent"; + DiagnosedSilencableFailure diag = emitSilencableError() + << "could not find an '" + << scf::ForOp::getOperationName() + << "' parent"; diag.attachNote(target->getLoc()) << "target op"; return diag; } @@ -51,7 +53,7 @@ parents.insert(loop); } results.set(getResult().cast(), parents.getArrayRef()); - return success(); + return DiagnosedSilencableFailure::success(); } //===----------------------------------------------------------------------===// @@ -83,7 +85,7 @@ return executeRegionOp; } -LogicalResult +DiagnosedSilencableFailure transform::LoopOutlineOp::apply(transform::TransformResults &results, transform::TransformState &state) { SmallVector transformed; @@ -94,7 +96,8 @@ SimpleRewriter rewriter(getContext()); scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target); if (!exec) { - InFlightDiagnostic diag = emitError() << "failed to outline"; + DiagnosedSilencableFailure diag = emitSilencableError() + << "failed to outline"; diag.attachNote(target->getLoc()) << "target op"; return diag; } @@ -102,8 +105,10 @@ FailureOr outlined = outlineSingleBlockRegion( rewriter, location, exec.getRegion(), getFuncName(), &call); - if (failed(outlined)) - return reportUnknownTransformError(target); + if (failed(outlined)) { + (void)reportUnknownTransformError(target); + return DiagnosedSilencableFailure::definiteFailure(); + } if (symbolTableOp) { SymbolTable &symbolTable = @@ -115,7 +120,7 @@ transformed.push_back(*outlined); } results.set(getTransformed().cast(), transformed); - return success(); + return DiagnosedSilencableFailure::success(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -10,8 +10,10 @@ #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Operation.h" -#include "llvm/ADT/ScopeExit.h" -#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "transform-dialect" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ") using namespace mlir; @@ -186,16 +188,18 @@ return success(); } -LogicalResult +DiagnosedSilencableFailure transform::TransformState::applyTransform(TransformOpInterface transform) { + LLVM_DEBUG(DBGS() << "applying: " << transform << "\n"); if (options.getExpensiveChecksEnabled() && failed(checkAndRecordHandleInvalidation(transform))) { - return failure(); + return DiagnosedSilencableFailure::definiteFailure(); } transform::TransformResults results(transform->getNumResults()); - if (failed(transform.apply(results, *this))) - return failure(); + DiagnosedSilencableFailure result(transform.apply(results, *this)); + if (!result.succeeded()) + return result; // Remove the mapping for the operand if it is consumed by the operation. This // allows us to catch use-after-free with assertions later on. @@ -219,10 +223,10 @@ "payload IR association for a value other than the result of the " "current transform op"); if (failed(setPayloadOps(result, results.get(result.getResultNumber())))) - return failure(); + return DiagnosedSilencableFailure::definiteFailure(); } - return success(); + return DiagnosedSilencableFailure::success(); } //===----------------------------------------------------------------------===// @@ -273,14 +277,14 @@ //===----------------------------------------------------------------------===// LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments( - TransformState &state, Operation *op) { + TransformState &state, Operation *op, unsigned region) { SmallVector targets; if (op->getNumOperands() != 0) llvm::append_range(targets, state.getPayloadOps(op->getOperand(0))); else targets.push_back(state.getTopLevel()); - return state.mapBlockArguments(op->getRegion(0).front().getArgument(0), + return state.mapBlockArguments(op->getRegion(region).front().getArgument(0), targets); } @@ -293,8 +297,8 @@ "should implement TransformOpInterface to have " "PossibleTopLevelTransformOpTrait"); - if (op->getNumRegions() != 1) - return op->emitOpError() << "expects one region"; + if (op->getNumRegions() < 1) + return op->emitOpError() << "expects at least one region"; Region *bodyRegion = &op->getRegion(0); if (!llvm::hasNItems(*bodyRegion, 1)) diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -10,13 +10,16 @@ #include "mlir/Dialect/PDL/IR/PDLOps.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" -#include "mlir/IR/Builders.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Rewrite/PatternApplicator.h" #include "llvm/ADT/ScopeExit.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "transform-dialect" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ") using namespace mlir; @@ -115,34 +118,174 @@ } } // namespace +//===----------------------------------------------------------------------===// +// AlternativesOp +//===----------------------------------------------------------------------===// + +OperandRange +transform::AlternativesOp::getSuccessorEntryOperands(unsigned index) { + if (getOperation()->getNumOperands() == 1) + return getOperation()->getOperands(); + return OperandRange(getOperation()->operand_end(), + getOperation()->operand_end()); +} + +void transform::AlternativesOp::getSuccessorRegions( + Optional index, ArrayRef operands, + SmallVectorImpl ®ions) { + for (Region &alternative : + llvm::drop_begin(getAlternatives(), index.hasValue() ? *index + 1 : 0)) { + regions.emplace_back(&alternative, !getOperands().empty() + ? alternative.getArguments() + : Block::BlockArgListType()); + } + if (index.hasValue()) + regions.emplace_back(getOperation()->getResults()); +} + +void transform::AlternativesOp::getRegionInvocationBounds( + ArrayRef operands, SmallVectorImpl &bounds) { + (void)operands; + // The region corresponding to the first alternative is always executed, the + // remaining may or may not be executed. + bounds.reserve(getNumRegions()); + bounds.emplace_back(1, 1); + bounds.resize(getNumRegions(), InvocationBounds(0, 1)); +} + +static void forwardTerminatorOperands(Block *block, + transform::TransformState &state, + transform::TransformResults &results) { + for (const auto &pair : llvm::zip(block->getTerminator()->getOperands(), + block->getParentOp()->getOpResults())) { + Value terminatorOperand = std::get<0>(pair); + OpResult result = std::get<1>(pair); + results.set(result, state.getPayloadOps(terminatorOperand)); + } +} + +DiagnosedSilencableFailure +transform::AlternativesOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + SmallVector originals; + if (Value scopeHandle = getScope()) + llvm::append_range(originals, state.getPayloadOps(scopeHandle)); + else + originals.push_back(state.getTopLevel()); + + for (Operation *original : originals) { + if (original->isAncestor(getOperation())) { + InFlightDiagnostic diag = + emitError() << "scope must not contain the transforms being applied"; + diag.attachNote(original->getLoc()) << "scope"; + return DiagnosedSilencableFailure::definiteFailure(); + } + } + + for (Region ® : getAlternatives()) { + // Clone the scope operations and make the transforms in this alternative + // region apply to them by virtue of mapping the block argument (the only + // visible handle) to the cloned scope operations. This effectively prevents + // the transformation from accessing any IR outside the scope. + auto scope = state.make_region_scope(reg); + auto clones = llvm::to_vector( + llvm::map_range(originals, [](Operation *op) { return op->clone(); })); + if (failed(state.mapBlockArguments(reg.front().getArgument(0), clones))) + return DiagnosedSilencableFailure::definiteFailure(); + auto deleteClones = llvm::make_scope_exit([&] { + for (Operation *clone : clones) + clone->erase(); + }); + + bool failed = false; + for (Operation &transform : reg.front().without_terminator()) { + DiagnosedSilencableFailure result = + state.applyTransform(cast(transform)); + if (result.isSilencableFailure()) { + LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage() + << "\n"); + failed = true; + break; + } + + if (::mlir::failed(result.silence())) + return DiagnosedSilencableFailure::definiteFailure(); + } + + // If all operations in the given alternative succeeded, no need to consider + // the rest. Replace the original scoping operation with the clone on which + // the transformations were performed. + if (!failed) { + // We will be using the clones, so cancel their scheduled deletion. + deleteClones.release(); + IRRewriter rewriter(getContext()); + for (const auto &kvp : llvm::zip(originals, clones)) { + Operation *original = std::get<0>(kvp); + Operation *clone = std::get<1>(kvp); + original->getBlock()->getOperations().insert(original->getIterator(), + clone); + rewriter.replaceOp(original, clone->getResults()); + } + forwardTerminatorOperands(®.front(), state, results); + return DiagnosedSilencableFailure::success(); + } + } + return emitSilencableError() << "all alternatives failed"; +} + +LogicalResult transform::AlternativesOp::verify() { + for (Region &alternative : getAlternatives()) { + Block &block = alternative.front(); + if (block.getNumArguments() != 1 || + !block.getArgument(0).getType().isa()) { + return emitOpError() + << "expects region blocks to have one operand of type " + << pdl::OperationType::get(getContext()); + } + + Operation *terminator = block.getTerminator(); + if (terminator->getOperands().getTypes() != getResults().getTypes()) { + InFlightDiagnostic diag = emitOpError() + << "expects terminator operands to have the " + "same type as results of the operation"; + diag.attachNote(terminator->getLoc()) << "terminator"; + return diag; + } + } + + return success(); +} + //===----------------------------------------------------------------------===// // GetClosestIsolatedParentOp //===----------------------------------------------------------------------===// -LogicalResult transform::GetClosestIsolatedParentOp::apply( +DiagnosedSilencableFailure transform::GetClosestIsolatedParentOp::apply( transform::TransformResults &results, transform::TransformState &state) { SetVector parents; for (Operation *target : state.getPayloadOps(getTarget())) { Operation *parent = target->getParentWithTrait(); if (!parent) { - InFlightDiagnostic diag = - emitError() << "could not find an isolated-from-above parent op"; + DiagnosedSilencableFailure diag = + emitSilencableError() + << "could not find an isolated-from-above parent op"; diag.attachNote(target->getLoc()) << "target op"; return diag; } parents.insert(parent); } results.set(getResult().cast(), parents.getArrayRef()); - return success(); + return DiagnosedSilencableFailure::success(); } //===----------------------------------------------------------------------===// // PDLMatchOp //===----------------------------------------------------------------------===// -LogicalResult transform::PDLMatchOp::apply(transform::TransformResults &results, - transform::TransformState &state) { +DiagnosedSilencableFailure +transform::PDLMatchOp::apply(transform::TransformResults &results, + transform::TransformState &state) { auto *extension = state.getExtension(); assert(extension && "expected PatternApplicatorExtension to be attached by the parent op"); @@ -150,41 +293,38 @@ for (Operation *root : state.getPayloadOps(getRoot())) { if (failed(extension->findAllMatches( getPatternName().getLeafReference().getValue(), root, targets))) { - return emitOpError() << "could not find pattern '" << getPatternName() - << "'"; + emitOpError() << "could not find pattern '" << getPatternName() << "'"; + return DiagnosedSilencableFailure::definiteFailure(); } } results.set(getResult().cast(), targets); - return success(); + return DiagnosedSilencableFailure::success(); } //===----------------------------------------------------------------------===// // SequenceOp //===----------------------------------------------------------------------===// -LogicalResult transform::SequenceOp::apply(transform::TransformResults &results, - transform::TransformState &state) { +DiagnosedSilencableFailure +transform::SequenceOp::apply(transform::TransformResults &results, + transform::TransformState &state) { // Map the entry block argument to the list of operations. auto scope = state.make_region_scope(*getBodyBlock()->getParent()); if (failed(mapBlockArguments(state))) - return failure(); + return DiagnosedSilencableFailure::definiteFailure(); // Apply the sequenced ops one by one. - for (Operation &transform : getBodyBlock()->without_terminator()) - if (failed(state.applyTransform(cast(transform)))) - return failure(); + for (Operation &transform : getBodyBlock()->without_terminator()) { + DiagnosedSilencableFailure result = + state.applyTransform(cast(transform)); + if (!result.succeeded()) + return result; + } // Forward the operation mapping for values yielded from the sequence to the // values produced by the sequence op. - for (const auto &pair : - llvm::zip(getBodyBlock()->getTerminator()->getOperands(), - getOperation()->getOpResults())) { - Value terminatorOperand = std::get<0>(pair); - OpResult result = std::get<1>(pair); - results.set(result, state.getPayloadOps(terminatorOperand)); - } - - return success(); + forwardTerminatorOperands(getBodyBlock(), state, results); + return DiagnosedSilencableFailure::success(); } /// Returns `true` if the given op operand may be consuming the handle value in @@ -346,7 +486,7 @@ // WithPDLPatternsOp //===----------------------------------------------------------------------===// -LogicalResult +DiagnosedSilencableFailure transform::WithPDLPatternsOp::apply(transform::TransformResults &results, transform::TransformState &state) { OwningOpRef pdlModuleOp = @@ -365,7 +505,7 @@ auto scope = state.make_region_scope(getBody()); if (failed(mapBlockArguments(state))) - return failure(); + return DiagnosedSilencableFailure::definiteFailure(); return state.applyTransform(transformOp); } diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir --- a/mlir/test/Dialect/Transform/ops-invalid.mlir +++ b/mlir/test/Dialect/Transform/ops-invalid.mlir @@ -81,7 +81,7 @@ // ----- -// expected-error @below {{expects one region}} +// expected-error @below {{expects at least one region}} "transform.test_transform_unrestricted_op_no_interface"() : () -> () // ----- @@ -153,3 +153,34 @@ } } } + +// ----- + +transform.sequence { +^bb1(%arg1: !pdl.operation): + // expected-error @below {{expects at least one region}} + transform.alternatives +} + +// ----- + +transform.sequence { +^bb1(%arg1: !pdl.operation): + // expected-error @below {{expects terminator operands to have the same type as results of the operation}} + %2 = transform.alternatives %arg1 -> !pdl.operation { + ^bb2(%arg2: !pdl.operation): + transform.yield %arg2 : !pdl.operation + }, { + ^bb2(%arg2: !pdl.operation): + // expected-note @below {{terminator}} + transform.yield + } +} + +// ----- + +// expected-error @below {{expects the entry block to have one argument of type '!pdl.operation'}} +transform.alternatives { +^bb0: + transform.yield +} diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -128,3 +128,223 @@ test_print_remark_at_operand %m, "parent function" } } + +// ----- + +func.func @foo() { + %0 = arith.constant 0 : i32 + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @match_func : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "func.func"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + // This is necessary to run the transformation on something other than the + // top-level module, "alternatives" cannot be run on that. + %0 = pdl_match @match_func in %arg1 + transform.alternatives %0 { + ^bb2(%arg2: !pdl.operation): + %1 = transform.test_produce_param_or_forward_operand 42 + // This operation fails, which triggers the next alternative without + // reporting the error. + transform.test_consume_operand_if_matches_param_or_fail %1[43] + }, { + ^bb2(%arg2: !pdl.operation): + %1 = transform.test_produce_param_or_forward_operand 42 + // expected-remark @below {{succeeded}} + transform.test_consume_operand_if_matches_param_or_fail %1[42] + } + } +} + +// ----- + +func.func private @bar() + +func.func @foo() { + call @bar() : () -> () + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @match_call : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "func.call"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @match_call in %arg1 + %1 = get_closest_isolated_parent %0 + // expected-error @below {{all alternatives failed}} + transform.alternatives %1 { + ^bb2(%arg2: !pdl.operation): + %2 = transform.pdl_match @match_call in %arg2 + // expected-remark @below {{applying}} + transform.test_emit_remark_and_erase_operand %2, "applying" {fail_after_erase} + } + } +} + +// ----- + +func.func private @bar() + +func.func @foo() { + // expected-remark @below {{still here}} + call @bar() : () -> () + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @match_call : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "func.call"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @match_call in %arg1 + %1 = get_closest_isolated_parent %0 + transform.alternatives %1 { + ^bb2(%arg2: !pdl.operation): + %2 = transform.pdl_match @match_call in %arg2 + // expected-remark @below {{applying}} + transform.test_emit_remark_and_erase_operand %2, "applying" {fail_after_erase} + }, { + ^bb2(%arg2: !pdl.operation): + %2 = transform.pdl_match @match_call in %arg2 + transform.test_print_remark_at_operand %2, "still here" + // This alternative succeeds. + }, { + ^bb2(%arg2: !pdl.operation): + // This alternative is never run, so we must not have a remark here. + %2 = transform.pdl_match @match_call in %arg2 + transform.test_emit_remark_and_erase_operand %2, "should not happen" {fail_after_erase} + } + } +} + +// ----- + +func.func private @bar() + +// CHECK-LABEL: @erase_call +func.func @erase_call() { + // CHECK-NOT: call @bar + call @bar() : () -> () + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @match_call : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "func.call"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @match_call in %arg1 + %1 = get_closest_isolated_parent %0 + transform.alternatives %1 { + ^bb2(%arg2: !pdl.operation): + %2 = transform.pdl_match @match_call in %arg2 + // expected-remark @below {{applying}} + transform.test_emit_remark_and_erase_operand %2, "applying" {fail_after_erase} + }, { + ^bb2(%arg2: !pdl.operation): + %2 = transform.pdl_match @match_call in %arg2 + // expected-remark @below {{applying second time}} + transform.test_emit_remark_and_erase_operand %2, "applying second time" + } + } +} + +// ----- + +func.func private @bar() + +func.func @foo() { + call @bar() : () -> () + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @match_call : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "func.call"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @match_call in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.alternatives %1 -> !pdl.operation { + ^bb2(%arg2: !pdl.operation): + %3 = transform.pdl_match @match_call in %arg2 + // expected-remark @below {{applying}} + transform.test_emit_remark_and_erase_operand %3, "applying" {fail_after_erase} + %4 = transform.test_produce_param_or_forward_operand 43 + transform.yield %4 : !pdl.operation + }, { + ^bb2(%arg2: !pdl.operation): + %4 = transform.test_produce_param_or_forward_operand 42 + transform.yield %4 : !pdl.operation + } + // The first alternative failed, so the returned value is taken from the + // second alternative. + // expected-remark @below {{succeeded}} + transform.test_consume_operand_if_matches_param_or_fail %2[42] + } +} + +// ----- + +// expected-note @below {{scope}} +module { + func.func @foo() { + %0 = arith.constant 0 : i32 + return + } + + func.func @bar() { + %0 = arith.constant 0 : i32 + %1 = arith.constant 1 : i32 + return + } + + transform.sequence { + ^bb1(%arg1: !pdl.operation): + // expected-error @below {{scope must not contain the transforms being applied}} + transform.alternatives %arg1 { + ^bb2(%arg2: !pdl.operation): + %0 = transform.test_produce_param_or_forward_operand 42 + transform.test_consume_operand_if_matches_param_or_fail %0[43] + }, { + ^bb2(%arg2: !pdl.operation): + %0 = transform.test_produce_param_or_forward_operand 42 + transform.test_consume_operand_if_matches_param_or_fail %0[42] + } + } +} + diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -38,13 +38,13 @@ return llvm::StringLiteral("transform.test_transform_op"); } - LogicalResult apply(transform::TransformResults &results, - transform::TransformState &state) { + DiagnosedSilencableFailure apply(transform::TransformResults &results, + transform::TransformState &state) { InFlightDiagnostic remark = emitRemark() << "applying transformation"; if (Attribute message = getMessage()) remark << " " << message; - return success(); + return DiagnosedSilencableFailure::success(); } Attribute getMessage() { return getOperation()->getAttr("message"); } @@ -91,9 +91,9 @@ "transform.test_transform_unrestricted_op_no_interface"); } - LogicalResult apply(transform::TransformResults &results, - transform::TransformState &state) { - return success(); + DiagnosedSilencableFailure apply(transform::TransformResults &results, + transform::TransformState &state) { + return DiagnosedSilencableFailure::success(); } // No side effects. @@ -101,7 +101,8 @@ }; } // namespace -LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::apply( +DiagnosedSilencableFailure +mlir::test::TestProduceParamOrForwardOperandOp::apply( transform::TransformResults &results, transform::TransformState &state) { if (getOperation()->getNumOperands() != 0) { results.set(getResult().cast(), @@ -110,7 +111,7 @@ results.set(getResult().cast(), reinterpret_cast(*getParameter())); } - return success(); + return DiagnosedSilencableFailure::success(); } LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() { @@ -119,48 +120,50 @@ return success(); } -LogicalResult +DiagnosedSilencableFailure mlir::test::TestConsumeOperand::apply(transform::TransformResults &results, transform::TransformState &state) { - return success(); + return DiagnosedSilencableFailure::success(); } -LogicalResult mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply( +DiagnosedSilencableFailure +mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply( transform::TransformResults &results, transform::TransformState &state) { ArrayRef payload = state.getPayloadOps(getOperand()); assert(payload.size() == 1 && "expected a single target op"); auto value = reinterpret_cast(payload[0]); if (static_cast(value) != getParameter()) { - return emitOpError() << "expected the operand to be associated with " - << getParameter() << " got " << value; + return emitSilencableError() + << "op expected the operand to be associated with " << getParameter() + << " got " << value; } emitRemark() << "succeeded"; - return success(); + return DiagnosedSilencableFailure::success(); } -LogicalResult mlir::test::TestPrintRemarkAtOperandOp::apply( +DiagnosedSilencableFailure mlir::test::TestPrintRemarkAtOperandOp::apply( transform::TransformResults &results, transform::TransformState &state) { ArrayRef payload = state.getPayloadOps(getOperand()); for (Operation *op : payload) op->emitRemark() << getMessage(); - return success(); + return DiagnosedSilencableFailure::success(); } -LogicalResult +DiagnosedSilencableFailure mlir::test::TestAddTestExtensionOp::apply(transform::TransformResults &results, transform::TransformState &state) { state.addExtension(getMessageAttr()); - return success(); + return DiagnosedSilencableFailure::success(); } -LogicalResult mlir::test::TestCheckIfTestExtensionPresentOp::apply( +DiagnosedSilencableFailure mlir::test::TestCheckIfTestExtensionPresentOp::apply( transform::TransformResults &results, transform::TransformState &state) { auto *extension = state.getExtension(); if (!extension) { emitRemark() << "extension absent"; - return success(); + return DiagnosedSilencableFailure::success(); } InFlightDiagnostic diag = emitRemark() @@ -172,40 +175,56 @@ "operations"); } - return success(); + return DiagnosedSilencableFailure::success(); } -LogicalResult mlir::test::TestRemapOperandPayloadToSelfOp::apply( +DiagnosedSilencableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply( transform::TransformResults &results, transform::TransformState &state) { auto *extension = state.getExtension(); - if (!extension) - return emitError() << "TestTransformStateExtension missing"; + if (!extension) { + emitError() << "TestTransformStateExtension missing"; + return DiagnosedSilencableFailure::definiteFailure(); + } - return extension->updateMapping(state.getPayloadOps(getOperand()).front(), - getOperation()); + if (failed(extension->updateMapping(state.getPayloadOps(getOperand()).front(), + getOperation()))) + return DiagnosedSilencableFailure::definiteFailure(); + return DiagnosedSilencableFailure::success(); } -LogicalResult mlir::test::TestRemoveTestExtensionOp::apply( +DiagnosedSilencableFailure mlir::test::TestRemoveTestExtensionOp::apply( transform::TransformResults &results, transform::TransformState &state) { state.removeExtension(); - return success(); + return DiagnosedSilencableFailure::success(); } -LogicalResult mlir::test::TestTransformOpWithRegions::apply( +DiagnosedSilencableFailure mlir::test::TestTransformOpWithRegions::apply( transform::TransformResults &results, transform::TransformState &state) { - return success(); + return DiagnosedSilencableFailure::success(); } void mlir::test::TestTransformOpWithRegions::getEffects( SmallVectorImpl &effects) {} -LogicalResult mlir::test::TestBranchingTransformOpTerminator::apply( +DiagnosedSilencableFailure +mlir::test::TestBranchingTransformOpTerminator::apply( transform::TransformResults &results, transform::TransformState &state) { - return success(); + return DiagnosedSilencableFailure::success(); } void mlir::test::TestBranchingTransformOpTerminator::getEffects( SmallVectorImpl &effects) {} +DiagnosedSilencableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply( + transform::TransformResults &results, transform::TransformState &state) { + emitRemark() << getRemark(); + for (Operation *op : state.getPayloadOps(getTarget())) + op->erase(); + + if (getFailAfterErase()) + return emitSilencableError() << "silencable error"; + return DiagnosedSilencableFailure::success(); +} + namespace { /// Test extension of the Transform dialect. Registers additional ops and /// declares PDL as dependent dialect since the additional ops are using PDL diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -118,4 +118,14 @@ let cppNamespace = "::mlir::test"; } +def TestEmitRemarkAndEraseOperandOp + : Op, + MemoryEffectsOpInterface, FunctionalStyleTransformOpTrait]> { + let arguments = (ins PDL_Operation:$target, StrAttr:$remark, + UnitAttr:$fail_after_erase); + let assemblyFormat = "$target `,` $remark attr-dict"; + let cppNamespace = "::mlir::test"; +} + #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp @@ -47,7 +47,7 @@ enableExpensiveChecks)); for (auto op : module.getBody()->getOps()) { - if (failed(state.applyTransform(op))) + if (failed(state.applyTransform(op).checkAndReport())) return signalPassFailure(); } }