diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -18,7 +18,7 @@ /// The result of a transform IR operation application. This can have one of the /// three states: /// - success; -/// - silencable (recoverable) failure with yet-unreported diagnostic; +/// - silenceable (recoverable) failure with yet-unreported diagnostic; /// - definite failure. /// Silenceable failure is intended to communicate information about /// transformations that did not apply but in a way that supports recovery, @@ -26,10 +26,10 @@ /// predictable way. They are associated with a Diagnostic that provides more /// details on the failure. Silenceable failure can be discarded, turning the /// result into success, or "reported", emitting the diagnostic and turning the -/// result into definite failure. Transform IR operations containing other -/// operations are allowed to do either with the results of the nested -/// transformations, but must propagate definite failures as their diagnostics -/// have been already reported to the user. +/// result into definite failure. +/// Transform IR operations containing other operations are allowed to do either +/// with the results of the nested transformations, but must propagate definite +/// failures as their diagnostics have been already reported to the user. class LLVM_NODISCARD DiagnosedSilenceableFailure { public: explicit DiagnosedSilenceableFailure(LogicalResult result) : result(result) {} @@ -51,10 +51,10 @@ return DiagnosedSilenceableFailure(::mlir::failure()); } - /// Constructs a DiagnosedSilenceableFailure in the silencable failure state, + /// Constructs a DiagnosedSilenceableFailure in the silenceable failure state, /// ready to emit the given diagnostic. This is considered a failure /// regardless of the diagnostic severity. - static DiagnosedSilenceableFailure silencableFailure(Diagnostic &&diag) { + static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag) { return DiagnosedSilenceableFailure(std::forward(diag)); } @@ -74,7 +74,7 @@ return result; } - /// Returns `true` if this is a silencable failure. + /// Returns `true` if this is a silenceable failure. bool isSilenceableFailure() const { return diagnostic.hasValue(); } /// Returns `true` if this is a success. @@ -83,10 +83,10 @@ } /// Returns the diagnostic message without emitting it. Expects this object - /// to be a silencable failure. + /// to be a silenceable failure. std::string getMessage() const { return diagnostic->str(); } - /// Converts silencable failure into LogicalResult success without reporting + /// Converts silenceable failure into LogicalResult success without reporting /// the diagnostic, preserves the other states. LogicalResult silence() { if (diagnostic) { @@ -96,12 +96,19 @@ return result; } + /// Take the diagnostic and silence. + Diagnostic &&takeDiagnostic() { + assert(diagnostic && "expected a diagnostic to be present"); + auto guard = llvm::make_scope_exit([&]() { diagnostic.reset(); }); + return std::move(*diagnostic); + } + /// Streams the given values into the diagnotic. Expects this object to be a - /// silencable failure. + /// silenceable failure. template DiagnosedSilenceableFailure &operator<<(T &&value) & { assert(isSilenceableFailure() && - "can only append output in silencable failure state"); + "can only append output in silenceable failure state"); *diagnostic << std::forward(value); return *this; } @@ -110,11 +117,11 @@ return std::move(this->operator<<(std::forward(value))); } - /// Attaches a note to the diagnostic. Expects this object to be a silencable + /// Attaches a note to the diagnostic. Expects this object to be a silenceable /// failure. Diagnostic &attachNote(Optional loc = llvm::None) { assert(isSilenceableFailure() && - "can only attach notes to silencable failures"); + "can only attach notes to silenceable failures"); return diagnostic->attachNote(loc); } @@ -123,7 +130,7 @@ : diagnostic(std::move(diagnostic)), result(failure()) {} /// The diagnostic associated with this object. If present, the object is - /// considered to be in the silencable failure state regardless of the + /// considered to be in the silenceable failure state regardless of the /// `result` field. Optional diagnostic; @@ -581,22 +588,29 @@ /// Trait implementing the TransformOpInterface for operations applying a /// transformation to a single operation handle and producing one or multiple /// operation handles. -/// The op must implement a method with one of the following signatures: -/// - FailureOr applyToOne(OpTy, state) -/// - FailureOr>applyToOne(OpTy, state) -/// - LogicalResult applyToOne(OpTy, state) +/// The op must implement a method with the following signature: +/// - DiagnosedSilenceableFailure applyToOne(OpTy, +/// SmallVector &newOps, state) /// to perform a transformation that is applied in turn to all payload IR /// operations that correspond to the handle of the transform IR operation. -/// In the functions above, OpTy is either Operation * or a concrete payload IR +/// In the functions above, OpTy is either Operation* or a concrete payload IR /// Op class that the transformation is applied to (NOT the class of the -/// transform IR op). The op is expected to have a single operand. +/// transform IR op). +/// The `applyToOne` method is allowed to fill the `newOps` vector with +/// NULL elements to signify that the transformation did not apply to this the +/// payload IR operations. +/// The op is expected to have a single operand. template class TransformEachOpTrait : public OpTrait::TraitBase { public: /// Calls `applyToOne` for every payload operation associated with the operand - /// of this transform IR op. If `applyToOne` returns ops, associates them with - /// the result of this transform op. + /// of this transform IR op. If `applyToOne` produces ops, associate them with + /// the result of this transform op. If any `applyToOne` returns + /// definiteFailure, the transformation is considered failed. If all + /// `applyToOne` return success, the transformation is considered succeeded. + /// If some `applyToOne` return silenceableFailure, the transformation is + /// considered silenceable. DiagnosedSilenceableFailure apply(TransformResults &transformResults, TransformState &state); @@ -714,65 +728,23 @@ namespace mlir { namespace transform { namespace detail { -/// Appends `result` to the vector assuming it corresponds to the success state -/// in `FailureOr`. If `result` is just a -/// `LogicalResult`, appends an empy vector. -template -std::enable_if_t::value, LogicalResult> -appendTransformResultToVector( - Ty result, SmallVectorImpl> &results) { - results.push_back(SmallVector()); - return result; -} - -template -std::enable_if_t< - llvm::conjunction< - llvm::negation>, - std::is_convertible>::value, - LogicalResult> -appendTransformResultToVector( - Ty result, SmallVectorImpl> &results) { - if (failed(result)) - return failure(); - results.push_back(SmallVector{*result}); - return success(); -} - -template -std::enable_if_t< - llvm::conjunction< - llvm::negation>, - llvm::negation>>::value, - LogicalResult> -appendTransformResultToVector( - ContainerTy resultContainer, - SmallVectorImpl> &results) { - if (failed(resultContainer)) - return failure(); - results.push_back(*resultContainer); - return success(); -} /// Applies a one-to-one or a one-to-many transform to each of the given -/// targets. Puts the results of transforms, if any, in `results` in the same +/// targets. Puts the newOps of transforms, if any, in `newOps` in the same /// order. Fails if any of the application fails. Individual transforms must be -/// callable with one of the following signatures: -/// - FailureOr(OpTy) -/// - LogicalResult(OpTy) -/// - FailureOr>( -/// SmallVectorImpl) -/// - LogicalResult(SmallVectorImpl) +/// callable with the following signature: +/// - DiagnosedSilenceableFailure(OpTy, +/// SmallVector &newOps, state) /// where OpTy is either /// - Operation *, in which case the transform is always applied; /// - a concrete Op class, in which case a check is performed whether -/// `targets` contains operations of the same class and a silencable failure +/// `targets` contains operations of the same class and a silenceable failure /// is reported if it does not. template DiagnosedSilenceableFailure applyTransformToEach(ArrayRef targets, - SmallVectorImpl> &results, + SmallVectorImpl> &newOps, FnTy transform) { + SmallVector silenceableStack; using OpTy = typename llvm::function_traits::template arg_t<0>; static_assert(std::is_convertible::value, "expected transform function to take an operation"); @@ -781,21 +753,35 @@ "expected transform function to return LogicalResult or " "FailureOr"); for (Operation *target : targets) { + // + newOps.push_back(SmallVector()); + auto specificOp = dyn_cast(target); if (!specificOp) { - Diagnostic diag(target->getLoc(), DiagnosticSeverity::Error); + Diagnostic diag(target->getLoc(), DiagnosticSeverity::Warning); diag << "attempted to apply transform to the wrong op kind"; - return DiagnosedSilenceableFailure::silencableFailure(std::move(diag)); + // TODO: Should this return immediately or just add to the stack? + // Producing 0-newOps when successful paths produce a fixed number of + // newOps is a reasonable silenceableFailure mode. + silenceableStack.push_back(std::move(diag)); + // Old code: + // return + // DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); } - auto result = transform(specificOp); - if (failed(appendTransformResultToVector(result, results))) - return DiagnosedSilenceableFailure::definiteFailure(); + DiagnosedSilenceableFailure result = transform(specificOp, newOps.back()); + if (failed(result)) + return result; + if (result.isSilenceableFailure()) + silenceableStack.push_back(std::move(result.takeDiagnostic())); } + // TODO: return the stack of diagnostics. + if (!silenceableStack.empty()) + return DiagnosedSilenceableFailure::silenceableFailure(); return DiagnosedSilenceableFailure::success(); } -/// Helper function to transform M ops with N results into N results of M ops. +/// Helper function: transpose MxN into NxM; assumes that the input is a valid. static inline SmallVector> transposeResults(const SmallVector, 1> &m) { SmallVector> res; @@ -824,74 +810,86 @@ decltype(&OpTy::applyToOne)>::template arg_t<0>; ArrayRef targets = state.getPayloadOps(this->getOperation()->getOperand(0)); - // Handle the corner case where no target is specified. + // Step 1. Handle the corner case where no target is specified. // This is typically the case when the matcher fails to apply and we need to // propagate gracefully. - // In this case, we fill all results with an empty vector. + // In this case, we fill all newOps with an empty vector. if (targets.empty()) { - SmallVector emptyResult; + SmallVector empty; for (auto r : this->getOperation()->getResults()) - transformResults.set(r.template cast(), emptyResult); + transformResults.set(r.template cast(), empty); return DiagnosedSilenceableFailure::success(); } - SmallVector, 1> results; - // In the multi-result case, collect the number of results each transform - // produced. + // Step 2. Call applyToOne on each target and recorde newly produced ops in + // its corresponding newOps entry. + SmallVector, 1> newOps; DiagnosedSilenceableFailure result = detail::applyTransformToEach( - targets, results, [&](TransformOpType specificOp) { - return static_cast(this)->applyToOne(specificOp, state); + targets, newOps, + [&](TransformOpType specificOp, SmallVector &partialResult) { + return static_cast(this)->applyToOne(specificOp, partialResult, + state); }); - // Propagate the failure (definite or silencable) if any. - if (!result.succeeded()) + + // Step 3. Propagate the definite failure if any and bail out. + if (result.failed()) return result; - // Legitimately no results, bail early. - if (results.empty() && OpTy::template hasTrait()) + // Step 4. If there are legitimately no newOps, return early. + if (newOps.empty() && OpTy::template hasTrait()) return DiagnosedSilenceableFailure::success(); - // Ensure all applications return the same number of results. - // Variadic cases are much trickier to handle in a generic fashion. - int64_t nRes = results.empty() ? 0 : results[0].size(); - if (llvm::any_of(results, [&](const auto &r) { - return static_cast(r.size()) != nRes; - })) { - return static_cast(this)->emitSilenceableError() - << "expected all applications of " << OpTy::getOperationName() - << " to produce " << nRes - << " results.\n If you need variadic results, consider using a " - "generic `apply` instead of the specialized `applyToOne`"; - } - // Ensure the number of results agrees with what the transform op expects. - // Unless we see empty results, in which case we just want to propagate the - // emptiness. - if (this->getOperation()->getNumResults() != nRes) { - InFlightDiagnostic diag = static_cast(this)->emitError() - << "unexpected number of results (got " << nRes - << " expected " - << this->getOperation()->getNumResults() << ")"; + // Step 5. Ensure all applications return the same number of newOps. + // Also allow a particular application of applyToOne to return no newOps, + // which must mean that we have a silenceableFailure result. In this case, we + // normalize the + // + // Note: variadic cases are not supported here as they are much trickier to + // handle in a generic fashion. + int nRes = this->getOperation()->getNumResults(); + bool normalized = llvm::any_of(newOps, [&](auto &r) { + if (nRes > 0 && r.empty()) { + r.resize(nRes, nullptr); + assert(result.isSilenceableFailure() && + "applyToOne produced 0 new ops when the transform expected " + "non-zero: this must be a silenceableFailure"); + } + return static_cast(r.size()) != nRes; + }); + if (!normalized) { + mlir::emitError(target->getLoc(), "expected all applications of ") + << OpTy::getOperationName() + << " to produce 0 (silenceableFailure case) or " << nRes + << " (success case) newOps." + << "\nIf you need variadic newOps, consider using a " + "generic `apply` instead of the specialized `applyToOne`"; return DiagnosedSilenceableFailure::definiteFailure(); } - // Perform transposition of M applications producing N results each into N - // results for each of the M applications. + // Step 6. Perform transposition of M applications producing N newOps each + // into N newOps for each of the M applications. SmallVector> transposedResults = - detail::transposeResults(results); - // Single result applies to M ops produces one single M-result. + detail::transposeResults(newOps); + + // Step 7. Single result applies to M ops produces one single M-result. if (OpTy::template hasTrait()) { assert(transposedResults.size() == 1 && "Expected single result"); transformResults.set( this->getOperation()->getResult(0).template cast(), transposedResults[0]); - return DiagnosedSilenceableFailure::success(); + // ApplyToOne may have returned silenceableFailure, propagate it. + return result; } - // M ops, N results each. + + // Step 8. Set the transformResults: M ops, N newOps each. for (const auto &it : llvm::zip(this->getOperation()->getResults(), transposedResults)) { transformResults.set(std::get<0>(it).template cast(), std::get<1>(it)); } - return DiagnosedSilenceableFailure::success(); + + // Step 9. ApplyToOne may have returned silenceableFailure, propagate it. + return result; } template diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -62,11 +62,11 @@ return diag; } - /// Creates the silencable failure object with a diagnostic located at the + /// Creates the silenceable failure object with a diagnostic located at the /// current operation. DiagnosedSilenceableFailure emitSilenceableError() { Diagnostic diag($_op->getLoc(), DiagnosticSeverity::Error); - return DiagnosedSilenceableFailure::silencableFailure(std::move(diag)); + return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); } }]; }