diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h @@ -15,6 +15,7 @@ namespace mlir { namespace linalg { +class GenericOp; class LinalgOp; } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -22,9 +22,15 @@ let description = [{ Decomposes named complex operations, such as higher-dimensional (depthwise) convolutions, into combinations of lower-dimensional equivalents - when possible. The operand handle must point to a list of such operations. - The returning handle points to the main produced computational operation, - such as the lower-dimensional convolution. + when possible. + + Return modes: + ============= + This operation ignores non-Linalg ops and drops them in the return. + If all the operations referred to by the `target` PDLOperation decompose + properly, the transform succeeds. Otherwise the transform silently fails. + The return handle points to only the subset of successfully produced + computational operations, which can be empty. }]; let arguments = (ins PDL_Operation:$target); @@ -32,8 +38,10 @@ let assemblyFormat = "$target attr-dict"; let extraClassDeclaration = [{ - ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne( - ::mlir::linalg::LinalgOp target, TransformState &state); + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::linalg::LinalgOp target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); }]; } @@ -61,11 +69,16 @@ TransformOpInterface, TransformEachOpTrait]> { let description = [{ Transforms a named structued operation into the generic form with the - explicit attached region. The operand handle must point to a list of - structured operations, it is consumed by the transformation and is not - expected to be used afterwards. The resulting handle points to the list - of equivalent generic operations, in the same order as the original named - operations. + explicit attached region. + + Return modes: + ============= + This operation ignores non-Linalg ops and drops them in the return. + If all the operations referred to by the `target` PDLOperation generalize + properly, the transform succeeds. Otherwise the transform silently fails. + The return handle points to only the subset of successfully produced + equivalent generic operations, which can be empty or contain the original + ops if they were already in generic form. }]; let arguments = (ins PDL_Operation:$target); @@ -73,8 +86,10 @@ let assemblyFormat = "$target attr-dict"; let extraClassDeclaration = [{ - ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne( - ::mlir::linalg::LinalgOp target, TransformState &state); + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::linalg::LinalgOp target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); }]; } @@ -84,6 +99,16 @@ let description = [{ Interchanges the iterators of the operations pointed to by the target handle using the iterator interchange attribute. + + Return modes: + ============= + This operation ignores non-linalg::Generic ops and drops them in the return. + This operation fails if the interchange attribute is invalid. + If all the operations referred to by the `target` PDLOperation interchange + properly, the transform succeeds. + If any interchange fails, the transform definitely fails. + The return handle points to only the subset of successfully produced + interchanged operations, which can be empty. }]; let arguments = @@ -95,8 +120,10 @@ let hasVerifier = 1; let extraClassDeclaration = [{ - ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne( - ::mlir::linalg::LinalgOp target, TransformState &state); + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::linalg::GenericOp target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); }]; } @@ -106,6 +133,16 @@ let description = [{ Pads the operations pointed to by the target handle using the options provides as operation attributes. + + Return modes: + ============= + This operation ignores non-Linalg ops and drops them in the return. + This operation may produce a definiteFailure if the padding fails for any + reason. + If all the operations referred to by the `target` PDLOperation pad + properly, the transform succeeds. Otherwise the transform silently fails. + The return handle points to only the subset of successfully produced + padded operations, which can be empty. }]; let arguments = @@ -123,8 +160,10 @@ let hasVerifier = 1; let extraClassDeclaration = [{ - ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne( - ::mlir::linalg::LinalgOp target, TransformState &state); + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::linalg::LinalgOp target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); }]; } @@ -135,11 +174,23 @@ Indicates that ops of a specific kind in the given function should be scalarized (i.e. their dynamic dimensions tiled by 1). - This operation returns the tiled op but not the loops. + Return modes: + ============= + This operation ignores non-Linalg ops and drops them in the return. + This operation produces `definiteFailure` if the scalarization fails for any + reason. + If all the operations referred to by the `target` PDLOperation scalarize + properly, the transform succeeds. Otherwise the transform silently fails. + The return handle points to only the subset of successfully produced + tiled-by-1 operations, which can be empty. + + This operation does not return handles to the tiled loop. We make this design choice because it is hard to know ahead of time the number of loops that will be produced (it depends on the number of dynamic dimensions after multiple transformations have been applied). + Loops can always be recovered by navigating from the tiled operations if + needed. }]; let arguments = (ins PDL_Operation:$target); @@ -148,8 +199,10 @@ let assemblyFormat = "$target attr-dict"; let extraClassDeclaration = [{ - ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne( - ::mlir::linalg::LinalgOp target, TransformState &state); + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::linalg::LinalgOp target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); }]; } @@ -206,7 +259,17 @@ - use_alloc: whether to use an alloc op to allocate the temporary tensor (default: do not use alloc op) - This op returns 4 handles to: + Return modes: + ============= + This operation ignores non-Linalg ops and drops them in the return. + This operation produces `definiteFailure` if the splitting fails for any + reason. + + If all the operations referred to by the `target` PDLOperation split + properly, the transform succeeds. Otherwise the transform silently fails. + The 4 returned handles points to only the subset of successfully produced + computational operations, which can all be empty. + This 4 returned handles point to: - the init op (or tensor_alloc op if use_alloc = true), - the fill op used to initialize the neutral element, - the split op and @@ -316,15 +379,18 @@ DefaultValuedAttr:$insert_split_dimension, UnitAttr:$use_scaling_algorithm, UnitAttr:$use_alloc); - let results = (outs PDL_Operation:$fill_op, + let results = (outs PDL_Operation:$init_or_alloc_op, + PDL_Operation:$fill_op, PDL_Operation:$split_linalg_op, PDL_Operation:$combining_linalg_op); let assemblyFormat = "$target attr-dict"; let extraClassDeclaration = [{ - ::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne( - ::mlir::linalg::LinalgOp target, TransformState &state); + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::linalg::LinalgOp target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); }]; } @@ -372,6 +438,13 @@ Note that this transformation is invalidating the handles to any payload IR operation that is contained inside the vectorization target. + + Return modes: + ============= + This operation produces `definiteFailure` if vectorization fails for any + reason. + The operation always returns the handle to the target op that is expected + to be isolated from above. }]; let arguments = (ins PDL_Operation:$target, @@ -381,8 +454,10 @@ let assemblyFormat = "$target attr-dict"; let extraClassDeclaration = [{ - ::mlir::FailureOr applyToOne( - ::mlir::Operation *target, TransformState &state); + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); }]; } diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -239,6 +239,8 @@ /// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be /// integers, in the range 0..`op.rank` without duplications /// (i.e. `[1,1,2]` is an invalid permutation). +/// +/// Return failure if the permutation is not valid. FailureOr interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef interchangeVector); diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -50,7 +50,7 @@ outlined into a separate function. The provided name is used as a _base_ for forming actual function names following SymbolTable auto-renaming scheme to avoid duplicate symbols. Expects that all ops in the Payload IR - have a SymbolTable ancestor (typically true because of the top-level + have a SymbolTable ancestor (typically true because of the top-level module). Returns the handle to the list of outlined functions in the same order as the operand handle. }]; @@ -68,28 +68,40 @@ let summary = "Peels the last iteration of the loop"; let description = [{ Updates the given loop so that its step evenly divides its range and puts - the remaining iteration into a separate loop or a conditional. Note that - even though the Payload IR modification may be performed in-place, this - operation consumes the operand handle and produces a new one. Applies to - each loop associated with the operand handle individually. The results - follow the same order as the operand. - - Note: If it can be proven statically that the step already evenly divides - the range, this op is a no-op. In the absence of sufficient static - information, this op may peel a loop, even if the step always divides the - range evenly at runtime. + the remaining iteration into a separate loop or a conditional. + + In the absence of sufficient static information, this op may peel a loop, + even if the step always divides the range evenly at runtime. + + Return modes: + ============= + This operation ignores non-scf::ForOp ops and drops them in the return. + + This operation always succeeds and returns the scf::ForOp with the + postcondition: "the loop trip count is divisible by the step". + This operation may return the same unmodified loop handle when peeling did + not modify the IR (i.e. the loop trip count was already divisible). + + Note that even though the Payload IR modification may be performed + in-place, this operation consumes the operand handle and produces a new + one. + + TODO: Return both the peeled loop and the remainder loop. }]; let arguments = (ins PDL_Operation:$target, DefaultValuedAttr:$fail_if_already_divisible); + // TODO: Return both the peeled loop and the remainder loop. let results = (outs PDL_Operation:$transformed); let assemblyFormat = "$target attr-dict"; let extraClassDeclaration = [{ - ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne( - ::mlir::scf::ForOp loop, TransformState &state); + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::scf::ForOp target, + ::llvm::SmallVector<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); }]; } @@ -102,10 +114,21 @@ each of them. That is, performs some amount of reads from memory before the loop rather than inside the loop, the same amount of writes into memory after the loop, and updates each iteration to read the data for a following - iteration rather than the current one. The amount is specified by the - attributes. The values read and about to be stored are transferred as loop - iteration arguments. Currently supports memref and vector transfer - operations as memory reads/writes. + iteration rather than the current one. + + The amount is specified by the attributes. + + The values read and about to be stored are transferred as loop iteration + arguments. Currently supports memref and vector transfer operations as + memory reads/writes. + + Return modes: + ============= + This operation ignores non-scf::For ops and drops them in the return. + If all the operations referred to by the `target` PDLOperation pipeline + properly, the transform succeeds. Otherwise the transform silently fails. + The return handle points to only the subset of successfully produced + pipelined loops, which can be empty. }]; let arguments = (ins PDL_Operation:$target, @@ -116,8 +139,10 @@ let assemblyFormat = "$target attr-dict"; let extraClassDeclaration = [{ - ::mlir::FailureOr<::mlir::scf::ForOp> applyToOne( - ::mlir::scf::ForOp loop, TransformState &state); + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::scf::ForOp target, + ::llvm::SmallVector<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); }]; } @@ -126,11 +151,18 @@ TransformOpInterface, TransformEachOpTrait]> { let summary = "Unrolls the given loop with the given unroll factor"; let description = [{ - Unrolls each loop associated with the given handle to have up to the given - number of loop body copies per iteration. If the unroll factor is larger - than the loop trip count, the latter is used as the unroll factor instead. - Does not produce a new handle as the operation may result in the loop being - removed after a full unrolling. + Unrolls each loop associated with the given handle to have up to the given + number of loop body copies per iteration. If the unroll factor is larger + than the loop trip count, the latter is used as the unroll factor instead. + + Return modes: + ============== + This operation ignores non-scf::For ops and drops them in the return. + If all the operations referred to by the `target` PDLOperation unroll + properly, the transform succeeds. Otherwise the transform silently fails. + + Does not return handles as the operation may result in the loop being + removed after a full unrolling. }]; let arguments = (ins PDL_Operation:$target, @@ -139,8 +171,10 @@ let assemblyFormat = "$target attr-dict"; let extraClassDeclaration = [{ - ::mlir::LogicalResult applyToOne( - ::mlir::scf::ForOp loop, TransformState &state); + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::scf::ForOp target, + ::llvm::SmallVector<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); }]; } diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -12,13 +12,14 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "llvm/ADT/ScopeExit.h" namespace mlir { /// The result of a transform IR operation application. This can have one of the /// three states: /// - success; -/// - silencable (recoverable) failure with yet-unreported diagnostic; +/// - silenceable (recoverable) failure with yet-unreported diagnostic; /// - definite failure. /// Silenceable failure is intended to communicate information about /// transformations that did not apply but in a way that supports recovery, @@ -26,10 +27,10 @@ /// predictable way. They are associated with a Diagnostic that provides more /// details on the failure. Silenceable failure can be discarded, turning the /// result into success, or "reported", emitting the diagnostic and turning the -/// result into definite failure. Transform IR operations containing other -/// operations are allowed to do either with the results of the nested -/// transformations, but must propagate definite failures as their diagnostics -/// have been already reported to the user. +/// result into definite failure. +/// Transform IR operations containing other operations are allowed to do either +/// with the results of the nested transformations, but must propagate definite +/// failures as their diagnostics have been already reported to the user. class LLVM_NODISCARD DiagnosedSilenceableFailure { public: explicit DiagnosedSilenceableFailure(LogicalResult result) : result(result) {} @@ -51,12 +52,17 @@ return DiagnosedSilenceableFailure(::mlir::failure()); } - /// Constructs a DiagnosedSilenceableFailure in the silencable failure state, + /// Constructs a DiagnosedSilenceableFailure in the silenceable failure state, /// ready to emit the given diagnostic. This is considered a failure /// regardless of the diagnostic severity. - static DiagnosedSilenceableFailure silencableFailure(Diagnostic &&diag) { + static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag) { return DiagnosedSilenceableFailure(std::forward(diag)); } + static DiagnosedSilenceableFailure + silenceableFailure(SmallVector &&diag) { + return DiagnosedSilenceableFailure( + std::forward>(diag)); + } /// Converts all kinds of failure into a LogicalResult failure, emitting the /// diagnostic if necessary. Must not be called more than once. @@ -65,44 +71,72 @@ assert(!reported && "attempting to report a diagnostic more than once"); reported = true; #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS - if (diagnostic) { - diagnostic->getLocation().getContext()->getDiagEngine().emit( - std::move(*diagnostic)); - diagnostic.reset(); + if (!diagnostics.empty()) { + for (auto &&diagnostic : diagnostics) { + diagnostic.getLocation().getContext()->getDiagEngine().emit( + std::move(diagnostic)); + } + diagnostics.clear(); result = ::mlir::failure(); } return result; } - /// Returns `true` if this is a silencable failure. - bool isSilenceableFailure() const { return diagnostic.hasValue(); } + /// Returns `true` if this is a silenceable failure. + bool isDefiniteFailure() const { return result.failed(); } + + /// Returns `true` if this is a silenceable failure. + bool isSilenceableFailure() const { return !diagnostics.empty(); } /// Returns `true` if this is a success. bool succeeded() const { - return !diagnostic.hasValue() && ::mlir::succeeded(result); + return diagnostics.empty() && ::mlir::succeeded(result); } /// Returns the diagnostic message without emitting it. Expects this object - /// to be a silencable failure. - std::string getMessage() const { return diagnostic->str(); } + /// to be a silenceable failure. + std::string getMessage() const { + std::string res; + for (auto &diagnostic : diagnostics) { + res.append(diagnostic.str()); + res.append("\n"); + } + return res; + } + + /// Returns a string representation of the failure mode (for error reporting). + std::string getStatusString() const { + if (succeeded()) + return "success"; + if (isSilenceableFailure()) + return "silenceable failure"; + return "definite failure"; + } - /// Converts silencable failure into LogicalResult success without reporting + /// Converts silenceable failure into LogicalResult success without reporting /// the diagnostic, preserves the other states. LogicalResult silence() { - if (diagnostic) { - diagnostic.reset(); + if (!diagnostics.empty()) { + diagnostics.clear(); result = ::mlir::success(); } return result; } - /// Streams the given values into the diagnotic. Expects this object to be a - /// silencable failure. + /// Take the diagnostic and silence. + SmallVector &&takeDiagnostics() { + assert(!diagnostics.empty() && "expected a diagnostic to be present"); + auto guard = llvm::make_scope_exit([&]() { diagnostics.clear(); }); + return std::move(diagnostics); + } + + /// Streams the given values into the last diagnotic. + /// Expects this object to be a silenceable failure. template DiagnosedSilenceableFailure &operator<<(T &&value) & { assert(isSilenceableFailure() && - "can only append output in silencable failure state"); - *diagnostic << std::forward(value); + "can only append output in silenceable failure state"); + diagnostics.back() << std::forward(value); return *this; } template @@ -110,31 +144,36 @@ return std::move(this->operator<<(std::forward(value))); } - /// Attaches a note to the diagnostic. Expects this object to be a silencable - /// failure. + /// Attaches a note to the last diagnostic. + /// Expects this object to be a silenceable failure. Diagnostic &attachNote(Optional loc = llvm::None) { assert(isSilenceableFailure() && - "can only attach notes to silencable failures"); - return diagnostic->attachNote(loc); + "can only attach notes to silenceable failures"); + return diagnostics.back().attachNote(loc); } private: explicit DiagnosedSilenceableFailure(Diagnostic &&diagnostic) - : diagnostic(std::move(diagnostic)), result(failure()) {} + : diagnostics(), result(failure()) { + diagnostics.emplace_back(std::move(diagnostic)); + } + explicit DiagnosedSilenceableFailure(SmallVector &&diagnostics) + : diagnostics(std::move(diagnostics)), result(failure()) {} - /// The diagnostic associated with this object. If present, the object is - /// considered to be in the silencable failure state regardless of the + /// The diagnostics associated with this object. If non-empty, the object is + /// considered to be in the silenceable failure state regardless of the /// `result` field. - Optional diagnostic; + SmallVector diagnostics; - /// The "definite" logical state, either success or failure. Ignored if the - /// diagnostic message is present. + /// The "definite" logical state, either success or failure. + /// Ignored if the diagnostics message is present. LogicalResult result; #if LLVM_ENABLE_ABI_BREAKING_CHECKS - /// Whther the associated diagnostic have been reported. Diagnostic reporting - /// consumes the diagnostic, so we need a mechanism to differentiate a - /// reported diagnostic from a state where it was never created. + /// Whether the associated diagnostics have been reported. + /// Diagnostics reporting consumes the diagnostics, so we need a mechanism to + /// differentiate reported diagnostics from a state where it was never + /// created. bool reported = false; #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS }; @@ -579,24 +618,45 @@ }; /// Trait implementing the TransformOpInterface for operations applying a -/// transformation to a single operation handle and producing one or multiple -/// operation handles. -/// The op must implement a method with one of the following signatures: -/// - FailureOr applyToOne(OpTy, state) -/// - FailureOr>applyToOne(OpTy, state) -/// - LogicalResult applyToOne(OpTy, state) +/// transformation to a single operation handle and producing zero, one or +/// multiple operation handles. +/// The op must implement a method with the following signature: +/// - DiagnosedSilenceableFailure applyToOne(OpTy, +/// SmallVector &results, state) /// to perform a transformation that is applied in turn to all payload IR /// operations that correspond to the handle of the transform IR operation. -/// In the functions above, OpTy is either Operation * or a concrete payload IR -/// Op class that the transformation is applied to (NOT the class of the -/// transform IR op). The op is expected to have a single operand. +/// In `applyToOne`, OpTy is either Operation* or a concrete payload IR Op class +/// that the transformation is applied to (and NOT the class of the transform IR +/// op). +/// The `applyToOne` method takes an empty `results` vector that it fills with +/// zero, one or multiple operations depending on the number of resultd expected +/// by the transform op. +/// The number of results must match the number of results of the transform op. +/// `applyToOne` is allowed to fill the `results` with all null elements to +/// signify that the transformation did not apply to the payload IR operations. +/// Such null elements are filtered out from results before return. +/// +/// The transform op having this trait is expected to have a single operand. template class TransformEachOpTrait : public OpTrait::TraitBase { public: /// Calls `applyToOne` for every payload operation associated with the operand - /// of this transform IR op. If `applyToOne` returns ops, associates them with - /// the result of this transform op. + /// of this transform IR op, the following case disjunction happens: + /// 1. If not target payload ops are associated to the operand then fill the + /// results vector with the expected number of null elements and return + /// success. This is the corner case handling that allows propagating + /// the "no-op" case gracefully to improve usability. + /// 2. If any `applyToOne` returns definiteFailure, the transformation is + /// immediately considered definitely failed and we return. + /// 3. All applications of `applyToOne` are checked to return a number of + /// results expected by the transform IR op. If not, this is a definite + /// failure and we return early. + /// 4. If `applyToOne` produces ops, associate them with the result of this + /// transform op. + /// 5. If any `applyToOne` return silenceableFailure, the transformation is + /// considered silenceable. + /// 6. Otherwise the transformation is considered successful. DiagnosedSilenceableFailure apply(TransformResults &transformResults, TransformState &state); @@ -714,88 +774,58 @@ namespace mlir { namespace transform { namespace detail { -/// Appends `result` to the vector assuming it corresponds to the success state -/// in `FailureOr`. If `result` is just a -/// `LogicalResult`, appends an empy vector. -template -std::enable_if_t::value, LogicalResult> -appendTransformResultToVector( - Ty result, SmallVectorImpl> &results) { - results.push_back(SmallVector()); - return result; -} - -template -std::enable_if_t< - llvm::conjunction< - llvm::negation>, - std::is_convertible>::value, - LogicalResult> -appendTransformResultToVector( - Ty result, SmallVectorImpl> &results) { - if (failed(result)) - return failure(); - results.push_back(SmallVector{*result}); - return success(); -} - -template -std::enable_if_t< - llvm::conjunction< - llvm::negation>, - llvm::negation>>::value, - LogicalResult> -appendTransformResultToVector( - ContainerTy resultContainer, - SmallVectorImpl> &results) { - if (failed(resultContainer)) - return failure(); - results.push_back(*resultContainer); - return success(); -} /// Applies a one-to-one or a one-to-many transform to each of the given /// targets. Puts the results of transforms, if any, in `results` in the same /// order. Fails if any of the application fails. Individual transforms must be -/// callable with one of the following signatures: -/// - FailureOr(OpTy) -/// - LogicalResult(OpTy) -/// - FailureOr>( -/// SmallVectorImpl) -/// - LogicalResult(SmallVectorImpl) +/// callable with the following signature: +/// - DiagnosedSilenceableFailure(OpTy, +/// SmallVector &results, state) /// where OpTy is either /// - Operation *, in which case the transform is always applied; /// - a concrete Op class, in which case a check is performed whether -/// `targets` contains operations of the same class and a silencable failure +/// `targets` contains operations of the same class and a silenceable failure /// is reported if it does not. template -DiagnosedSilenceableFailure -applyTransformToEach(ArrayRef targets, - SmallVectorImpl> &results, - FnTy transform) { +DiagnosedSilenceableFailure applyTransformToEach( + Location loc, int expectedNumResults, ArrayRef targets, + SmallVectorImpl> &results, FnTy transform) { + SmallVector silenceableStack; using OpTy = typename llvm::function_traits::template arg_t<0>; static_assert(std::is_convertible::value, "expected transform function to take an operation"); - using RetTy = typename llvm::function_traits::result_t; - static_assert(std::is_convertible::value, - "expected transform function to return LogicalResult or " - "FailureOr"); for (Operation *target : targets) { + // Emplace back a placeholder for the returned new ops. + // This is filled with `expectedNumResults` if the op fails to apply. + results.push_back(SmallVector()); + auto specificOp = dyn_cast(target); if (!specificOp) { - Diagnostic diag(target->getLoc(), DiagnosticSeverity::Error); - diag << "attempted to apply transform to the wrong op kind"; - return DiagnosedSilenceableFailure::silencableFailure(std::move(diag)); + Diagnostic diag(loc, DiagnosticSeverity::Error); + diag << "transform applied to the wrong op kind"; + diag.attachNote(target->getLoc()) << "when applied to this op"; + // Producing `expectedNumResults` nullptr is a silenceableFailure mode. + // TODO: encode this implicit `expectedNumResults` nullptr == + // silenceableFailure with a proper trait. + results.back().assign(expectedNumResults, nullptr); + silenceableStack.push_back(std::move(diag)); + continue; } - auto result = transform(specificOp); - if (failed(appendTransformResultToVector(result, results))) - return DiagnosedSilenceableFailure::definiteFailure(); + DiagnosedSilenceableFailure result = transform(specificOp, results.back()); + if (result.isDefiniteFailure()) + return result; + if (result.isSilenceableFailure()) + for (auto &&diag : result.takeDiagnostics()) + silenceableStack.push_back(std::move(diag)); + } + if (!silenceableStack.empty()) { + return DiagnosedSilenceableFailure::silenceableFailure( + std::move(silenceableStack)); } return DiagnosedSilenceableFailure::success(); } -/// Helper function to transform M ops with N results into N results of M ops. +/// Helper function: transpose MxN into NxM; assumes that the input is a valid. static inline SmallVector> transposeResults(const SmallVector, 1> &m) { SmallVector> res; @@ -824,74 +854,95 @@ decltype(&OpTy::applyToOne)>::template arg_t<0>; ArrayRef targets = state.getPayloadOps(this->getOperation()->getOperand(0)); - // Handle the corner case where no target is specified. + + // Step 1. Handle the corner case where no target is specified. // This is typically the case when the matcher fails to apply and we need to // propagate gracefully. // In this case, we fill all results with an empty vector. if (targets.empty()) { - SmallVector emptyResult; + SmallVector empty; for (auto r : this->getOperation()->getResults()) - transformResults.set(r.template cast(), emptyResult); + transformResults.set(r.template cast(), empty); return DiagnosedSilenceableFailure::success(); } + // Step 2. Call applyToOne on each target and record newly produced ops in its + // corresponding results entry. + int expectedNumResults = this->getOperation()->getNumResults(); SmallVector, 1> results; - // In the multi-result case, collect the number of results each transform - // produced. DiagnosedSilenceableFailure result = detail::applyTransformToEach( - targets, results, [&](TransformOpType specificOp) { - return static_cast(this)->applyToOne(specificOp, state); + this->getOperation()->getLoc(), expectedNumResults, targets, results, + [&](TransformOpType specificOp, SmallVector &partialResult) { + auto res = static_cast(this)->applyToOne(specificOp, + partialResult, state); + if (res.isDefiniteFailure()) + return res; + + // TODO: encode this implicit must always produce `expectedNumResults` + // and nullptr is fine with a proper trait. + if (static_cast(partialResult.size()) != expectedNumResults) { + auto loc = this->getOperation()->getLoc(); + auto diag = mlir::emitError(loc, "applications of ") + << OpTy::getOperationName() << " expected to produce " + << expectedNumResults << " results (actually produced " + << partialResult.size() << ")."; + diag.attachNote(loc) + << "If you need variadic results, consider a generic `apply` " + << "instead of the specialized `applyToOne`."; + diag.attachNote(loc) + << "Producing " << expectedNumResults << " null results is " + << "allowed if the use case warrants it."; + diag.attachNote(specificOp->getLoc()) << "when applied to this op"; + return DiagnosedSilenceableFailure::definiteFailure(); + } + // Check that all is null or none is null + // TODO: relax this behavior and encode with a proper trait. + if (llvm::any_of(partialResult, [](Operation *op) { return op; }) && + llvm::any_of(partialResult, [](Operation *op) { return !op; })) { + auto loc = this->getOperation()->getLoc(); + auto diag = mlir::emitError(loc, "unexpected application of ") + << OpTy::getOperationName() + << " produces both null and non null results."; + diag.attachNote(specificOp->getLoc()) << "when applied to this op"; + return DiagnosedSilenceableFailure::definiteFailure(); + } + return res; }); - // Propagate the failure (definite or silencable) if any. - if (!result.succeeded()) + + // Step 3. Propagate the definite failure if any and bail out. + if (result.isDefiniteFailure()) return result; - // Legitimately no results, bail early. - if (results.empty() && OpTy::template hasTrait()) - return DiagnosedSilenceableFailure::success(); + // Step 4. If there are no results, return early. + if (OpTy::template hasTrait()) + return result; - // Ensure all applications return the same number of results. - // Variadic cases are much trickier to handle in a generic fashion. - int64_t nRes = results.empty() ? 0 : results[0].size(); - if (llvm::any_of(results, [&](const auto &r) { - return static_cast(r.size()) != nRes; - })) { - return static_cast(this)->emitSilenceableError() - << "expected all applications of " << OpTy::getOperationName() - << " to produce " << nRes - << " results.\n If you need variadic results, consider using a " - "generic `apply` instead of the specialized `applyToOne`"; - } - // Ensure the number of results agrees with what the transform op expects. - // Unless we see empty results, in which case we just want to propagate the - // emptiness. - if (this->getOperation()->getNumResults() != nRes) { - InFlightDiagnostic diag = static_cast(this)->emitError() - << "unexpected number of results (got " << nRes - << " expected " - << this->getOperation()->getNumResults() << ")"; - return DiagnosedSilenceableFailure::definiteFailure(); - } - - // Perform transposition of M applications producing N results each into N - // results for each of the M applications. + // Step 5. Perform transposition of M applications producing N results each + // into N results for each of the M applications. SmallVector> transposedResults = detail::transposeResults(results); - // Single result applies to M ops produces one single M-result. + + // Step 6. Single result applies to M ops produces one single M-result. if (OpTy::template hasTrait()) { assert(transposedResults.size() == 1 && "Expected single result"); transformResults.set( this->getOperation()->getResult(0).template cast(), transposedResults[0]); - return DiagnosedSilenceableFailure::success(); + // ApplyToOne may have returned silenceableFailure, propagate it. + return result; } - // M ops, N results each. + + // Step 7. Filter out empty results and set the transformResults. for (const auto &it : llvm::zip(this->getOperation()->getResults(), transposedResults)) { - transformResults.set(std::get<0>(it).template cast(), - std::get<1>(it)); + SmallVector filtered; + llvm::copy_if(std::get<1>(it), std::back_inserter(filtered), + [](Operation *op) { return op; }); + transformResults.set(std::get<0>(it).template cast(), filtered); } - return DiagnosedSilenceableFailure::success(); + + // Step 8. ApplyToOne may have returned silenceableFailure, propagate it. + return result; } template diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -62,11 +62,21 @@ return diag; } - /// Creates the silencable failure object with a diagnostic located at the + /// Creates the silenceable failure object with a diagnostic located at the /// current operation. DiagnosedSilenceableFailure emitSilenceableError() { Diagnostic diag($_op->getLoc(), DiagnosticSeverity::Error); - return DiagnosedSilenceableFailure::silencableFailure(std::move(diag)); + return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); + } + + /// Creates the default silenceable failure for a transform op that failed + /// to properly apply to a target. + DiagnosedSilenceableFailure emitDefaultSilenceableFailure( + Operation *target) { + Diagnostic diag($_op->getLoc(), DiagnosticSeverity::Error); + diag << $_op->getName() << " failed to apply"; + diag.attachNote(target->getLoc()) << "when applied to this op"; + return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); } }]; } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -76,19 +76,24 @@ // DecomposeOp //===----------------------------------------------------------------------===// -FailureOr transform::DecomposeOp::applyToOne(LinalgOp target, - TransformState &state) { +DiagnosedSilenceableFailure +transform::DecomposeOp::applyToOne(linalg::LinalgOp target, + SmallVectorImpl &results, + transform::TransformState &state) { FailureOr windowed = tryApply(target); - if (succeeded(windowed)) - return windowed; - + if (succeeded(windowed)) { + results.push_back(*windowed); + return DiagnosedSilenceableFailure(success()); + } FailureOr depthwise = tryApply(target); - if (succeeded(depthwise)) - return depthwise; - - return reportUnknownTransformError(target); + if (succeeded(depthwise)) { + results.push_back(*depthwise); + return DiagnosedSilenceableFailure(success()); + } + results.assign(1, nullptr); + return emitDefaultSilenceableFailure(target); } //===----------------------------------------------------------------------===// @@ -221,41 +226,46 @@ // GeneralizeOp //===----------------------------------------------------------------------===// -FailureOr transform::GeneralizeOp::applyToOne(LinalgOp target, - TransformState &state) { +DiagnosedSilenceableFailure +transform::GeneralizeOp::applyToOne(linalg::LinalgOp target, + SmallVectorImpl &results, + transform::TransformState &state) { // Exit early if no transformation is needed. - if (isa(target)) - return target; - + if (isa(target)) { + results.push_back(target); + return DiagnosedSilenceableFailure(success()); + } FailureOr generic = tryApply(target); - if (succeeded(generic)) - return generic; - - return reportUnknownTransformError(target); + if (succeeded(generic)) { + results.push_back(generic->getOperation()); + return DiagnosedSilenceableFailure(success()); + } + results.assign(1, nullptr); + return emitDefaultSilenceableFailure(target); } //===----------------------------------------------------------------------===// // InterchangeOp //===----------------------------------------------------------------------===// -FailureOr -transform::InterchangeOp::applyToOne(LinalgOp target, TransformState &state) { +DiagnosedSilenceableFailure +transform::InterchangeOp::applyToOne(linalg::GenericOp target, + SmallVectorImpl &results, + transform::TransformState &state) { SmallVector interchangeVector = extractUIntArray(getIteratorInterchange()); // Exit early if no transformation is needed. - if (interchangeVector.empty()) - return target; - - auto genericTarget = dyn_cast(target.getOperation()); - if (!genericTarget) { - InFlightDiagnostic diag = emitOpError() - << "applies to " << GenericOp::getOperationName() - << " ops"; - diag.attachNote(target.getLoc()) << "attempted to apply to this op"; - return diag; + if (interchangeVector.empty()) { + results.push_back(target); + return DiagnosedSilenceableFailure(success()); } - - return tryApply(target, interchangeVector); + SimpleRewriter rewriter(target->getContext()); + FailureOr res = + interchangeGenericOp(rewriter, target, interchangeVector); + if (failed(res)) + return DiagnosedSilenceableFailure::definiteFailure(); + results.push_back(res->getOperation()); + return DiagnosedSilenceableFailure(success()); } LogicalResult transform::InterchangeOp::verify() { @@ -275,8 +285,10 @@ // PadOp //===---------------------------------------------------------------------===// -FailureOr transform::PadOp::applyToOne(LinalgOp target, - TransformState &state) { +DiagnosedSilenceableFailure +transform::PadOp::applyToOne(linalg::LinalgOp target, + SmallVectorImpl &results, + transform::TransformState &state) { // Convert the integer packing flags to booleans. SmallVector packPaddings; for (int64_t packPadding : extractI64Array(getPackPaddings())) @@ -293,21 +305,19 @@ paddingValues.push_back( parseAttribute(attr.cast(), elementType)); if (!paddingValues.back()) { - InFlightDiagnostic diag = emitOpError() - << "expects a padding value that parses to " - << elementType << ", got " << std::get<0>(it); + auto diag = this->emitOpError("expects a padding that parses to ") + << elementType << ", got " << std::get<0>(it); diag.attachNote(target.getLoc()) << "when applied to this op"; - return diag; + return DiagnosedSilenceableFailure::definiteFailure(); } continue; } // Otherwise, add the attribute directly. if (attr.getType() != elementType) { - InFlightDiagnostic diag = emitOpError() - << "expects a padding value of type " - << elementType << ", got " << attr; + auto diag = this->emitOpError("expects a padding value of type ") + << elementType << ", got " << attr; diag.attachNote(target.getLoc()) << "when applied to this op"; - return diag; + return DiagnosedSilenceableFailure::definiteFailure(); } paddingValues.push_back(attr); } @@ -327,13 +337,13 @@ FailureOr result = tryApply(target, paddingOptions); - if (succeeded(result)) - return result; + if (succeeded(result)) { + results.push_back(result->getOperation()); + return DiagnosedSilenceableFailure(success()); + } - InFlightDiagnostic diag = emitError() - << "failed to apply pattern to target op"; - diag.attachNote(target.getLoc()) << "target op"; - return diag; + results.assign(1, nullptr); + return emitDefaultSilenceableFailure(target); } LogicalResult transform::PadOp::verify() { @@ -381,8 +391,10 @@ // ScalarizeOp //===----------------------------------------------------------------------===// -FailureOr transform::ScalarizeOp::applyToOne(LinalgOp target, - TransformState &state) { +DiagnosedSilenceableFailure +transform::ScalarizeOp::applyToOne(linalg::LinalgOp target, + SmallVectorImpl &results, + transform::TransformState &state) { LinalgTilingOptions tilingOptions; tilingOptions.scalarizeDynamicDims(); // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile @@ -394,9 +406,10 @@ FailureOr result = pattern.returningMatchAndRewrite(target, rewriter); if (failed(result)) - return failure(); + return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); - return result->op; + results.push_back(result->op); + return DiagnosedSilenceableFailure(success()); } //===----------------------------------------------------------------------===// @@ -558,9 +571,10 @@ // SplitReductionOp //===----------------------------------------------------------------------===// -FailureOr> -transform::SplitReductionOp::applyToOne(LinalgOp target, - TransformState &state) { +DiagnosedSilenceableFailure +transform::SplitReductionOp::applyToOne(linalg::LinalgOp target, + SmallVectorImpl &results, + transform::TransformState &state) { ControlSplitReductionFn splitFn = [&](LinalgOp) { return std::pair(getSplitFactor(), getInsertSplitDimension()); @@ -572,10 +586,13 @@ ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc()) : splitReduction(rewriter, target, splitFn, getUseAlloc()); if (failed(splitResult)) - return getOperation()->emitError("failed to apply"); - return SmallVector{splitResult->fillOp, - splitResult->splitLinalgOp, - splitResult->resultCombiningLinalgOp}; + return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + + results.push_back(splitResult->initOrAlloc); + results.push_back(splitResult->fillOp); + results.push_back(splitResult->splitLinalgOp); + results.push_back(splitResult->resultCombiningLinalgOp); + return DiagnosedSilenceableFailure(success()); } //===----------------------------------------------------------------------===// @@ -618,13 +635,14 @@ // VectorizeOp //===----------------------------------------------------------------------===// -FailureOr VectorizeOp::applyToOne(Operation *target, - TransformState &state) { +DiagnosedSilenceableFailure +transform::VectorizeOp::applyToOne(Operation *target, + SmallVectorImpl &results, + transform::TransformState &state) { if (!target->hasTrait()) { - InFlightDiagnostic diag = emitOpError() - << "applies only to isolated-from-above targets"; + auto diag = this->emitOpError("requires isolated-from-above targets"); diag.attachNote(target->getLoc()) << "non-isolated target"; - return diag; + return DiagnosedSilenceableFailure::definiteFailure(); } MLIRContext *ctx = getContext(); @@ -642,8 +660,10 @@ linalg::populatePadOpVectorizationPatterns(patterns); if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) - return reportUnknownTransformError(target); - return target; + return DiagnosedSilenceableFailure(reportUnknownTransformError(target)); + + results.push_back(target); + return DiagnosedSilenceableFailure(success()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -127,18 +127,21 @@ // LoopPeelOp //===----------------------------------------------------------------------===// -FailureOr transform::LoopPeelOp::applyToOne(scf::ForOp loop, - TransformState &state) { +DiagnosedSilenceableFailure +transform::LoopPeelOp::applyToOne(scf::ForOp target, + SmallVector &results, + transform::TransformState &state) { scf::ForOp result; - IRRewriter rewriter(loop->getContext()); + IRRewriter rewriter(target->getContext()); + // This helper returns failure when peeling does not occur (i.e. when the IR + // is not modified). This is not a failure for the op as the postcondition: + // "the loop trip count is divisible by the step" + // is valid. LogicalResult status = - scf::peelAndCanonicalizeForLoop(rewriter, loop, result); - if (failed(status)) { - if (getFailIfAlreadyDivisible()) - return reportUnknownTransformError(loop); - return loop; - } - return result; + scf::peelAndCanonicalizeForLoop(rewriter, target, result); + // TODO: Return both the peeled loop and the remainder loop. + results.push_back(failed(status) ? target : result); + return DiagnosedSilenceableFailure(success()); } //===----------------------------------------------------------------------===// @@ -181,8 +184,10 @@ } } -FailureOr -transform::LoopPipelineOp::applyToOne(scf::ForOp loop, TransformState &state) { +DiagnosedSilenceableFailure +transform::LoopPipelineOp::applyToOne(scf::ForOp target, + SmallVector &results, + transform::TransformState &state) { scf::PipeliningOption options; options.getScheduleFn = [this](scf::ForOp forOp, @@ -190,26 +195,33 @@ loopScheduling(forOp, schedule, getIterationInterval(), getReadLatency()); }; - - scf::ForLoopPipeliningPattern pattern(options, loop->getContext()); + scf::ForLoopPipeliningPattern pattern(options, target->getContext()); SimpleRewriter rewriter(getContext()); - rewriter.setInsertionPoint(loop); + rewriter.setInsertionPoint(target); FailureOr patternResult = - pattern.returningMatchAndRewrite(loop, rewriter); - if (failed(patternResult)) - return reportUnknownTransformError(loop); - return patternResult; + pattern.returningMatchAndRewrite(target, rewriter); + if (succeeded(patternResult)) { + results.push_back(*patternResult); + return DiagnosedSilenceableFailure(success()); + } + results.assign(1, nullptr); + return emitDefaultSilenceableFailure(target); } //===----------------------------------------------------------------------===// // LoopUnrollOp //===----------------------------------------------------------------------===// -LogicalResult transform::LoopUnrollOp::applyToOne(scf::ForOp loop, - TransformState &state) { - if (failed(loopUnrollByFactor(loop, getFactor()))) - return reportUnknownTransformError(loop); - return success(); +DiagnosedSilenceableFailure +transform::LoopUnrollOp::applyToOne(scf::ForOp target, + SmallVector &results, + transform::TransformState &state) { + if (failed(loopUnrollByFactor(target, getFactor()))) { + Diagnostic diag(target->getLoc(), DiagnosticSeverity::Note); + diag << "op failed to unroll"; + return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); + } + return DiagnosedSilenceableFailure(success()); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/transform-op-interchange.mlir b/mlir/test/Dialect/Linalg/transform-op-interchange.mlir --- a/mlir/test/Dialect/Linalg/transform-op-interchange.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-interchange.mlir @@ -37,7 +37,7 @@ // ----- func.func @interchange_matmul(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { - // expected-note @below {{attempted to apply to this op}} + // expected-note @below {{when applied to this op}} %0 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) -> tensor return %0 : tensor } @@ -54,7 +54,7 @@ transform.sequence %arg0 { ^bb1(%arg1: !pdl.operation): %0 = pdl_match @match_generic in %arg1 - // expected-error @below {{applies to linalg.generic ops}} + // expected-error @below {{transform applied to the wrong op kind}} transform.structured.interchange %0 { iterator_interchange = [1, 0]} } } diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir --- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir @@ -99,7 +99,7 @@ transform.sequence %arg0 { ^bb1(%arg1: !pdl.operation): %0 = pdl_match @pdl_target in %arg1 - // expected-error @below {{expects a padding value that parses to 'f32', got "foo"}} + // expected-error @below {{expects a padding that parses to 'f32', got "foo"}} %1 = transform.structured.pad %0 {padding_values=["foo", 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0]} } } @@ -109,7 +109,7 @@ func.func @pad(%arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { - // expected-note @below {{target op}} + // expected-note @below {{when applied to this op}} %0 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> func.return %0 : tensor<24x25xf32> } @@ -127,7 +127,7 @@ transform.sequence %arg0 { ^bb1(%arg1: !pdl.operation): %0 = pdl_match @pdl_target in %arg1 - // expected-error @below {{failed to apply pattern to target op}} + // expected-error @below {{transform.structured.pad failed to apply}} %1 = transform.structured.pad %0 {padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], padding_dimensions=[0, 1, 2], pack_paddings=[1, 1, 0]} } } diff --git a/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir b/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir --- a/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir @@ -31,7 +31,7 @@ transform.sequence %arg0 { ^bb1(%arg1: !pdl.operation): %0 = pdl_match @pdl_target in %arg1 - %1:3 = transform.structured.split_reduction %0 + %1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 2, use_scaling_algorithm, use_alloc} } } diff --git a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir --- a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir @@ -30,6 +30,6 @@ transform.sequence %arg0 { ^bb1(%arg1: !pdl.operation): %0 = pdl_match @pdl_target in %arg1 - %1:3 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 2} + %1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 2} } } diff --git a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir --- a/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-vectorize.mlir @@ -176,7 +176,7 @@ transform.sequence %arg0 { ^bb1(%arg1: !pdl.operation): %0 = pdl_match @pdl_target in %arg1 - // expected-error @below {{applies only to isolated-from-above targets}} + // expected-error @below {{op requires isolated-from-above targets}} %2 = transform.structured.vectorize %0 } } diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -380,16 +380,36 @@ } // ----- -transform.sequence { +func.func @foo() { + // expected-note @below {{when applied to this op}} + "op" () : () -> () + return +} + +transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): - // expected-error @below {{unexpected number of results (got 0 expected 3)}} - transform.test_wrong_number_of_results %arg0 + pdl.pattern @some : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "op"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb0(%arg1: !pdl.operation): + %0 = pdl_match @some in %arg1 + // expected-error @below {{applications of transform.test_wrong_number_of_results expected to produce 3 results (actually produced 1).}} + // expected-note @below {{If you need variadic results, consider a generic `apply` instead of the specialized `applyToOne`.}} + // expected-note @below {{Producing 3 null results is allowed if the use case warrants it.}} + transform.test_wrong_number_of_results %0 + } } // ----- func.func @foo() { "op" () : () -> () + // expected-note @below {{when applied to this op}} "op" () : () -> () return } @@ -406,7 +426,9 @@ transform.sequence %arg0 { ^bb0(%arg1: !pdl.operation): %0 = pdl_match @some in %arg1 - // expected-error @below {{expected all applications of transform.test_wrong_number_of_multi_results to produce 1 results}} + // expected-error @below {{applications of transform.test_wrong_number_of_multi_results expected to produce 1 results (actually produced 0)}} + // expected-note @below {{If you need variadic results, consider a generic `apply` instead of the specialized `applyToOne`.}} + // expected-note @below {{Producing 1 null results is allowed if the use case warrants it.}} transform.test_wrong_number_of_multi_results %0 } } @@ -463,6 +485,31 @@ // ----- +func.func @foo() { + // expected-note @below {{when applied to this op}} + "op" () : () -> () + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @some : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "op"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb0(%arg1: !pdl.operation): + %0 = pdl_match @some in %arg1 + // expected-error @below {{unexpected application of transform.test_mixed_null_and_non_null_results produces both null and non null results.}} + transform.test_mixed_null_and_non_null_results %0 + } +} + +// ----- + // Expecting to match all operations by merging the handles that matched addi // and subi separately. func.func @foo(%arg0: index) { @@ -498,4 +545,3 @@ test_print_remark_at_operand %2, "matched" } } - diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -226,28 +226,44 @@ return DiagnosedSilenceableFailure::success(); } -FailureOr> -mlir::test::TestWrongNumberOfResultsOp::applyToOne( - Operation *, transform::TransformState &state) { - return SmallVector{}; +DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfResultsOp::applyToOne( + Operation *target, SmallVectorImpl &results, + transform::TransformState &state) { + OperationState opState(target->getLoc(), "foo"); + results.push_back(OpBuilder(target).create(opState)); + return DiagnosedSilenceableFailure::success(); } -FailureOr> +DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfMultiResultsOp::applyToOne( - Operation *op, transform::TransformState &state) { + Operation *target, SmallVectorImpl &results, + transform::TransformState &state) { static int count = 0; - if (count++ > 0) - return SmallVector{}; - OperationState opState(op->getLoc(), "foo"); - return SmallVector{OpBuilder(op).create(opState)}; + if (count++ == 0) { + OperationState opState(target->getLoc(), "foo"); + results.push_back(OpBuilder(target).create(opState)); + } + return DiagnosedSilenceableFailure::success(); } -FailureOr> +DiagnosedSilenceableFailure mlir::test::TestCorrectNumberOfMultiResultsOp::applyToOne( - Operation *op, transform::TransformState &state) { - OperationState opState(op->getLoc(), "foo"); - return SmallVector{OpBuilder(op).create(opState), - OpBuilder(op).create(opState)}; + Operation *target, SmallVectorImpl &results, + transform::TransformState &state) { + OperationState opState(target->getLoc(), "foo"); + results.push_back(OpBuilder(target).create(opState)); + results.push_back(OpBuilder(target).create(opState)); + return DiagnosedSilenceableFailure::success(); +} + +DiagnosedSilenceableFailure +mlir::test::TestMixedNullAndNonNullResultsOp::applyToOne( + Operation *target, SmallVectorImpl &results, + transform::TransformState &state) { + OperationState opState(target->getLoc(), "foo"); + results.push_back(nullptr); + results.push_back(OpBuilder(target).create(opState)); + return DiagnosedSilenceableFailure::success(); } namespace { diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -139,8 +139,10 @@ let assemblyFormat = "$target attr-dict"; let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ - ::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne( - ::mlir::Operation *target, transform::TransformState &state); + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation * target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); }]; } @@ -153,8 +155,10 @@ let assemblyFormat = "$target attr-dict"; let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ - ::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne( - ::mlir::Operation *target, transform::TransformState &state); + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation * target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); }]; } @@ -168,8 +172,27 @@ let assemblyFormat = "$target attr-dict"; let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ - ::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne( - ::mlir::Operation *target, transform::TransformState &state); + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation * target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); + }]; +} + +def TestMixedNullAndNonNullResultsOp + : Op { + let arguments = (ins PDL_Operation:$target); + let results = (outs PDL_Operation:$null, + PDL_Operation:$non_null); + let assemblyFormat = "$target attr-dict"; + let cppNamespace = "::mlir::test"; + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation * target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); }]; }