diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td --- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td @@ -93,7 +93,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::tensor::EmptyOp target, - ::llvm::SmallVector<::mlir::Operation *> &results, + ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); }]; } diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td --- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td +++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td @@ -109,7 +109,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::Operation *target, - ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); }]; } @@ -130,7 +130,7 @@ The operation searches top level `scf.foreach_thread` ops under `gpu_launch` and maps each such op to GPU blocks. Mapping is one-to-one and the induction variables of `scf.foreach_thread` are - rewritten to gpu.block_id according to the `thread_dim_apping` attribute. + rewritten to gpu.block_id according to the `thread_dim_mapping` attribute. Dynamic, `scf.foreach_thread` trip counts are currently not supported. Dynamic block dim sizes are currently not supported. @@ -167,7 +167,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::Operation *target, - ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); }]; } diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h @@ -23,6 +23,7 @@ } // namespace linalg namespace transform { +class TransformHandleTypeInterface; // Types needed for builders. struct TileSizesSpec {}; struct NumThreadsSpec {}; diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -47,7 +47,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::linalg::LinalgOp target, - ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); }]; } @@ -160,7 +160,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::linalg::LinalgOp target, - ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); }]; } @@ -202,7 +202,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::linalg::GenericOp target, - ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); }]; } @@ -336,7 +336,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::linalg::LinalgOp target, - ::llvm::SmallVector<::mlir::Operation *> &results, + ::mlir::transform::ApplyToEachResultList &results, TransformState &state); }]; } @@ -380,7 +380,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::linalg::LinalgOp target, - ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); }]; } @@ -423,7 +423,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::linalg::LinalgOp target, - ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); }]; } @@ -492,7 +492,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::linalg::LinalgOp target, - ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); }]; } @@ -700,7 +700,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::linalg::LinalgOp target, - ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); }]; } @@ -803,7 +803,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::linalg::LinalgOp target, - ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); }]; } @@ -911,7 +911,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::linalg::LinalgOp target, - ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); }]; @@ -1228,7 +1228,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::Operation *target, - ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); }]; } diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td --- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td +++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td @@ -44,7 +44,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( memref::AllocOp target, - ::llvm::SmallVector<::mlir::Operation *> &results, + ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); }]; } diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -109,7 +109,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::scf::ForOp target, - ::llvm::SmallVector<::mlir::Operation *> &results, + ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); }]; } @@ -151,7 +151,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::scf::ForOp target, - ::llvm::SmallVector<::mlir::Operation *> &results, + ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); }]; } @@ -184,7 +184,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::Operation *target, - ::llvm::SmallVector<::mlir::Operation *> &results, + ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); }]; } diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -542,8 +542,8 @@ }; /// Trait implementing the TransformOpInterface for operations applying a -/// transformation to a single operation handle and producing zero, one or -/// multiple operation handles. +/// transformation to a single operation handle and producing an arbitrary +/// number of handles and parameter values. /// The op must implement a method with the following signature: /// - DiagnosedSilenceableFailure applyToOne(OpTy, /// SmallVector &results, state) @@ -732,7 +732,82 @@ namespace mlir { namespace transform { + +/// A single result of applying a transform op with `ApplyEachOpTrait` to a +/// single payload operation. +using ApplyToEachResult = llvm::PointerUnion; + +/// A list of results of applying a transform op with `ApplyEachOpTrait` to a +/// single payload operation, co-indexed with the results of the transform op. +class ApplyToEachResultList { +public: + ApplyToEachResultList() = default; + explicit ApplyToEachResultList(unsigned size) : results(size) {} + + /// Sets the list of results to `size` null pointers. + void assign(unsigned size, std::nullptr_t) { results.assign(size, nullptr); } + + /// Sets the list of results to the given range of values. + template + void assign(Range &&range) { + // This is roughly the implementation of SmallVectorImpl::assign. + // Dispatching to it with map_range and template type inference would result + // in more complex code here. + results.clear(); + results.reserve(llvm::size(range)); + for (auto element : range) { + if constexpr (std::is_convertible_v) { + results.push_back(static_cast(element)); + } else { + results.push_back(static_cast(element)); + } + } + } + + /// Appends an element to the list. + void push_back(Operation *op) { results.push_back(op); } + void push_back(Attribute attr) { results.push_back(attr); } + + /// Reserves space for `size` elements in the list. + void reserve(unsigned size) { results.reserve(size); } + + /// Iterators over the list. + auto begin() { return results.begin(); } + auto end() { return results.end(); } + auto begin() const { return results.begin(); } + auto end() const { return results.end(); } + + /// Returns the number of elements in the list. + size_t size() const { return results.size(); } + + /// Element access. Expects the index to be in bounds. + ApplyToEachResult &operator[](size_t index) { return results[index]; } + const ApplyToEachResult &operator[](size_t index) const { + return results[index]; + } + +private: + /// Underlying storage. + SmallVector results; +}; + namespace detail { + +/// Check that the contents of `partialResult` matches the number, kind (payload +/// op or parameter) and nullity (either all or none) requirements of +/// `transformOp`. Report errors and return failure otherwise. +LogicalResult checkApplyToOne(Operation *transformOp, Location payloadOpLoc, + const ApplyToEachResultList &partialResult); + +/// "Transpose" the results produced by individual applications, arranging them +/// per result value of the transform op, and populate `transformResults` with +/// that. The number, kind and nullity of per-application results are assumed to +/// have been verified. +void setApplyToOneResults(Operation *transformOp, + TransformResults &transformResults, + ArrayRef results); + /// Applies a one-to-one or a one-to-many transform to each of the given /// targets. Puts the results of transforms, if any, in `results` in the same /// order. Fails if any of the application fails. Individual transforms must be @@ -744,22 +819,28 @@ /// - a concrete Op class, in which case a check is performed whether /// `targets` contains operations of the same class and a silenceable failure /// is reported if it does not. -template -DiagnosedSilenceableFailure applyTransformToEach( - Location loc, int expectedNumResults, ArrayRef targets, - SmallVectorImpl> &results, FnTy transform) { - SmallVector silenceableStack; - using OpTy = typename llvm::function_traits::template arg_t<0>; +template +DiagnosedSilenceableFailure +applyTransformToEach(TransformOpTy transformOp, ArrayRef targets, + SmallVectorImpl &results, + TransformState &state) { + using OpTy = typename llvm::function_traits< + decltype(&TransformOpTy::applyToOne)>::template arg_t<0>; static_assert(std::is_convertible::value, "expected transform function to take an operation"); + + SmallVector silenceableStack; + unsigned expectedNumResults = transformOp->getNumResults(); for (Operation *target : targets) { - // Emplace back a placeholder for the returned new ops. + // Emplace back a placeholder for the returned new ops and params. // This is filled with `expectedNumResults` if the op fails to apply. - results.push_back(SmallVector()); + ApplyToEachResultList placeholder; + placeholder.reserve(expectedNumResults); + results.push_back(std::move(placeholder)); auto specificOp = dyn_cast(target); if (!specificOp) { - Diagnostic diag(loc, DiagnosticSeverity::Error); + Diagnostic diag(transformOp->getLoc(), DiagnosticSeverity::Error); diag << "transform applied to the wrong op kind"; diag.attachNote(target->getLoc()) << "when applied to this op"; // Producing `expectedNumResults` nullptr is a silenceableFailure mode. @@ -770,11 +851,16 @@ continue; } - DiagnosedSilenceableFailure result = transform(specificOp, results.back()); - if (result.isDefiniteFailure()) - return result; - if (result.isSilenceableFailure()) - result.takeDiagnostics(silenceableStack); + DiagnosedSilenceableFailure res = + transformOp.applyToOne(specificOp, results.back(), state); + if (res.isDefiniteFailure() || + failed(detail::checkApplyToOne(transformOp, specificOp->getLoc(), + results.back()))) { + return DiagnosedSilenceableFailure::definiteFailure(); + } + + if (res.isSilenceableFailure()) + res.takeDiagnostics(silenceableStack); } if (!silenceableStack.empty()) { return DiagnosedSilenceableFailure::silenceableFailure( @@ -783,23 +869,6 @@ return DiagnosedSilenceableFailure::success(); } -/// Helper function: transpose MxN into NxM; assumes that the input is a valid. -static inline SmallVector> -transposeResults(const SmallVector, 1> &m) { - SmallVector> res; - if (m.empty()) - return res; - int64_t rows = m.size(), cols = m[0].size(); - for (int64_t j = 0; j < cols; ++j) - res.push_back(SmallVector(rows, nullptr)); - for (int64_t i = 0; i < rows; ++i) { - assert(static_cast(m[i].size()) == cols); - for (int64_t j = 0; j < cols; ++j) { - res[j][i] = m[i][j]; - } - } - return res; -} } // namespace detail } // namespace transform } // namespace mlir @@ -808,8 +877,6 @@ mlir::DiagnosedSilenceableFailure mlir::transform::TransformEachOpTrait::apply( TransformResults &transformResults, TransformState &state) { - using TransformOpType = typename llvm::function_traits< - decltype(&OpTy::applyToOne)>::template arg_t<0>; ArrayRef targets = state.getPayloadOps(this->getOperation()->getOperand(0)); @@ -818,88 +885,35 @@ // propagate gracefully. // In this case, we fill all results with an empty vector. if (targets.empty()) { - SmallVector empty; - for (auto r : this->getOperation()->getResults()) - transformResults.set(r.template cast(), empty); + SmallVector emptyPayload; + SmallVector emptyParams; + for (OpResult r : this->getOperation()->getResults()) { + if (r.getType().isa()) + transformResults.setParams(r, emptyParams); + else + transformResults.set(r, emptyPayload); + } return DiagnosedSilenceableFailure::success(); } // Step 2. Call applyToOne on each target and record newly produced ops in its // corresponding results entry. - int expectedNumResults = this->getOperation()->getNumResults(); - SmallVector, 1> results; + SmallVector results; + results.reserve(targets.size()); DiagnosedSilenceableFailure result = detail::applyTransformToEach( - this->getOperation()->getLoc(), expectedNumResults, targets, results, - [&](TransformOpType specificOp, SmallVector &partialResult) { - auto res = static_cast(this)->applyToOne(specificOp, - partialResult, state); - if (res.isDefiniteFailure()) - return res; - - // TODO: encode this implicit must always produce `expectedNumResults` - // and nullptr is fine with a proper trait. - if (static_cast(partialResult.size()) != expectedNumResults) { - auto loc = this->getOperation()->getLoc(); - auto diag = mlir::emitError(loc, "applications of ") - << OpTy::getOperationName() << " expected to produce " - << expectedNumResults << " results (actually produced " - << partialResult.size() << ")."; - diag.attachNote(loc) - << "If you need variadic results, consider a generic `apply` " - << "instead of the specialized `applyToOne`."; - diag.attachNote(loc) - << "Producing " << expectedNumResults << " null results is " - << "allowed if the use case warrants it."; - diag.attachNote(specificOp->getLoc()) << "when applied to this op"; - return DiagnosedSilenceableFailure::definiteFailure(); - } - // Check that all is null or none is null - // TODO: relax this behavior and encode with a proper trait. - if (llvm::any_of(partialResult, [](Operation *op) { return op; }) && - llvm::any_of(partialResult, [](Operation *op) { return !op; })) { - auto loc = this->getOperation()->getLoc(); - auto diag = mlir::emitError(loc, "unexpected application of ") - << OpTy::getOperationName() - << " produces both null and non null results."; - diag.attachNote(specificOp->getLoc()) << "when applied to this op"; - return DiagnosedSilenceableFailure::definiteFailure(); - } - return res; - }); + cast(this->getOperation()), targets, results, state); // Step 3. Propagate the definite failure if any and bail out. if (result.isDefiniteFailure()) return result; - // Step 4. If there are no results, return early. - if (OpTy::template hasTrait()) - return result; - - // Step 5. Perform transposition of M applications producing N results each - // into N results for each of the M applications. - SmallVector> transposedResults = - detail::transposeResults(results); - - // Step 6. Single result applies to M ops produces one single M-result. - if (OpTy::template hasTrait()) { - assert(transposedResults.size() == 1 && "Expected single result"); - transformResults.set( - this->getOperation()->getResult(0).template cast(), - transposedResults[0]); - // ApplyToOne may have returned silenceableFailure, propagate it. - return result; - } - - // Step 7. Filter out empty results and set the transformResults. - for (const auto &it : - llvm::zip(this->getOperation()->getResults(), transposedResults)) { - SmallVector filtered; - llvm::copy_if(std::get<1>(it), std::back_inserter(filtered), - [](Operation *op) { return op; }); - transformResults.set(std::get<0>(it).template cast(), filtered); - } + // Step 4. "Transpose" the results produced by individual applications, + // arranging them per result value of the transform op. The number, kind and + // nullity of per-application results have been verified by the callback + // above. + detail::setApplyToOneResults(this->getOperation(), transformResults, results); - // Step 8. ApplyToOne may have returned silenceableFailure, propagate it. + // Step 5. ApplyToOne may have returned silenceableFailure, propagate it. return result; } diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -109,7 +109,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::Operation *target, - ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); }]; } diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -75,7 +75,7 @@ DiagnosedSilenceableFailure EmptyTensorToAllocTensorOp::applyToOne(tensor::EmptyOp target, - SmallVector &results, + ApplyToEachResultList &results, transform::TransformState &state) { IRRewriter rewriter(target->getContext()); rewriter.setInsertionPoint(target); diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -291,14 +291,14 @@ DiagnosedSilenceableFailure transform::MapForeachToBlocks::applyToOne(Operation *target, - SmallVectorImpl &results, + ApplyToEachResultList &results, transform::TransformState &state) { LaunchOp gpuLaunch = dyn_cast(target); TrivialPatternRewriter rewriter(getContext()); auto transformOp = cast(getOperation()); if (!getGenerateGpuLaunch() && !gpuLaunch) { - results.assign({target}); + results.push_back(target); DiagnosedSilenceableFailure diag = emitSilenceableError() << "Given target is not gpu.launch, set `generate_gpu_launch` " @@ -312,7 +312,7 @@ mlir::transform::gpu::findTopLevelForeachThreadOp( target, topLevelForeachThreadOp, transformOp); if (!diag.succeeded()) { - results.assign({target}); + results.push_back(target); diag.attachNote(target->getLoc()) << "when applied to this payload op"; return diag; } @@ -325,7 +325,7 @@ DiagnosedSilenceableFailure diag = createGpuLaunch(rewriter, target->getLoc(), transformOp, gpuLaunch); if (!diag.succeeded()) { - results.assign({target}); + results.push_back(target); return diag; } rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front()); @@ -352,7 +352,7 @@ gridDim[0], gridDim[1], gridDim[2]); } - results.assign({gpuLaunch}); + results.push_back(gpuLaunch); return diag; } @@ -520,14 +520,12 @@ } DiagnosedSilenceableFailure transform::MapNestedForeachToThreads::applyToOne( - ::mlir::Operation *target, - ::llvm::SmallVectorImpl<::mlir::Operation *> &results, - ::mlir::transform::TransformState &state) { + Operation *target, ApplyToEachResultList &results, TransformState &state) { LaunchOp gpuLaunch = dyn_cast(target); auto transformOp = cast(getOperation()); if (!gpuLaunch) { - results.assign({target}); + results.push_back(target); return emitSilenceableError() << "Given target is not gpu.launch"; } @@ -538,7 +536,7 @@ checkGpuLimits(transformOp, std::nullopt, std::nullopt, std::nullopt, blockDim[0], blockDim[1], blockDim[2]); if (diag.isSilenceableFailure()) { - results.assign({target}); + results.push_back(target); diag.attachNote(getLoc()) << getBlockDimAttrName() << " is very large"; return diag; } @@ -562,7 +560,7 @@ blockDim[2]); } - results.assign({gpuLaunch}); + results.push_back(gpuLaunch.getOperation()); return diag; } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -66,7 +66,7 @@ DiagnosedSilenceableFailure transform::DecomposeOp::applyToOne(linalg::LinalgOp target, - SmallVectorImpl &results, + transform::ApplyToEachResultList &results, transform::TransformState &state) { #define DOWNSCALE(trans) \ { \ @@ -577,7 +577,7 @@ DiagnosedSilenceableFailure transform::GeneralizeOp::applyToOne(linalg::LinalgOp target, - SmallVectorImpl &results, + transform::ApplyToEachResultList &results, transform::TransformState &state) { // Exit early if no transformation is needed. if (isa(target)) { @@ -599,7 +599,7 @@ DiagnosedSilenceableFailure transform::InterchangeOp::applyToOne(linalg::GenericOp target, - SmallVectorImpl &results, + transform::ApplyToEachResultList &results, transform::TransformState &state) { ArrayRef interchangeVector = getIteratorInterchange(); // Exit early if no transformation is needed. @@ -708,7 +708,8 @@ //===---------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne( - LinalgOp target, SmallVector &results, TransformState &state) { + LinalgOp target, transform::ApplyToEachResultList &results, + TransformState &state) { OpBuilder builder(target.getContext()); builder.setInsertionPoint(target); OpFoldResult targetSize = builder.getIndexAttr(getTargetSize()); @@ -748,7 +749,7 @@ DiagnosedSilenceableFailure transform::PadOp::applyToOne(linalg::LinalgOp target, - SmallVectorImpl &results, + transform::ApplyToEachResultList &results, transform::TransformState &state) { // Convert the integer packing flags to booleans. SmallVector packPaddings; @@ -861,7 +862,7 @@ DiagnosedSilenceableFailure transform::PromoteOp::applyToOne(linalg::LinalgOp target, - SmallVectorImpl &results, + transform::ApplyToEachResultList &results, transform::TransformState &state) { LinalgPromotionOptions promotionOptions; if (!getOperandsToPromote().empty()) @@ -955,7 +956,7 @@ DiagnosedSilenceableFailure transform::ScalarizeOp::applyToOne(linalg::LinalgOp target, - SmallVectorImpl &results, + transform::ApplyToEachResultList &results, transform::TransformState &state) { scf::SCFTilingOptions tilingOptions; tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) { @@ -991,7 +992,10 @@ rewriter.replaceOp(target, maybeTilingResult->replacements); else rewriter.eraseOp(target); - results.append(maybeTilingResult->tiledOps); + + results.reserve(maybeTilingResult->tiledOps.size()); + for (Operation *tiled : maybeTilingResult->tiledOps) + results.push_back(tiled); return DiagnosedSilenceableFailure::success(); } @@ -1172,10 +1176,9 @@ result.addTypes({resultType, resultType, resultType, resultType}); } -DiagnosedSilenceableFailure -transform::SplitReductionOp::applyToOne(linalg::LinalgOp target, - SmallVectorImpl &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne( + linalg::LinalgOp target, transform::ApplyToEachResultList &results, + transform::TransformState &state) { ControlSplitReductionFn splitFn = [&](LinalgOp) { return linalg::SplitReductionOptions{int64_t(getSplitFactor()), unsigned(getInsertSplitDimension()), @@ -1219,7 +1222,7 @@ } DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne( - linalg::LinalgOp target, SmallVectorImpl &results, + linalg::LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { TrivialPatternRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); @@ -1263,7 +1266,7 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForeachThreadOp::applyToOne( - linalg::LinalgOp target, SmallVectorImpl &results, + linalg::LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { TrivialPatternRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); @@ -1952,7 +1955,7 @@ DiagnosedSilenceableFailure transform::VectorizeOp::applyToOne(Operation *target, - SmallVectorImpl &results, + transform::ApplyToEachResultList &results, transform::TransformState &state) { if (!target->hasTrait()) { auto diag = this->emitOpError("requires isolated-from-above targets"); diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -21,10 +21,9 @@ // MemRefMultiBufferOp //===----------------------------------------------------------------------===// -DiagnosedSilenceableFailure -transform::MemRefMultiBufferOp::applyToOne(memref::AllocOp target, - SmallVector &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::applyToOne( + memref::AllocOp target, transform::ApplyToEachResultList &results, + transform::TransformState &state) { auto newBuffer = memref::multiBuffer(target, getFactor()); if (failed(newBuffer)) { Diagnostic diag(target->getLoc(), DiagnosticSeverity::Note); diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -125,7 +125,7 @@ DiagnosedSilenceableFailure transform::LoopPeelOp::applyToOne(scf::ForOp target, - SmallVector &results, + transform::ApplyToEachResultList &results, transform::TransformState &state) { scf::ForOp result; IRRewriter rewriter(target->getContext()); @@ -182,7 +182,7 @@ DiagnosedSilenceableFailure transform::LoopPipelineOp::applyToOne(scf::ForOp target, - SmallVector &results, + transform::ApplyToEachResultList &results, transform::TransformState &state) { scf::PipeliningOption options; options.getScheduleFn = @@ -210,7 +210,7 @@ DiagnosedSilenceableFailure transform::LoopUnrollOp::applyToOne(Operation *op, - SmallVector &results, + transform::ApplyToEachResultList &results, transform::TransformState &state) { LogicalResult result(failure()); if (scf::ForOp scfFor = dyn_cast(op)) diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -407,6 +407,93 @@ return paramSegments[resultNumber].data() != nullptr; } +//===----------------------------------------------------------------------===// +// Utilities for TransformEachOpTrait. +//===----------------------------------------------------------------------===// + +LogicalResult +transform::detail::checkApplyToOne(Operation *transformOp, + Location payloadOpLoc, + const ApplyToEachResultList &partialResult) { + Location transformOpLoc = transformOp->getLoc(); + StringRef transformOpName = transformOp->getName().getStringRef(); + unsigned expectedNumResults = transformOp->getNumResults(); + // TODO: encode this implicit must always produce `expectedNumResults` + // and nullptr is fine with a proper trait. + if (partialResult.size() != expectedNumResults) { + auto diag = mlir::emitError(transformOpLoc, "applications of ") + << transformOpName << " expected to produce " + << expectedNumResults << " results (actually produced " + << partialResult.size() << ")."; + diag.attachNote(transformOpLoc) + << "If you need variadic results, consider a generic `apply` " + << "instead of the specialized `applyToOne`."; + diag.attachNote(transformOpLoc) + << "Producing " << expectedNumResults << " null results is " + << "allowed if the use case warrants it."; + diag.attachNote(payloadOpLoc) << "when applied to this op"; + return failure(); + } + + // Check that all is null or none is null + // TODO: relax this behavior and encode with a proper trait. + if (llvm::any_of( + partialResult, + [](llvm::PointerUnion ptr) { return ptr; }) && + llvm::any_of(partialResult, + [](llvm::PointerUnion ptr) { + return !ptr; + })) { + auto diag = mlir::emitError(transformOpLoc, "unexpected application of ") + << transformOpName + << " produces both null and non null results."; + diag.attachNote(payloadOpLoc) << "when applied to this op"; + return failure(); + } + + // Check that the right kind of value was produced. + for (const auto &[ptr, res] : + llvm::zip(partialResult, transformOp->getResults())) { + if (ptr.is() && + !res.getType().template isa()) { + mlir::emitError(transformOpLoc) + << "applications of " << transformOpName + << " expected to produce an Attribute for result #" + << res.getResultNumber(); + return failure(); + } + if (ptr.is() && + !res.getType().template isa()) { + mlir::emitError(transformOpLoc) + << "applications of " << transformOpName + << " expected to produce an Operation * for result #" + << res.getResultNumber(); + return failure(); + } + } + return success(); +} + +void transform::detail::setApplyToOneResults( + Operation *transformOp, TransformResults &transformResults, + ArrayRef results) { + for (OpResult r : transformOp->getResults()) { + if (r.getType().isa()) { + auto params = llvm::to_vector( + llvm::map_range(results, [r](const ApplyToEachResultList &oneResult) { + return oneResult[r.getResultNumber()].get(); + })); + transformResults.setParams(r, params); + } else { + auto payloads = llvm::to_vector( + llvm::map_range(results, [r](const ApplyToEachResultList &oneResult) { + return oneResult[r.getResultNumber()].get(); + })); + transformResults.set(r, payloads); + } + } +} + //===----------------------------------------------------------------------===// // Utilities for PossibleTopLevelTransformOpTrait. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -261,8 +261,7 @@ //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::CastOp::applyToOne(Operation *target, - SmallVectorImpl &results, +transform::CastOp::applyToOne(Operation *target, ApplyToEachResultList &results, transform::TransformState &state) { results.push_back(target); return DiagnosedSilenceableFailure::success(); diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -991,3 +991,36 @@ "test.op_a"(%arg0) { attr = 0 : i32} : (i32) -> i32 return } + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + // expected-error @below {{expected to produce an Operation * for result #0}} + transform.test_produce_transform_param_or_forward_operand %arg0 + { first_result_is_param } + : (!transform.any_op) -> (!transform.any_op, !transform.param) +} + +// ----- + +// expected-note @below {{when applied to this op}} +module { + transform.sequence failures(propagate) { + ^bb0(%arg0: !transform.any_op): + // expected-error @below {{produces both null and non null results}} + transform.test_produce_transform_param_or_forward_operand %arg0 + { first_result_is_null } + : (!transform.any_op) -> (!transform.any_op, !transform.param) + } +} + +// ----- + +transform.sequence failures(propagate) { +^bb0(%arg0: !transform.any_op): + // expected-error @below {{expected to produce an Attribute for result #1}} + transform.test_produce_transform_param_or_forward_operand %arg0 + { second_result_is_handle } + : (!transform.any_op) -> (!transform.any_op, !transform.param) +} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -243,7 +243,7 @@ } DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfResultsOp::applyToOne( - Operation *target, SmallVectorImpl &results, + Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { OperationState opState(target->getLoc(), "foo"); results.push_back(OpBuilder(target).create(opState)); @@ -252,7 +252,7 @@ DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfMultiResultsOp::applyToOne( - Operation *target, SmallVectorImpl &results, + Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { static int count = 0; if (count++ == 0) { @@ -264,7 +264,7 @@ DiagnosedSilenceableFailure mlir::test::TestCorrectNumberOfMultiResultsOp::applyToOne( - Operation *target, SmallVectorImpl &results, + Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { OperationState opState(target->getLoc(), "foo"); results.push_back(OpBuilder(target).create(opState)); @@ -274,7 +274,7 @@ DiagnosedSilenceableFailure mlir::test::TestMixedNullAndNonNullResultsOp::applyToOne( - Operation *target, SmallVectorImpl &results, + Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { OperationState opState(target->getLoc(), "foo"); results.push_back(nullptr); @@ -284,7 +284,7 @@ DiagnosedSilenceableFailure mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne( - Operation *target, SmallVectorImpl &results, + Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { if (target->hasAttr("target_me")) return DiagnosedSilenceableFailure::success(); @@ -429,6 +429,35 @@ return success(); } +void mlir::test::TestProduceTransformParamOrForwardOperandOp::getEffects( + SmallVectorImpl &effects) { + transform::onlyReadsHandle(getIn(), effects); + transform::producesHandle(getOut(), effects); + transform::producesHandle(getParam(), effects); +} + +DiagnosedSilenceableFailure +mlir::test::TestProduceTransformParamOrForwardOperandOp::applyToOne( + Operation *target, ::transform::ApplyToEachResultList &results, + ::transform::TransformState &state) { + Builder builder(getContext()); + if (getFirstResultIsParam()) { + results.push_back(builder.getI64IntegerAttr(0)); + } else if (getFirstResultIsNull()) { + results.push_back(nullptr); + } else { + results.push_back(state.getPayloadOps(getIn()).front()); + } + + if (getSecondResultIsHandle()) { + results.push_back(state.getPayloadOps(getIn()).front()); + } else { + results.push_back(builder.getI64IntegerAttr(42)); + } + + return DiagnosedSilenceableFailure::success(); +} + namespace { /// Test extension of the Transform dialect. Registers additional ops and /// declares PDL as dependent dialect since the additional ops are using PDL diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -173,7 +173,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::Operation * target, - ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); }]; } @@ -189,7 +189,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::Operation * target, - ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); }]; } @@ -206,7 +206,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::Operation * target, - ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); }]; } @@ -223,7 +223,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::Operation * target, - ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); }]; } @@ -239,7 +239,7 @@ let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( ::mlir::Operation * target, - ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); }]; } @@ -313,4 +313,25 @@ let hasVerifier = 1; } +def TestProduceTransformParamOrForwardOperandOp + : Op, + TransformEachOpTrait, TransformOpInterface]> { + let arguments = (ins TransformHandleTypeInterface:$in, + UnitAttr:$first_result_is_param, + UnitAttr:$first_result_is_null, + UnitAttr:$second_result_is_handle); + let results = (outs TransformHandleTypeInterface:$out, + TransformParamTypeInterface:$param); + let assemblyFormat = "$in attr-dict `:` functional-type(operands, results)"; + let cppNamespace = "::mlir::test"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD