diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h @@ -15,6 +15,7 @@ namespace mlir { namespace linalg { +class GenericOp; class LinalgOp; } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -22,9 +22,14 @@ let description = [{ Decomposes named complex operations, such as higher-dimensional (depthwise) convolutions, into combinations of lower-dimensional equivalents - when possible. The operand handle must point to a list of such operations. - The returning handle points to the main produced computational operation, - such as the lower-dimensional convolution. + when possible. + + Return modes: + ============= + This operation always succeeds. + + The returned handle points to the main produced computational operation. + If the operation is not decomposable, the returned handle is set to null. }]; let arguments = (ins PDL_Operation:$target); @@ -32,8 +37,10 @@ let assemblyFormat = "$target attr-dict"; let extraClassDeclaration = [{ - ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne( - ::mlir::linalg::LinalgOp target, TransformState &state); + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::linalg::LinalgOp target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); }]; } @@ -61,11 +68,15 @@ TransformOpInterface, TransformEachOpTrait]> { let description = [{ Transforms a named structued operation into the generic form with the - explicit attached region. The operand handle must point to a list of - structured operations, it is consumed by the transformation and is not - expected to be used afterwards. The resulting handle points to the list - of equivalent generic operations, in the same order as the original named - operations. + explicit attached region. + + Return modes: + ============= + This operation always succeeds. + + The returned handle points to the equivalent generic operation, this may be + the original op if it was already in generic form. + If the operation is not generalizable, the returned handle is set to null. }]; let arguments = (ins PDL_Operation:$target); @@ -73,8 +84,10 @@ let assemblyFormat = "$target attr-dict"; let extraClassDeclaration = [{ - ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne( - ::mlir::linalg::LinalgOp target, TransformState &state); + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::linalg::LinalgOp target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); }]; } @@ -84,6 +97,15 @@ let description = [{ Interchanges the iterators of the operations pointed to by the target handle using the iterator interchange attribute. + + Return modes: + ============= + This operation may produce a definiteFailure if the interchange attribute + is invalid. + If the transform is applied to the wrong op kind, the result is null and + the operation produces a silenceableFailure. + Otherwise, it succeeds and returns the handle to the rewritten generic op. + This handle may be the original op if the permutation is the identity. }]; let arguments = @@ -95,8 +117,10 @@ let hasVerifier = 1; let extraClassDeclaration = [{ - ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne( - ::mlir::linalg::LinalgOp target, TransformState &state); + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::linalg::GenericOp target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); }]; } @@ -106,6 +130,12 @@ let description = [{ Pads the operations pointed to by the target handle using the options provides as operation attributes. + + Return modes: + ============= + This operation may produce a definiteFailure if the padding fails for any + reason. + Otherwise, it succeeds and returns the handle to the rewritten padded op. }]; let arguments = @@ -123,8 +153,10 @@ let hasVerifier = 1; let extraClassDeclaration = [{ - ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne( - ::mlir::linalg::LinalgOp target, TransformState &state); + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::linalg::LinalgOp target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); }]; } @@ -135,8 +167,12 @@ Indicates that ops of a specific kind in the given function should be scalarized (i.e. their dynamic dimensions tiled by 1). - This operation returns the tiled op but not the loops. + Return modes: + ============= + This operation produces `definiteFailure` if the scalarization fails for any + reason. + This operation returns the tiled op but not the loops. We make this design choice because it is hard to know ahead of time the number of loops that will be produced (it depends on the number of dynamic dimensions after multiple transformations have been applied). @@ -148,8 +184,10 @@ let assemblyFormat = "$target attr-dict"; let extraClassDeclaration = [{ - ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne( - ::mlir::linalg::LinalgOp target, TransformState &state); + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::linalg::LinalgOp target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); }]; } @@ -174,6 +212,11 @@ - use_alloc: whether to use an alloc op to allocate the temporary tensor (default: do not use alloc op) + Return modes: + ============= + This operation produces `definiteFailure` if the splitting fails for any + reason. + This op returns 4 handles to: - the init op (or tensor_alloc op if use_alloc = true), - the fill op used to initialize the neutral element, @@ -291,8 +334,10 @@ let assemblyFormat = "$target attr-dict"; let extraClassDeclaration = [{ - ::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne( - ::mlir::linalg::LinalgOp target, TransformState &state); + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::linalg::LinalgOp target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); }]; } @@ -340,6 +385,14 @@ Note that this transformation is invalidating the handles to any payload IR operation that is contained inside the vectorization target. + + Return modes: + ============= + This operation produces `definiteFailure` if the splitting fails for any + reason. + + The operation returns the handle to the target op that is expected to be + isolated from above. }]; let arguments = (ins PDL_Operation:$target, @@ -349,8 +402,10 @@ let assemblyFormat = "$target attr-dict"; let extraClassDeclaration = [{ - ::mlir::FailureOr applyToOne( - ::mlir::Operation *target, TransformState &state); + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); }]; } diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -211,6 +211,8 @@ /// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be /// integers, in the range 0..`op.rank` without duplications /// (i.e. `[1,1,2]` is an invalid permutation). +/// +/// Return failure if the permutation is not valid. FailureOr interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef interchangeVector); diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -50,7 +50,7 @@ outlined into a separate function. The provided name is used as a _base_ for forming actual function names following SymbolTable auto-renaming scheme to avoid duplicate symbols. Expects that all ops in the Payload IR - have a SymbolTable ancestor (typically true because of the top-level + have a SymbolTable ancestor (typically true because of the top-level module). Returns the handle to the list of outlined functions in the same order as the operand handle. }]; @@ -68,28 +68,38 @@ let summary = "Peels the last iteration of the loop"; let description = [{ Updates the given loop so that its step evenly divides its range and puts - the remaining iteration into a separate loop or a conditional. Note that - even though the Payload IR modification may be performed in-place, this - operation consumes the operand handle and produces a new one. Applies to - each loop associated with the operand handle individually. The results - follow the same order as the operand. - - Note: If it can be proven statically that the step already evenly divides - the range, this op is a no-op. In the absence of sufficient static - information, this op may peel a loop, even if the step always divides the - range evenly at runtime. + the remaining iteration into a separate loop or a conditional. + + In the absence of sufficient static information, this op may peel a loop, + even if the step always divides the range evenly at runtime. + + Return modes: + ============= + This operation always succeeds and returns the scf::ForOp with the + postcondition: "the loop trip count is divisible by the step". + This operation may return the same unmodified loop handle when peeling did + not modify the IR (i.e. the loop trip count was already divisible). + + Note that even though the Payload IR modification may be performed + in-place, this operation consumes the operand handle and produces a new + one. + + TODO: Return both the peeled loop and the remainder loop. }]; let arguments = (ins PDL_Operation:$target, DefaultValuedAttr:$fail_if_already_divisible); + // TODO: Return both the peeled loop and the remainder loop. let results = (outs PDL_Operation:$transformed); let assemblyFormat = "$target attr-dict"; let extraClassDeclaration = [{ - ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne( - ::mlir::scf::ForOp loop, TransformState &state); + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::scf::ForOp target, + ::llvm::SmallVector<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); }]; } @@ -102,10 +112,19 @@ each of them. That is, performs some amount of reads from memory before the loop rather than inside the loop, the same amount of writes into memory after the loop, and updates each iteration to read the data for a following - iteration rather than the current one. The amount is specified by the - attributes. The values read and about to be stored are transferred as loop - iteration arguments. Currently supports memref and vector transfer - operations as memory reads/writes. + iteration rather than the current one. + + The amount is specified by the attributes. + + The values read and about to be stored are transferred as loop iteration + arguments. Currently supports memref and vector transfer operations as + memory reads/writes. + + Return modes: + ============= + This operation succeeds and returns the pipelined loop when possible. + Otherwise, it succeeds and returns null to allow chaining in the cases where + further transformations are applied to only the pipelined loops. }]; let arguments = (ins PDL_Operation:$target, @@ -116,8 +135,10 @@ let assemblyFormat = "$target attr-dict"; let extraClassDeclaration = [{ - ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne( - ::mlir::scf::ForOp loop, TransformState &state); + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::scf::ForOp target, + ::llvm::SmallVector<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); }]; } @@ -126,11 +147,18 @@ TransformOpInterface, TransformEachOpTrait]> { let summary = "Unrolls the given loop with the given unroll factor"; let description = [{ - Unrolls each loop associated with the given handle to have up to the given - number of loop body copies per iteration. If the unroll factor is larger - than the loop trip count, the latter is used as the unroll factor instead. - Does not produce a new handle as the operation may result in the loop being - removed after a full unrolling. + Unrolls each loop associated with the given handle to have up to the given + number of loop body copies per iteration. If the unroll factor is larger + than the loop trip count, the latter is used as the unroll factor instead. + + Return modes: + ============== + This operation succeeds when loop unrolling occurred. + Otherwise, it produces a silenceableFailure to denote that the particular + loop failed to unroll. + + Does not return handles as the operation may result in the loop being + removed after a full unrolling. }]; let arguments = (ins PDL_Operation:$target, @@ -139,8 +167,10 @@ let assemblyFormat = "$target attr-dict"; let extraClassDeclaration = [{ - ::mlir::LogicalResult applyToOne( - ::mlir::scf::ForOp loop, TransformState &state); + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::scf::ForOp target, + ::llvm::SmallVector<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); }]; } diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -12,13 +12,14 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "llvm/ADT/ScopeExit.h" namespace mlir { /// The result of a transform IR operation application. This can have one of the /// three states: /// - success; -/// - silencable (recoverable) failure with yet-unreported diagnostic; +/// - silenceable (recoverable) failure with yet-unreported diagnostic; /// - definite failure. /// Silenceable failure is intended to communicate information about /// transformations that did not apply but in a way that supports recovery, @@ -26,10 +27,10 @@ /// predictable way. They are associated with a Diagnostic that provides more /// details on the failure. Silenceable failure can be discarded, turning the /// result into success, or "reported", emitting the diagnostic and turning the -/// result into definite failure. Transform IR operations containing other -/// operations are allowed to do either with the results of the nested -/// transformations, but must propagate definite failures as their diagnostics -/// have been already reported to the user. +/// result into definite failure. +/// Transform IR operations containing other operations are allowed to do either +/// with the results of the nested transformations, but must propagate definite +/// failures as their diagnostics have been already reported to the user. class LLVM_NODISCARD DiagnosedSilenceableFailure { public: explicit DiagnosedSilenceableFailure(LogicalResult result) : result(result) {} @@ -51,10 +52,10 @@ return DiagnosedSilenceableFailure(::mlir::failure()); } - /// Constructs a DiagnosedSilenceableFailure in the silencable failure state, + /// Constructs a DiagnosedSilenceableFailure in the silenceable failure state, /// ready to emit the given diagnostic. This is considered a failure /// regardless of the diagnostic severity. - static DiagnosedSilenceableFailure silencableFailure(Diagnostic &&diag) { + static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag) { return DiagnosedSilenceableFailure(std::forward(diag)); } @@ -74,7 +75,10 @@ return result; } - /// Returns `true` if this is a silencable failure. + /// Returns `true` if this is a silenceable failure. + bool isDefiniteFailure() const { return result.failed(); } + + /// Returns `true` if this is a silenceable failure. bool isSilenceableFailure() const { return diagnostic.hasValue(); } /// Returns `true` if this is a success. @@ -83,10 +87,19 @@ } /// Returns the diagnostic message without emitting it. Expects this object - /// to be a silencable failure. + /// to be a silenceable failure. std::string getMessage() const { return diagnostic->str(); } - /// Converts silencable failure into LogicalResult success without reporting + /// Returns a string representation of the failure mode (for error reporting). + std::string getStatusString() const { + if (succeeded()) + return "success"; + if (isSilenceableFailure()) + return "silenceable failure"; + return "definite failure"; + } + + /// Converts silenceable failure into LogicalResult success without reporting /// the diagnostic, preserves the other states. LogicalResult silence() { if (diagnostic) { @@ -96,12 +109,19 @@ return result; } + /// Take the diagnostic and silence. + Diagnostic &&takeDiagnostic() { + assert(diagnostic && "expected a diagnostic to be present"); + auto guard = llvm::make_scope_exit([&]() { diagnostic.reset(); }); + return std::move(*diagnostic); + } + /// Streams the given values into the diagnotic. Expects this object to be a - /// silencable failure. + /// silenceable failure. template DiagnosedSilenceableFailure &operator<<(T &&value) & { assert(isSilenceableFailure() && - "can only append output in silencable failure state"); + "can only append output in silenceable failure state"); *diagnostic << std::forward(value); return *this; } @@ -110,11 +130,11 @@ return std::move(this->operator<<(std::forward(value))); } - /// Attaches a note to the diagnostic. Expects this object to be a silencable + /// Attaches a note to the diagnostic. Expects this object to be a silenceable /// failure. Diagnostic &attachNote(Optional loc = llvm::None) { assert(isSilenceableFailure() && - "can only attach notes to silencable failures"); + "can only attach notes to silenceable failures"); return diagnostic->attachNote(loc); } @@ -123,7 +143,7 @@ : diagnostic(std::move(diagnostic)), result(failure()) {} /// The diagnostic associated with this object. If present, the object is - /// considered to be in the silencable failure state regardless of the + /// considered to be in the silenceable failure state regardless of the /// `result` field. Optional diagnostic; @@ -581,22 +601,29 @@ /// Trait implementing the TransformOpInterface for operations applying a /// transformation to a single operation handle and producing one or multiple /// operation handles. -/// The op must implement a method with one of the following signatures: -/// - FailureOr applyToOne(OpTy, state) -/// - FailureOr>applyToOne(OpTy, state) -/// - LogicalResult applyToOne(OpTy, state) +/// The op must implement a method with the following signature: +/// - DiagnosedSilenceableFailure applyToOne(OpTy, +/// SmallVector &results, state) /// to perform a transformation that is applied in turn to all payload IR /// operations that correspond to the handle of the transform IR operation. -/// In the functions above, OpTy is either Operation * or a concrete payload IR +/// In the functions above, OpTy is either Operation* or a concrete payload IR /// Op class that the transformation is applied to (NOT the class of the -/// transform IR op). The op is expected to have a single operand. +/// transform IR op). +/// The `applyToOne` method is allowed to fill the `results` vector with +/// NULL elements to signify that the transformation did not apply to this the +/// payload IR operations. +/// The op is expected to have a single operand. template class TransformEachOpTrait : public OpTrait::TraitBase { public: /// Calls `applyToOne` for every payload operation associated with the operand - /// of this transform IR op. If `applyToOne` returns ops, associates them with - /// the result of this transform op. + /// of this transform IR op. If `applyToOne` produces ops, associate them with + /// the result of this transform op. If any `applyToOne` returns + /// definiteFailure, the transformation is considered failed. If all + /// `applyToOne` return success, the transformation is considered succeeded. + /// If some `applyToOne` return silenceableFailure, the transformation is + /// considered silenceable. DiagnosedSilenceableFailure apply(TransformResults &transformResults, TransformState &state); @@ -714,88 +741,57 @@ namespace mlir { namespace transform { namespace detail { -/// Appends `result` to the vector assuming it corresponds to the success state -/// in `FailureOr`. If `result` is just a -/// `LogicalResult`, appends an empy vector. -template -std::enable_if_t::value, LogicalResult> -appendTransformResultToVector( - Ty result, SmallVectorImpl> &results) { - results.push_back(SmallVector()); - return result; -} - -template -std::enable_if_t< - llvm::conjunction< - llvm::negation>, - std::is_convertible>::value, - LogicalResult> -appendTransformResultToVector( - Ty result, SmallVectorImpl> &results) { - if (failed(result)) - return failure(); - results.push_back(SmallVector{*result}); - return success(); -} - -template -std::enable_if_t< - llvm::conjunction< - llvm::negation>, - llvm::negation>>::value, - LogicalResult> -appendTransformResultToVector( - ContainerTy resultContainer, - SmallVectorImpl> &results) { - if (failed(resultContainer)) - return failure(); - results.push_back(*resultContainer); - return success(); -} /// Applies a one-to-one or a one-to-many transform to each of the given /// targets. Puts the results of transforms, if any, in `results` in the same /// order. Fails if any of the application fails. Individual transforms must be -/// callable with one of the following signatures: -/// - FailureOr(OpTy) -/// - LogicalResult(OpTy) -/// - FailureOr>( -/// SmallVectorImpl) -/// - LogicalResult(SmallVectorImpl) +/// callable with the following signature: +/// - DiagnosedSilenceableFailure(OpTy, +/// SmallVector &results, state) /// where OpTy is either /// - Operation *, in which case the transform is always applied; /// - a concrete Op class, in which case a check is performed whether -/// `targets` contains operations of the same class and a silencable failure +/// `targets` contains operations of the same class and a silenceable failure /// is reported if it does not. template DiagnosedSilenceableFailure -applyTransformToEach(ArrayRef targets, +applyTransformToEach(Location loc, ArrayRef targets, SmallVectorImpl> &results, FnTy transform) { + SmallVector silenceableStack; using OpTy = typename llvm::function_traits::template arg_t<0>; static_assert(std::is_convertible::value, "expected transform function to take an operation"); - using RetTy = typename llvm::function_traits::result_t; - static_assert(std::is_convertible::value, - "expected transform function to return LogicalResult or " - "FailureOr"); for (Operation *target : targets) { + // Emplace back a placeholder for the returned new ops. This may remain + // empty if the op fails to apply + results.push_back(SmallVector()); + auto specificOp = dyn_cast(target); if (!specificOp) { - Diagnostic diag(target->getLoc(), DiagnosticSeverity::Error); - diag << "attempted to apply transform to the wrong op kind"; - return DiagnosedSilenceableFailure::silencableFailure(std::move(diag)); + Diagnostic diag(loc, DiagnosticSeverity::Warning); + diag << "transform applied to the wrong op kind, null handle returned"; + diag.attachNote(target->getLoc()) << "when applied to this op"; + // Producing 0-results when successful paths produce a fixed number of + // results is a reasonable silenceableFailure mode. + silenceableStack.push_back(std::move(diag)); + continue; } - auto result = transform(specificOp); - if (failed(appendTransformResultToVector(result, results))) - return DiagnosedSilenceableFailure::definiteFailure(); + DiagnosedSilenceableFailure result = transform(specificOp, results.back()); + if (result.isDefiniteFailure()) + return result; + if (result.isSilenceableFailure()) + silenceableStack.push_back(std::move(result.takeDiagnostic())); + } + // TODO: return/merge the stack of diagnostics. + if (!silenceableStack.empty()) { + return DiagnosedSilenceableFailure::silenceableFailure( + std::move(silenceableStack.back())); } return DiagnosedSilenceableFailure::success(); } -/// Helper function to transform M ops with N results into N results of M ops. +/// Helper function: transpose MxN into NxM; assumes that the input is a valid. static inline SmallVector> transposeResults(const SmallVector, 1> &m) { SmallVector> res; @@ -824,74 +820,95 @@ decltype(&OpTy::applyToOne)>::template arg_t<0>; ArrayRef targets = state.getPayloadOps(this->getOperation()->getOperand(0)); - // Handle the corner case where no target is specified. + + // Step 1. Handle the corner case where no target is specified. // This is typically the case when the matcher fails to apply and we need to // propagate gracefully. // In this case, we fill all results with an empty vector. if (targets.empty()) { - SmallVector emptyResult; + SmallVector empty; for (auto r : this->getOperation()->getResults()) - transformResults.set(r.template cast(), emptyResult); + transformResults.set(r.template cast(), empty); return DiagnosedSilenceableFailure::success(); } + // Step 2. Call applyToOne on each target and recorde newly produced ops in + // its corresponding results entry. SmallVector, 1> results; - // In the multi-result case, collect the number of results each transform - // produced. DiagnosedSilenceableFailure result = detail::applyTransformToEach( - targets, results, [&](TransformOpType specificOp) { - return static_cast(this)->applyToOne(specificOp, state); + this->getOperation()->getLoc(), targets, results, + [&](TransformOpType specificOp, SmallVector &partialResult) { + return static_cast(this)->applyToOne(specificOp, partialResult, + state); }); - // Propagate the failure (definite or silencable) if any. - if (!result.succeeded()) + + // Step 3. Propagate the definite failure if any and bail out. + if (result.isDefiniteFailure()) return result; - // Legitimately no results, bail early. + // Step 4. If there are legitimately no results, return early. if (results.empty() && OpTy::template hasTrait()) return DiagnosedSilenceableFailure::success(); - // Ensure all applications return the same number of results. - // Variadic cases are much trickier to handle in a generic fashion. - int64_t nRes = results.empty() ? 0 : results[0].size(); - if (llvm::any_of(results, [&](const auto &r) { - return static_cast(r.size()) != nRes; - })) { - return static_cast(this)->emitSilenceableError() - << "expected all applications of " << OpTy::getOperationName() - << " to produce " << nRes - << " results.\n If you need variadic results, consider using a " - "generic `apply` instead of the specialized `applyToOne`"; - } - // Ensure the number of results agrees with what the transform op expects. - // Unless we see empty results, in which case we just want to propagate the - // emptiness. - if (this->getOperation()->getNumResults() != nRes) { - InFlightDiagnostic diag = static_cast(this)->emitError() - << "unexpected number of results (got " << nRes - << " expected " - << this->getOperation()->getNumResults() << ")"; + // Step 5. Ensure all applications return the same number of results. + // Also allow a particular application of applyToOne to return no results, + // which must mean that we have a silenceableFailure result. + // In this case, we normalize the results to an expected nRes sized + // + // Note: variadic cases are not supported here as they are much trickier to + // handle in a generic fashion. + int nRes = this->getOperation()->getNumResults(); + int produced = -1; + bool normalized = llvm::all_of(results, [&](auto &r) { + if (nRes > 0 && r.empty()) { + if (!result.isSilenceableFailure()) { + // Empty results are only allowed with a silenceableFailure. + produced = 0; + return false; + } + r.resize(nRes, nullptr); + } + if (static_cast(r.size()) != nRes) { + produced = static_cast(r.size()); + return false; + } + return true; + }); + if (!normalized) { + mlir::emitError(this->getOperation()->getLoc(), "applications of ") + << OpTy::getOperationName() + << " expected to produce 0 (silenceableFailure case) or " << nRes + << " (success case) results (actually produced " << produced + << " with state '" << result.getStatusString() << "')." + << "\nIf you need variadic results, consider using a " + "generic `apply` instead of the specialized `applyToOne`"; return DiagnosedSilenceableFailure::definiteFailure(); } - // Perform transposition of M applications producing N results each into N - // results for each of the M applications. + // Step 6. Perform transposition of M applications producing N results each + // into N results for each of the M applications. SmallVector> transposedResults = detail::transposeResults(results); - // Single result applies to M ops produces one single M-result. + + // Step 7. Single result applies to M ops produces one single M-result. if (OpTy::template hasTrait()) { assert(transposedResults.size() == 1 && "Expected single result"); transformResults.set( this->getOperation()->getResult(0).template cast(), transposedResults[0]); - return DiagnosedSilenceableFailure::success(); + // ApplyToOne may have returned silenceableFailure, propagate it. + return result; } - // M ops, N results each. + + // Step 8. Set the transformResults: M ops, N results each. for (const auto &it : llvm::zip(this->getOperation()->getResults(), transposedResults)) { transformResults.set(std::get<0>(it).template cast(), std::get<1>(it)); } - return DiagnosedSilenceableFailure::success(); + + // Step 9. ApplyToOne may have returned silenceableFailure, propagate it. + return result; } template diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -62,11 +62,11 @@ return diag; } - /// Creates the silencable failure object with a diagnostic located at the + /// Creates the silenceable failure object with a diagnostic located at the /// current operation. DiagnosedSilenceableFailure emitSilenceableError() { Diagnostic diag($_op->getLoc(), DiagnosticSeverity::Error); - return DiagnosedSilenceableFailure::silencableFailure(std::move(diag)); + return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); } }]; } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -76,19 +76,21 @@ // DecomposeOp //===----------------------------------------------------------------------===// -FailureOr transform::DecomposeOp::applyToOne(LinalgOp target, - TransformState &state) { +DiagnosedSilenceableFailure +transform::DecomposeOp::applyToOne(linalg::LinalgOp target, + SmallVectorImpl &results, + transform::TransformState &state) { FailureOr windowed = tryApply(target); - if (succeeded(windowed)) - return windowed; - FailureOr depthwise = tryApply(target); - if (succeeded(depthwise)) - return depthwise; - - return reportUnknownTransformError(target); + if (succeeded(windowed)) + results.push_back(*windowed); + else if (succeeded(depthwise)) + results.push_back(*depthwise); + else + results.push_back(nullptr); + return DiagnosedSilenceableFailure(success()); } //===----------------------------------------------------------------------===// @@ -221,41 +223,44 @@ // GeneralizeOp //===----------------------------------------------------------------------===// -FailureOr transform::GeneralizeOp::applyToOne(LinalgOp target, - TransformState &state) { +DiagnosedSilenceableFailure +transform::GeneralizeOp::applyToOne(linalg::LinalgOp target, + SmallVectorImpl &results, + transform::TransformState &state) { // Exit early if no transformation is needed. - if (isa(target)) - return target; - + if (isa(target)) { + results.push_back(target); + return DiagnosedSilenceableFailure(success()); + } FailureOr generic = tryApply(target); - if (succeeded(generic)) - return generic; - - return reportUnknownTransformError(target); + // If the generalization did not happen, push a nullptr. + results.push_back(succeeded(generic) ? generic->getOperation() : nullptr); + return DiagnosedSilenceableFailure(success()); } //===----------------------------------------------------------------------===// // InterchangeOp //===----------------------------------------------------------------------===// -FailureOr -transform::InterchangeOp::applyToOne(LinalgOp target, TransformState &state) { +DiagnosedSilenceableFailure +transform::InterchangeOp::applyToOne(linalg::GenericOp target, + SmallVectorImpl &results, + transform::TransformState &state) { SmallVector interchangeVector = extractUIntArray(getIteratorInterchange()); // Exit early if no transformation is needed. - if (interchangeVector.empty()) - return target; - - auto genericTarget = dyn_cast(target.getOperation()); - if (!genericTarget) { - InFlightDiagnostic diag = emitOpError() - << "applies to " << GenericOp::getOperationName() - << " ops"; - diag.attachNote(target.getLoc()) << "attempted to apply to this op"; - return diag; + if (interchangeVector.empty()) { + results.push_back(target); + return DiagnosedSilenceableFailure(success()); } - return tryApply(target, interchangeVector); + SimpleRewriter rewriter(target->getContext()); + FailureOr res = + interchangeGenericOp(rewriter, target, interchangeVector); + if (failed(res)) + return DiagnosedSilenceableFailure::definiteFailure(); + results.push_back(res->getOperation()); + return DiagnosedSilenceableFailure(success()); } LogicalResult transform::InterchangeOp::verify() { @@ -275,8 +280,10 @@ // PadOp //===---------------------------------------------------------------------===// -FailureOr transform::PadOp::applyToOne(LinalgOp target, - TransformState &state) { +DiagnosedSilenceableFailure +transform::PadOp::applyToOne(linalg::LinalgOp target, + SmallVectorImpl &results, + transform::TransformState &state) { // Convert the integer packing flags to booleans. SmallVector packPaddings; for (int64_t packPadding : extractI64Array(getPackPaddings())) @@ -293,21 +300,19 @@ paddingValues.push_back( parseAttribute(attr.cast(), elementType)); if (!paddingValues.back()) { - InFlightDiagnostic diag = emitOpError() - << "expects a padding value that parses to " - << elementType << ", got " << std::get<0>(it); + auto diag = this->emitOpError("expects a padding that parses to ") + << elementType << ", got " << std::get<0>(it); diag.attachNote(target.getLoc()) << "when applied to this op"; - return diag; + return DiagnosedSilenceableFailure::definiteFailure(); } continue; } // Otherwise, add the attribute directly. if (attr.getType() != elementType) { - InFlightDiagnostic diag = emitOpError() - << "expects a padding value of type " - << elementType << ", got " << attr; + auto diag = this->emitOpError("expects a padding value of type ") + << elementType << ", got " << attr; diag.attachNote(target.getLoc()) << "when applied to this op"; - return diag; + return DiagnosedSilenceableFailure::definiteFailure(); } paddingValues.push_back(attr); } @@ -327,13 +332,14 @@ FailureOr result = tryApply(target, paddingOptions); - if (succeeded(result)) - return result; + if (succeeded(result)) { + results.push_back(result->getOperation()); + return DiagnosedSilenceableFailure(success()); + } - InFlightDiagnostic diag = emitError() - << "failed to apply pattern to target op"; + auto diag = this->emitOpError("failed to apply pattern to target op"); diag.attachNote(target.getLoc()) << "target op"; - return diag; + return DiagnosedSilenceableFailure::definiteFailure(); } LogicalResult transform::PadOp::verify() { @@ -381,8 +387,10 @@ // ScalarizeOp //===----------------------------------------------------------------------===// -FailureOr transform::ScalarizeOp::applyToOne(LinalgOp target, - TransformState &state) { +DiagnosedSilenceableFailure +transform::ScalarizeOp::applyToOne(linalg::LinalgOp target, + SmallVectorImpl &results, + transform::TransformState &state) { LinalgTilingOptions tilingOptions; tilingOptions.scalarizeDynamicDims(); // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile @@ -393,19 +401,24 @@ rewriter.setInsertionPoint(target); FailureOr result = pattern.returningMatchAndRewrite(target, rewriter); - if (failed(result)) - return failure(); + if (failed(result)) { + auto diag = this->emitOpError("failed"); + diag.attachNote(target.getLoc()) << "when applied to this op"; + return DiagnosedSilenceableFailure::definiteFailure(); + } - return result->op; + results.push_back(result->op); + return DiagnosedSilenceableFailure(success()); } //===----------------------------------------------------------------------===// // SplitReductionOp //===----------------------------------------------------------------------===// -FailureOr> -transform::SplitReductionOp::applyToOne(LinalgOp target, - TransformState &state) { +DiagnosedSilenceableFailure +transform::SplitReductionOp::applyToOne(linalg::LinalgOp target, + SmallVectorImpl &results, + transform::TransformState &state) { ControlSplitReductionFn splitFn = [&](LinalgOp) { return std::pair(getSplitFactor(), getInsertSplitDimension()); @@ -416,11 +429,15 @@ (getUseScalingAlgorithm()) ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc()) : splitReduction(rewriter, target, splitFn, getUseAlloc()); - if (failed(splitResult)) - return getOperation()->emitError("failed to apply"); - return SmallVector{splitResult->fillOp, - splitResult->splitLinalgOp, - splitResult->resultCombiningLinalgOp}; + if (failed(splitResult)) { + auto diag = this->emitOpError("failed "); + diag.attachNote(target.getLoc()) << "when applied to this op"; + return DiagnosedSilenceableFailure::definiteFailure(); + } + results.push_back(splitResult->fillOp); + results.push_back(splitResult->splitLinalgOp); + results.push_back(splitResult->resultCombiningLinalgOp); + return DiagnosedSilenceableFailure(success()); } //===----------------------------------------------------------------------===// @@ -463,13 +480,14 @@ // VectorizeOp //===----------------------------------------------------------------------===// -FailureOr VectorizeOp::applyToOne(Operation *target, - TransformState &state) { +DiagnosedSilenceableFailure +transform::VectorizeOp::applyToOne(Operation *target, + SmallVectorImpl &results, + transform::TransformState &state) { if (!target->hasTrait()) { - InFlightDiagnostic diag = emitOpError() - << "applies only to isolated-from-above targets"; + auto diag = this->emitOpError("requires isolated-from-above targets"); diag.attachNote(target->getLoc()) << "non-isolated target"; - return diag; + return DiagnosedSilenceableFailure::definiteFailure(); } MLIRContext *ctx = getContext(); @@ -487,8 +505,10 @@ linalg::populatePadOpVectorizationPatterns(patterns); if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) - return reportUnknownTransformError(target); - return target; + return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + + results.push_back(target); + return DiagnosedSilenceableFailure(success()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -127,18 +127,21 @@ // LoopPeelOp //===----------------------------------------------------------------------===// -FailureOr transform::LoopPeelOp::applyToOne(scf::ForOp loop, - TransformState &state) { +DiagnosedSilenceableFailure +transform::LoopPeelOp::applyToOne(scf::ForOp target, + SmallVector &results, + transform::TransformState &state) { scf::ForOp result; - IRRewriter rewriter(loop->getContext()); + IRRewriter rewriter(target->getContext()); + // This helper returns failure when peeling does not occur (i.e. when the IR + // is not modified). This is not a failure for the op as the postcondition: + // "the loop trip count is divisible by the step" + // is valid. LogicalResult status = - scf::peelAndCanonicalizeForLoop(rewriter, loop, result); - if (failed(status)) { - if (getFailIfAlreadyDivisible()) - return reportUnknownTransformError(loop); - return loop; - } - return result; + scf::peelAndCanonicalizeForLoop(rewriter, target, result); + // TODO: Return both the peeled loop and the remainder loop. + results.push_back(failed(status) ? target : result); + return DiagnosedSilenceableFailure(success()); } //===----------------------------------------------------------------------===// @@ -181,8 +184,10 @@ } } -FailureOr -transform::LoopPipelineOp::applyToOne(scf::ForOp loop, TransformState &state) { +DiagnosedSilenceableFailure +transform::LoopPipelineOp::applyToOne(scf::ForOp target, + SmallVector &results, + transform::TransformState &state) { scf::PipeliningOption options; options.getScheduleFn = [this](scf::ForOp forOp, @@ -191,25 +196,29 @@ getReadLatency()); }; - scf::ForLoopPipeliningPattern pattern(options, loop->getContext()); + scf::ForLoopPipeliningPattern pattern(options, target->getContext()); SimpleRewriter rewriter(getContext()); - rewriter.setInsertionPoint(loop); + rewriter.setInsertionPoint(target); FailureOr patternResult = - pattern.returningMatchAndRewrite(loop, rewriter); - if (failed(patternResult)) - return reportUnknownTransformError(loop); - return patternResult; + pattern.returningMatchAndRewrite(target, rewriter); + results.push_back(failed(patternResult) ? scf::ForOp() : *patternResult); + return DiagnosedSilenceableFailure(success()); } //===----------------------------------------------------------------------===// // LoopUnrollOp //===----------------------------------------------------------------------===// -LogicalResult transform::LoopUnrollOp::applyToOne(scf::ForOp loop, - TransformState &state) { - if (failed(loopUnrollByFactor(loop, getFactor()))) - return reportUnknownTransformError(loop); - return success(); +DiagnosedSilenceableFailure +transform::LoopUnrollOp::applyToOne(scf::ForOp target, + SmallVector &results, + transform::TransformState &state) { + if (failed(loopUnrollByFactor(target, getFactor()))) { + Diagnostic diag(target->getLoc(), DiagnosticSeverity::Note); + diag << "op failed to unroll"; + return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); + } + return DiagnosedSilenceableFailure(success()); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/transform-op-interchange.mlir b/mlir/test/Dialect/Linalg/transform-op-interchange.mlir --- a/mlir/test/Dialect/Linalg/transform-op-interchange.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-interchange.mlir @@ -37,7 +37,7 @@ // ----- func.func @interchange_matmul(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - // expected-note @below {{attempted to apply to this op}} + // expected-note @below {{when applied to this op}} %0 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) -> tensor return %0 : tensor } @@ -54,7 +54,7 @@ transform.sequence %arg0 { ^bb1(%arg1: !pdl.operation): %0 = pdl_match @match_generic in %arg1 - // expected-error @below {{applies to linalg.generic ops}} + // expected-warning @below {{transform applied to the wrong op kind, null handle returned}} transform.structured.interchange %0 { iterator_interchange = [1, 0]} } } diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir --- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir @@ -99,7 +99,7 @@ transform.sequence %arg0 { ^bb1(%arg1: !pdl.operation): %0 = pdl_match @pdl_target in %arg1 - // expected-error @below {{expects a padding value that parses to 'f32', got "foo"}} + // expected-error @below {{expects a padding that parses to 'f32', got "foo"}} %1 = transform.structured.pad %0 {padding_values=["foo", 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0]} } } diff --git a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir --- a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir @@ -176,7 +176,7 @@ transform.sequence %arg0 { ^bb1(%arg1: !pdl.operation): %0 = pdl_match @pdl_target in %arg1 - // expected-error @below {{applies only to isolated-from-above targets}} + // expected-error @below {{op requires isolated-from-above targets}} %2 = transform.structured.vectorize %0 } } diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -382,7 +382,7 @@ transform.sequence { ^bb0(%arg0: !pdl.operation): - // expected-error @below {{unexpected number of results (got 0 expected 3)}} + // expected-error @below {{applications of transform.test_wrong_number_of_results expected to produce 0 (silenceableFailure case) or 3 (success case) results (actually produced 1 with state 'success')}} transform.test_wrong_number_of_results %arg0 } @@ -406,7 +406,7 @@ transform.sequence %arg0 { ^bb0(%arg1: !pdl.operation): %0 = pdl_match @some in %arg1 - // expected-error @below {{expected all applications of transform.test_wrong_number_of_multi_results to produce 1 results}} + // expected-error @below {{applications of transform.test_wrong_number_of_multi_results expected to produce 0 (silenceableFailure case) or 1 (success case) results (actually produced 0 with state 'success')}} transform.test_wrong_number_of_multi_results %0 } } diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -226,28 +226,35 @@ return DiagnosedSilenceableFailure::success(); } -FailureOr> -mlir::test::TestWrongNumberOfResultsOp::applyToOne( - Operation *, transform::TransformState &state) { - return SmallVector{}; +DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfResultsOp::applyToOne( + Operation *target, SmallVectorImpl &results, + transform::TransformState &state) { + OperationState opState(target->getLoc(), "foo"); + results.push_back(OpBuilder(target).create(opState)); + return DiagnosedSilenceableFailure::success(); } -FailureOr> +DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfMultiResultsOp::applyToOne( - Operation *op, transform::TransformState &state) { + Operation *target, SmallVectorImpl &results, + transform::TransformState &state) { static int count = 0; - if (count++ > 0) - return SmallVector{}; - OperationState opState(op->getLoc(), "foo"); - return SmallVector{OpBuilder(op).create(opState)}; + if (count++ == 0) { + OperationState opState(target->getLoc(), "foo"); + results.push_back(OpBuilder(target).create(opState)); + } + return DiagnosedSilenceableFailure::success(); } -FailureOr> +DiagnosedSilenceableFailure mlir::test::TestCorrectNumberOfMultiResultsOp::applyToOne( - Operation *op, transform::TransformState &state) { - OperationState opState(op->getLoc(), "foo"); - return SmallVector{OpBuilder(op).create(opState), - OpBuilder(op).create(opState)}; + Operation *target, SmallVectorImpl &results, + transform::TransformState &state) { + OperationState opState(target->getLoc(), "foo"); + results.push_back(OpBuilder(target).create(opState)); + results.push_back(OpBuilder(target).create(opState)); + return DiagnosedSilenceableFailure::success(); + return DiagnosedSilenceableFailure::success(); } namespace { diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -139,8 +139,10 @@ let assemblyFormat = "$target attr-dict"; let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ - ::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne( - ::mlir::Operation *target, transform::TransformState &state); + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation * target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); }]; } @@ -153,8 +155,10 @@ let assemblyFormat = "$target attr-dict"; let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ - ::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne( - ::mlir::Operation *target, transform::TransformState &state); + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation * target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); }]; } @@ -168,8 +172,10 @@ let assemblyFormat = "$target attr-dict"; let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ - ::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne( - ::mlir::Operation *target, transform::TransformState &state); + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation * target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); }]; }